Skip to content

Commit

Permalink
Rep4, SPDZ-wise, MNIST training.
Browse files Browse the repository at this point in the history
  • Loading branch information
mkskeller committed Oct 28, 2020
1 parent 53f9b02 commit f42e614
Show file tree
Hide file tree
Showing 184 changed files with 5,833 additions and 816 deletions.
2 changes: 1 addition & 1 deletion BMR/RealProgramParty.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ RealProgramParty<T>::RealProgramParty(int argc, const char** argv) :
}
else
{
T::read_or_generate_mac_key(prep_dir, N, mac_key);
T::read_or_generate_mac_key(prep_dir, *P, mac_key);
prep = new Sub_Data_Files<T>(N, prep_dir, usage);
}

Expand Down
2 changes: 0 additions & 2 deletions BMR/Register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@

#include <unistd.h>

ostream& EvalRegister::out = cout;

int Register::counter = 0;

void Register::init(int n_parties)
Expand Down
6 changes: 3 additions & 3 deletions BMR/Register.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ using namespace std;
#include "Tools/FlexBuffer.h"
#include "Tools/PointerVector.h"
#include "Tools/Bundle.h"
#include "Tools/SwitchableOutput.h"

//#define PAD_TO_8(n) (n+8-n%8)
#define PAD_TO_8(n) (n)
Expand Down Expand Up @@ -199,6 +200,7 @@ class BlackHole
BlackHole& operator<<(T) { return *this; }
BlackHole& operator<<(BlackHole& (*__pf)(BlackHole&)) { (void)__pf; return *this; }
void activate(bool) {}
void redirect_to_file(ostream&) {}
};
inline BlackHole& endl(BlackHole& b) { return b; }
inline BlackHole& flush(BlackHole& b) { return b; }
Expand All @@ -211,7 +213,6 @@ class Phase
typedef NoMemory DynamicMemory;

typedef BlackHole out_type;
static BlackHole out;

static const bool actual_inputs = true;

Expand Down Expand Up @@ -353,8 +354,7 @@ class EvalRegister : public ProgramRegister

typedef EvalInputter Input;

typedef ostream& out_type;
static ostream& out;
typedef SwitchableOutput out_type;

static const bool actual_inputs = true;

Expand Down
21 changes: 20 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,32 @@
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.

## 0.2.0 (Oct 28, 2020)

- Rep4: honest-majority four-party computation with malicious security
- SY/SPDZ-wise: honest-majority computation with malicious security based on replicated or Shamir secret sharing
- Training with a sequence of dense layers
- Training and inference for multi-class classification
- Local share conversion for semi-honest protocols based on additive secret sharing modulo a power of two
- edaBit generation based on local share conversion
- Optimize exponentation with local share conversion
- Optimize Shamir pseudo-random secret sharing using a hyper-invertible matrix
- Mathematical functions (exponentation, logarithm, square root, and trigonometric functions) with binary circuits
- Direct construction of fixed-point values from any type, breaking `sfix(x)` where `x` is the integer representation of a fixed-point number. Use `sfix._new(x)` instead.
- Optimized dot product for `sfix`
- Matrix multiplication via operator overloading uses VM-optimized multiplication.
- Fake preprocessing for daBits and edaBits
- Fixed security bug: insufficient randomness in SemiBin random bit generation.
- Fixed security bug: insufficient randomization of FKOS15 inputs.
- Fixed security bug in binary computation with SPDZ(2k).

## 0.1.9 (Aug 24, 2020)

- Streamline inputs to binary circuits
- Improved private output
- Emulator for arithmetic circuits
- Efficient dot product with Shamir's secret sharing
- Lower memory usage for TensorFlow inference
- This version breaks bytecode compatibilty.
- This version breaks bytecode compatibility.

## 0.1.8 (June 15, 2020)

Expand Down
4 changes: 3 additions & 1 deletion CONFIG
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ USE_GF2N_LONG = 1
# AVX/AVX2 is required for replicated binary secret sharing
# BMI2 is used to optimize multiplication modulo a prime
# ADX is used to optimize big integer additions
# delete the second line to compile for a platform that supports everything
ARCH = -mtune=native -msse4.1 -msse4.2 -maes -mpclmul -mavx -mavx2 -mbmi2 -madx
ARCH = -march=native

# allow to set compiler in CONFIG.mine
CXX = g++
Expand Down Expand Up @@ -60,7 +62,7 @@ else
BOOST = -lboost_thread $(MY_BOOST)
endif

