Skip to content

Commit

Permalink
Maintenance.
Browse files Browse the repository at this point in the history
  • Loading branch information
mkskeller committed May 9, 2023
1 parent c62ab2c commit 6cc3fcc
Show file tree
Hide file tree
Showing 135 changed files with 1,658 additions and 1,062 deletions.
3 changes: 0 additions & 3 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
[submodule "SimpleOT"]
path = deps/SimpleOT
url = https://github.com/mkskeller/SimpleOT
[submodule "mpir"]
path = deps/mpir
url = https://github.com/wbhart/mpir
[submodule "Programs/Circuits"]
path = Programs/Circuits
url = https://github.com/mkskeller/bristol-fashion
Expand Down
2 changes: 0 additions & 2 deletions BMR/RealProgramParty.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ class RealProgramParty : public ProgramPartySpec<T>

bool one_shot;

size_t data_sent;

public:
static RealProgramParty& s();

Expand Down
3 changes: 1 addition & 2 deletions BMR/RealProgramParty.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ RealProgramParty<T>::RealProgramParty(int argc, const char** argv) :
while (next != GC::DONE_BREAK);

MC->Check(*P);
data_sent = P->total_comm().sent;

if (online_opts.verbose)
P->total_comm().print();
Expand Down Expand Up @@ -216,7 +215,7 @@ RealProgramParty<T>::~RealProgramParty()
delete prep;
delete garble_inputter;
delete garble_protocol;
cout << "Data sent = " << data_sent * 1e-6 << " MB" << endl;
garble_machine.print_comm(*this->P, this->P->total_comm());
T::MAC_Check::teardown();
}

Expand Down
11 changes: 7 additions & 4 deletions BMR/Register.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,13 @@ class BaseKeyVector
#endif
};
#else
class BaseKeyVector : public vector<Key>
class BaseKeyVector : public CheckVector<Key>
{
typedef CheckVector<Key> super;

public:
BaseKeyVector(int size = 0) : vector<Key>(size, Key(0)) {}
void resize(int size) { vector<Key>::resize(size, Key(0)); }
BaseKeyVector(int size = 0) : super(size, Key(0)) {}
void resize(int size) { super::resize(size, Key(0)); }
};
#endif

Expand Down Expand Up @@ -296,7 +298,8 @@ class ProgramRegister : public Phase, public Register
static void andm(GC::Processor<U>&, const BaseInstruction&)
{ throw runtime_error("andm not implemented"); }

static void run_tapes(const vector<int>&) { throw not_implemented(); }
static void run_tapes(const vector<int>&)
{ throw runtime_error("multi-threading not implemented"); }

// most BMR phases don't need actual input
template<class T>
Expand Down
15 changes: 15 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,20 @@
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.

## 0.3.6 (May 9, 2023)

- More extensive benchmarking outputs
- Replace MPIR by GMP
- Secure reading of edaBits from files
- Semi-honest client communication
- Back-propagation for average pooling
- Parallelized convolution
- Probabilistic truncation as in ABY3
- More balanced communication in Shamir secret sharing
- Avoid unnecessary communication in Dealer protocol
- Linear solver using Cholesky decomposition
- Accept .py files for compilation
- Fixed security bug: proper accounting for random elements

## 0.3.5 (Feb 16, 2023)

- Easier-to-use machine learning interface
Expand Down
21 changes: 20 additions & 1 deletion CONFIG
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,32 @@ ARM := $(shell uname -m | grep x86; echo $$?)
OS := $(shell uname -s)
ifeq ($(MACHINE), x86_64)
ifeq ($(OS), Linux)
ifeq ($(shell cat /proc/cpuinfo | grep -q avx2; echo $$?), 0)
AVX_OT = 1
else
AVX_OT = 0
endif
else
AVX_OT = 0
endif
else
ARCH =
AVX_OT = 0
endif

ifeq ($(OS), Darwin)
BREW_CFLAGS += -I/usr/local/opt/openssl/include -I`brew --prefix`/opt/openssl/include -I`brew --prefix`/include
BREW_LDLIBS += -L/usr/local/opt/openssl/lib -L`brew --prefix`/lib -L`brew --prefix`/opt/openssl/lib
endif

ifeq ($(OS), Linux)
ifeq ($(ARM), 1)
ifeq ($(shell cat /proc/cpuinfo | grep -q aes; echo $$?), 0)
ARCH = -march=armv8.2-a+crypto
endif
endif
endif

USE_KOS = 0

# allow to set compiler in CONFIG.mine
Expand All @@ -66,7 +83,8 @@ endif
# Default for MAX_MOD_SZ is 10, which suffices for all Overdrive protocols
# MOD = -DMAX_MOD_SZ=10 -DGFP_MOD_SZ=5

LDLIBS = -lmpirxx -lmpir -lsodium $(MY_LDLIBS)
LDLIBS = -lgmpxx -lgmp -lsodium $(MY_LDLIBS)
LDLIBS += $(BREW_LDLIBS)
LDLIBS += -Wl,-rpath -Wl,$(CURDIR)/local/lib -L$(CURDIR)/local/lib
LDLIBS += -lboost_system -lssl -lcrypto