CFLAGS += $(ARCH) $(MY_CFLAGS) $(GDEBUG) -Wextra -Wall $(OPTIM) -I$(ROOT) -pthread $(PROF) $(DEBUG) $(MOD) $(MEMPROTECT) $(GF2N_LONG) $(PREP_DIR) -std=c++11 -Werror
CFLAGS += $(ARCH) $(MY_CFLAGS) $(GDEBUG) -Wextra -Wall $(OPTIM) -I$(ROOT) -pthread $(PROF) $(DEBUG) $(MOD) $(MEMPROTECT) $(GF2N_LONG) $(PREP_DIR) $(SECURE) -std=c++11 -Werror
CPPFLAGS = $(CFLAGS)
LD = $(CXX)

Expand Down
49 changes: 36 additions & 13 deletions Compiler/GC/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def compose(cls, items, bit_length=1):
return cls.bit_compose(sum([util.bit_decompose(item, bit_length) for item in items], []))
@classmethod
def bit_compose(cls, bits):
bits = list(bits)
if len(bits) == 1:
return bits[0]
bits = list(bits)
Expand Down Expand Up @@ -72,7 +73,7 @@ def bit_decompose(self, bit_length=None):
res = [self.bit_type() for i in range(n)]
self.bitdec(self, *res)
else:
res = self.trans([self])
res = self.bit_type.trans([self])
self.decomposed = res
return res + suffix
else:
Expand All @@ -83,8 +84,8 @@ def bit_decompose_clear(a, n_bits):
cbits.conv_cint_vec(a, *res)
return res
@classmethod
def malloc(cls, size):
return Program.prog.malloc(size, cls)
def malloc(cls, size, creator_tape=None):
return Program.prog.malloc(size, cls, creator_tape=creator_tape)
@staticmethod
def n_elements():
return 1
Expand Down Expand Up @@ -430,6 +431,8 @@ def reveal(self):
def equal(self, other, n=None):
bits = (~(self + other)).bit_decompose()
return reduce(operator.mul, bits)
def right_shift(self, m, k, security=None, signed=True):
return self.TruncPr(k, m)
def TruncPr(self, k, m, kappa=None):
if k > self.n:
raise Exception('TruncPr overflow: %d > %d' % (k, self.n))
Expand Down Expand Up @@ -481,8 +484,8 @@ class sbitvec(_vec):
def get_type(cls, n):
class sbitvecn(cls, _structure):
@staticmethod
def malloc(size):
return sbit.malloc(size * n)
def malloc(size, creator_tape=None):
return sbit.malloc(size * n, creator_tape=creator_tape)
@staticmethod
def n_elements():
return n
Expand Down Expand Up @@ -566,7 +569,8 @@ def __init__(self, elements=None, length=None):
x = sbitintvec.from_vec(r_bits) + sbitintvec.from_vec(cb)
v = x.v
self.v = v[:length]
elif elements is not None:
elif elements is not None and not (util.is_constant(elements) and \
elements == 0):
self.v = sbits.trans(elements)
def popcnt(self):
res = sbitint.wallace_tree([[b] for b in self.v])
Expand Down Expand Up @@ -606,7 +610,10 @@ def conv(cls, other):
return cls.from_vec(other.v)
@property
def size(self):
return self.v[0].n
if not self.v or util.is_constant(self.v[0]):
return 1
else:
return self.v[0].n
@property
def n_bits(self):
return len(self.v)
Expand Down Expand Up @@ -725,6 +732,8 @@ def cast(self, n):
return self.get_type(n).bit_compose(bits)
def round(self, k, m, kappa=None, nearest=None, signed=None):
bits = self.bit_decompose()
if signed:
bits += [bits[-1]] * (k - len(bits))
res_bits = self.bit_adder(bits[m:k], [bits[m-1]])
return self.get_type(k - m).compose(res_bits)
def int_div(self, other, bit_length=None):
Expand Down Expand Up @@ -781,7 +790,7 @@ def set_length(*args):
@classmethod
def bit_compose(cls, bits):
# truncate and extend bits
bits = bits[:cls.n]
bits = list(bits)[:cls.n]
bits += [0] * (cls.n - len(bits))
return super(sbitint, cls).bit_compose(bits)
def force_bit_decompose(self, n_bits=None):
Expand All @@ -801,6 +810,7 @@ def TruncMul(self, other, k, m, kappa=None, nearest=False):
b = t.bit_compose(other_bits + [other_bits[-1]] * (k - len(other_bits)))
product = a * b
res_bits = product.bit_decompose()[m:k]
res_bits += [res_bits[-1]] * (self.n - len(res_bits))
t = self.combo_type(other)
return t.bit_compose(res_bits)
def __mul__(self, other):
Expand All @@ -824,6 +834,15 @@ def get_bit_matrix(cls, self_bits, other):
else:
res.append([(x & bit) for x in other.bit_decompose(n - i)])
return res
@classmethod
def popcnt_bits(cls, bits):
res = sbitvec.from_vec(bits).popcnt().elements()[0]
res = cls.conv(res)
return res
def pow2(self, k):
l = int(math.ceil(math.log(k, 2)))
bits = [self.equal(i, l) for i in range(k)]
return self.bit_compose(bits)

class sbitintvec(sbitvec, _number, _bitint, _sbitintbase):
def __add__(self, other):
Expand Down Expand Up @@ -867,8 +886,11 @@ class cbitfix(object):
conv = staticmethod(lambda x: x)
load_mem = classmethod(lambda cls, *args: cls(cbits.load_mem(*args)))
store_in_mem = lambda self, *args: self.v.store_in_mem(*args)
def __init__(self, value):
self.v = value
@classmethod
def _new(cls, value):
res = cls()
res.v = value
return res
def output(self):
v = self.v
if self.k < v.unit:
Expand Down Expand Up @@ -897,10 +919,10 @@ def get_input_from(cls, player):
inst.inputb(player, cls.k, cls.f, v)
return cls._new(v)
def __xor__(self, other):
return type(self)(self.v ^ other.v)
return type(self)._new(self.v ^ other.v)
def __mul__(self, other):
if isinstance(other, sbit):
return type(self)(self.int_type(other * self.v))
return type(self)._new(self.int_type(other * self.v))
elif isinstance(other, sbitfixvec):
return other * self
else:
Expand All @@ -911,10 +933,11 @@ def __mul__(self, other):
def multipliable(other, k, f, size):
class cls(_fix):
int_type = sbitint.get_type(k)
clear_type = cbitfix
cls.set_precision(f, k)
return cls._new(cls.int_type(other), k, f)

sbitfix.set_precision(20, 41)
sbitfix.set_precision(16, 31)

class sbitfixvec(_fix):
int_type = sbitintvec
Expand Down
2 changes: 2 additions & 0 deletions Compiler/allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ def __init__(self, block, options, merge_classes):
else:
self.max_parallel_open = float('inf')
self.counter = defaultdict(lambda: 0)
self.rounds = defaultdict(lambda: 0)
self.dependency_graph(merge_classes)

def do_merge(self, merges_iter):
Expand Down Expand Up @@ -271,6 +272,7 @@ def longest_paths_merge(self):
merge = merges[i]
t = type(self.instructions[merge[0]])
self.counter[t] += len(merge)
self.rounds[t] += 1
if len(merge) > 10000:
print('Merging %d %s in round %d/%d' % \
(len(merge), t.__name__, i, len(merges)))
Expand Down
36 changes: 23 additions & 13 deletions Compiler/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,27 +135,37 @@ def Trunc(d, a, k, m, kappa, signed):
mulm(d, t, c[2])

def TruncRing(d, a, k, m, signed):
if program.use_split() == 3:
program.curr_tape.require_bit_length(1)
if program.use_split() in (2, 3):
if signed:
a += (1 << (k - 1))
from Compiler.types import sint
from .GC.types import sbitint
length = int(program.options.ring)
summands = a.split_to_n_summands(length, 3)
summands = a.split_to_n_summands(length, program.use_split())
x = sbitint.wallace_tree_without_finish(summands, True)
if m == 1:
low = x[1][1]
high = sint.conv(CarryOutLE(x[1][:-1], x[0][:-1])) + \
sint.conv(x[0][-1])
if program.use_split() == 2:
carries = sbitint.get_carries(*x)
low = carries[m]
high = sint.conv(carries[length])
else:
mid_carry = CarryOutRawLE(x[1][:m], x[0][:m])
low = sint.conv(mid_carry) + sint.conv(x[0][m])
tmp = util.tree_reduce(carry, (sbitint.half_adder(xx, yy)
for xx, yy in zip(x[1][m:-1],
x[0][m:-1])))
top_carry = sint.conv(carry([None, mid_carry], tmp, False)[1])
high = top_carry + sint.conv(x[0][-1])
if m == 1:
low = x[1][1]
high = sint.conv(CarryOutLE(x[1][:-1], x[0][:-1])) + \
sint.conv(x[0][-1])
else:
mid_carry = CarryOutRawLE(x[1][:m], x[0][:m])
low = sint.conv(mid_carry) + sint.conv(x[0][m])
tmp = util.tree_reduce(carry, (sbitint.half_adder(xx, yy)
for xx, yy in zip(x[1][m:-1],
x[0][m:-1])))
top_carry = sint.conv(carry([None, mid_carry], tmp, False)[1])
high = top_carry + sint.conv(x[0][-1])
shifted = sint()
shrsi(shifted, a, m)
res = shifted + sint.conv(low) - (high << (length - m))
if signed:
res -= (1 << (k - m - 1))
else:
a_prime = Mod2mRing(None, a, k, m, signed)
a -= a_prime
Expand Down
Loading

0 comments on commit f42e614

Please sign in to comment.