Expand All @@ -88,6 +106,7 @@ BOOST = -lboost_thread $(MY_BOOST)
endif

CFLAGS += $(ARCH) $(MY_CFLAGS) $(GDEBUG) -Wextra -Wall $(OPTIM) -I$(ROOT) -I$(ROOT)/deps -pthread $(PROF) $(DEBUG) $(MOD) $(GF2N_LONG) $(PREP_DIR) $(SSL_DIR) $(SECURE) -std=c++11 -Werror
CFLAGS += $(BREW_CFLAGS)
CPPFLAGS = $(CFLAGS)
LD = $(CXX)

Expand Down
2 changes: 2 additions & 0 deletions Compiler/GC/instructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@

class SecretBitsAF(base.RegisterArgFormat):
reg_type = 'sb'
name = 'sbit'
class ClearBitsAF(base.RegisterArgFormat):
reg_type = 'cb'
name = 'cbit'

base.ArgFormats['sb'] = SecretBitsAF
base.ArgFormats['sbw'] = SecretBitsAF
Expand Down
33 changes: 12 additions & 21 deletions Compiler/allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,16 +338,19 @@ def add_edge(i, j):
d[j] = d[i]

def read(reg, n):
last_read[reg] = n
for dup in reg.duplicates:
if last_def[dup] != -1:
if last_def[dup] not in (-1, n):
add_edge(last_def[dup], n)
last_read[reg] = n

def write(reg, n):
last_def[reg] = n
for dup in reg.duplicates:
if last_read[dup] not in (-1, n):
add_edge(last_read[dup], n)
if id(dup) in [id(x) for x in block.instructions[n].get_used()] and \
last_read[dup] not in (-1, n):
add_edge(last_read[dup], n)
last_def[reg] = n

def handle_mem_access(addr, reg_type, last_access_this_kind,
last_access_other_kind):
Expand Down Expand Up @@ -434,19 +437,19 @@ def keep_text_order(inst, n):
# if options.debug:
# col = colordict[instr.__class__.__name__]
# G.add_node(n, color=col, label=str(instr))
for reg in inputs:
for reg in outputs:
if reg.vector and instr.is_vec():
for i in reg.vector:
read(i, n)
write(i, n)
else:
read(reg, n)
write(reg, n)

for reg in outputs:
for reg in inputs:
if reg.vector and instr.is_vec():
for i in reg.vector:
write(i, n)
read(i, n)
else:
write(reg, n)
read(reg, n)

# will be merged
if isinstance(instr, TextInputInstruction):
Expand Down Expand Up @@ -556,18 +559,6 @@ def eliminate(i):
if unused_result:
eliminate(i)
count += 1
# remove unnecessary stack instructions
# left by optimization with budget
if isinstance(inst, popint_class) and \
(not G.degree(i) or (G.degree(i) == 1 and
isinstance(instructions[list(G[i])[0]], StackInstruction))) \
and \
inst.args[0].can_eliminate and \
len(G.pred[i]) == 1 and \
isinstance(instructions[list(G.pred[i])[0]], pushint_class):
eliminate(list(G.pred[i])[0])
eliminate(i)
count += 2
if count > 0 and self.block.parent.program.verbose:
print('Eliminated %d dead instructions, among which %d opens: %s' \
% (count, open_count, dict(stats)))
Expand Down
3 changes: 3 additions & 0 deletions Compiler/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ def set_variant(options):
do_precomp = False
elif variant is not None:
raise CompilerError('Unknown comparison variant: %s' % variant)
if const_rounds and instructions_base.program.options.binary:
raise CompilerError(
'Comparison variant choice incompatible with binary circuits')

def ld2i(c, n):
""" Load immediate 2^n into clear GF(p) register c """
Expand Down
15 changes: 8 additions & 7 deletions Compiler/compilerLib.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(self, custom_args=None, usage=None, execute=False):
self.custom_args = custom_args
self.build_option_parser()
self.VARS = {}
self.root = os.path.dirname(__file__) + '/..'

def build_option_parser(self):
parser = OptionParser(usage=self.usage)
Expand Down Expand Up @@ -269,7 +270,7 @@ def build_program(self, name=None):
self.prog = Program(self.args, self.options, name=name)
if self.execute:
if self.options.execute in \
("emulate", "ring", "rep-field", "semi2k"):
("emulate", "ring", "rep-field"):
self.prog.use_trunc_pr = True
if self.options.execute in ("ring",):
self.prog.use_split(3)
Expand Down Expand Up @@ -405,7 +406,7 @@ def compile_file(self):
infile = open(self.prog.infile)

# make compiler modules directly accessible
sys.path.insert(0, "Compiler")
sys.path.insert(0, "%s/Compiler" % self.root)
# create the tapes
exec(compile(infile.read(), infile.name, "exec"), self.VARS)

Expand Down Expand Up @@ -477,15 +478,15 @@ def executable_from_protocol(protocol):

def local_execution(self, args=[]):
executable = self.executable_from_protocol(self.options.execute)
if not os.path.exists(executable):
if not os.path.exists("%s/%s" % (self.root, executable)):
print("Creating binary for virtual machine...")
try:
subprocess.run(["make", executable], check=True)
subprocess.run(["make", executable], check=True, cwd=self.root)
except:
raise CompilerError(
"Cannot produce %s. " % executable + \
"Note that compilation requires a few GB of RAM.")
vm = 'Scripts/%s.sh' % self.options.execute
vm = "%s/Scripts/%s.sh" % (self.root, self.options.execute)
os.execl(vm, vm, self.prog.name, *args)

def remote_execution(self, args=[]):
Expand All @@ -496,7 +497,7 @@ def remote_execution(self, args=[]):
from fabric import Connection
import subprocess
print("Creating static binary for virtual machine...")
subprocess.run(["make", "static/%s" % vm], check=True)
subprocess.run(["make", "static/%s" % vm], check=True, cwd=self.root)

# transfer files
import glob
Expand All @@ -519,7 +520,7 @@ def run(i):
"mkdir -p %s/{Player-Data,Programs/{Bytecode,Schedules}} " % \
dest)
# executable
connection.put("static/%s" % vm, dest)
connection.put("%s/static/%s" % (self.root, vm), dest)
# program
dest += "/"
connection.put("Programs/Schedules/%s.sch" % self.prog.name,
Expand Down
4 changes: 2 additions & 2 deletions Compiler/floatingpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def BitDecRingRaw(a, k, m):
def BitDecRing(a, k, m):
bits = BitDecRingRaw(a, k, m)
# reversing to reduce number of rounds
return [types.sint.conv(bit) for bit in reversed(bits)][::-1]
return [types.sintbit.conv(bit) for bit in reversed(bits)][::-1]

def BitDecFieldRaw(a, k, m, kappa, bits_to_compute=None):
instructions_base.set_global_vector_size(a.size)
Expand All @@ -306,7 +306,7 @@ def BitDecFieldRaw(a, k, m, kappa, bits_to_compute=None):

def BitDecField(a, k, m, kappa, bits_to_compute=None):
res = BitDecFieldRaw(a, k, m, kappa, bits_to_compute)
return [types.sint.conv(bit) for bit in res]
return [types.sintbit.conv(bit) for bit in res]


@instructions_base.ret_cisc
Expand Down
31 changes: 23 additions & 8 deletions Compiler/instructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,17 @@ class reqbl(base.Instruction):
code = base.opcodes['REQBL']
arg_format = ['int']

class active(base.Instruction):
""" Indicate whether program is compatible with malicious-security
protocols.
:param: 0 for no, 1 for yes
"""
code = base.opcodes['ACTIVE']
arg_format = ['int']

class time(base.IOInstruction):

""" Output time since start of computation. """
code = base.opcodes['TIME']
arg_format = []
Expand Down Expand Up @@ -2418,9 +2428,10 @@ def add_usage(self, req_node):
super(matmulsm, self).add_usage(req_node)
req_node.increment(('matmul', tuple(self.args[3:6])), 1)

class conv2ds(base.DataInstruction):
class conv2ds(base.DataInstruction, base.VarArgsInstruction, base.Mergeable):
""" Secret 2D convolution.
:param: number of arguments to follow (int)
:param: result (sint vector in row-first order)
:param: inputs (sint vector in row-first order)
:param: weights (sint vector in row-first order)
Expand All @@ -2436,10 +2447,12 @@ class conv2ds(base.DataInstruction):
:param: padding height (int)
:param: padding width (int)
:param: batch size (int)
:param: repeat from result...
"""
code = base.opcodes['CONV2DS']
arg_format = ['sw','s','s','int','int','int','int','int','int','int','int',
'int','int','int','int']
arg_format = itertools.cycle(['sw','s','s','int','int','int','int','int',
'int','int','int','int','int','int','int'])
data_type = 'triple'
is_vec = lambda self: True

Expand All @@ -2450,14 +2463,16 @@ def __init__(self, *args, **kwargs):
assert args[2].size == args[7] * args[8] * args[11]

def get_repeat(self):
return self.args[3] * self.args[4] * self.args[7] * self.args[8] * \
self.args[11] * self.args[14]
args = self.args
return sum(args[i+3] * args[i+4] * args[i+7] * args[i+8] * \
args[i+11] * args[i+14] for i in range(0, len(args), 15))

def add_usage(self, req_node):
super(conv2ds, self).add_usage(req_node)
args = self.args
req_node.increment(('matmul', (1, args[7] * args[8] * args[11],
args[14] * args[3] * args[4])), 1)
for i in range(0, len(self.args), 15):
args = self.args[i:i + 15]
req_node.increment(('matmul', (1, args[7] * args[8] * args[11],
args[14] * args[3] * args[4])), 1)

@base.vectorize
class trunc_pr(base.VarArgsInstruction):
Expand Down
Loading

0 comments on commit 6cc3fcc

Please sign in to comment.