diff --git a/BMR/RealProgramParty.hpp b/BMR/RealProgramParty.hpp index 93ef4fb3d..7872ec526 100644 --- a/BMR/RealProgramParty.hpp +++ b/BMR/RealProgramParty.hpp @@ -104,7 +104,7 @@ RealProgramParty::RealProgramParty(int argc, const char** argv) : } else { - T::read_or_generate_mac_key(prep_dir, N, mac_key); + T::read_or_generate_mac_key(prep_dir, *P, mac_key); prep = new Sub_Data_Files(N, prep_dir, usage); } diff --git a/BMR/Register.cpp b/BMR/Register.cpp index 99d1bda74..04af155f4 100644 --- a/BMR/Register.cpp +++ b/BMR/Register.cpp @@ -24,8 +24,6 @@ #include -ostream& EvalRegister::out = cout; - int Register::counter = 0; void Register::init(int n_parties) diff --git a/BMR/Register.h b/BMR/Register.h index 09b29d189..50a4cb677 100644 --- a/BMR/Register.h +++ b/BMR/Register.h @@ -22,6 +22,7 @@ using namespace std; #include "Tools/FlexBuffer.h" #include "Tools/PointerVector.h" #include "Tools/Bundle.h" +#include "Tools/SwitchableOutput.h" //#define PAD_TO_8(n) (n+8-n%8) #define PAD_TO_8(n) (n) @@ -199,6 +200,7 @@ class BlackHole BlackHole& operator<<(T) { return *this; } BlackHole& operator<<(BlackHole& (*__pf)(BlackHole&)) { (void)__pf; return *this; } void activate(bool) {} + void redirect_to_file(ostream&) {} }; inline BlackHole& endl(BlackHole& b) { return b; } inline BlackHole& flush(BlackHole& b) { return b; } @@ -211,7 +213,6 @@ class Phase typedef NoMemory DynamicMemory; typedef BlackHole out_type; - static BlackHole out; static const bool actual_inputs = true; @@ -353,8 +354,7 @@ class EvalRegister : public ProgramRegister typedef EvalInputter Input; - typedef ostream& out_type; - static ostream& out; + typedef SwitchableOutput out_type; static const bool actual_inputs = true; diff --git a/CHANGELOG.md b/CHANGELOG.md index 4fd7a9f07..0cbe8a95f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,24 @@ The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here. +## 0.2.0 (Oct 28, 2020) + +- Rep4: honest-majority four-party computation with malicious security +- SY/SPDZ-wise: honest-majority computation with malicious security based on replicated or Shamir secret sharing +- Training with a sequence of dense layers +- Training and inference for multi-class classification +- Local share conversion for semi-honest protocols based on additive secret sharing modulo a power of two +- edaBit generation based on local share conversion +- Optimize exponentation with local share conversion +- Optimize Shamir pseudo-random secret sharing using a hyper-invertible matrix +- Mathematical functions (exponentation, logarithm, square root, and trigonometric functions) with binary circuits +- Direct construction of fixed-point values from any type, breaking `sfix(x)` where `x` is the integer representation of a fixed-point number. Use `sfix._new(x)` instead. +- Optimized dot product for `sfix` +- Matrix multiplication via operator overloading uses VM-optimized multiplication. +- Fake preprocessing for daBits and edaBits +- Fixed security bug: insufficient randomness in SemiBin random bit generation. +- Fixed security bug: insufficient randomization of FKOS15 inputs. +- Fixed security bug in binary computation with SPDZ(2k). + ## 0.1.9 (Aug 24, 2020) - Streamline inputs to binary circuits @@ -7,7 +26,7 @@ The changelog explains changes pulled through from the private development repos - Emulator for arithmetic circuits - Efficient dot product with Shamir's secret sharing - Lower memory usage for TensorFlow inference -- This version breaks bytecode compatibilty. +- This version breaks bytecode compatibility. ## 0.1.8 (June 15, 2020) diff --git a/CONFIG b/CONFIG index 329bfbf0f..5e7b7eba9 100644 --- a/CONFIG +++ b/CONFIG @@ -24,7 +24,9 @@ USE_GF2N_LONG = 1 # AVX/AVX2 is required for replicated binary secret sharing # BMI2 is used to optimize multiplication modulo a prime # ADX is used to optimize big integer additions +# delete the second line to compile for a platform that supports everything ARCH = -mtune=native -msse4.1 -msse4.2 -maes -mpclmul -mavx -mavx2 -mbmi2 -madx +ARCH = -march=native # allow to set compiler in CONFIG.mine CXX = g++ @@ -60,7 +62,7 @@ else BOOST = -lboost_thread $(MY_BOOST) endif -CFLAGS += $(ARCH) $(MY_CFLAGS) $(GDEBUG) -Wextra -Wall $(OPTIM) -I$(ROOT) -pthread $(PROF) $(DEBUG) $(MOD) $(MEMPROTECT) $(GF2N_LONG) $(PREP_DIR) -std=c++11 -Werror +CFLAGS += $(ARCH) $(MY_CFLAGS) $(GDEBUG) -Wextra -Wall $(OPTIM) -I$(ROOT) -pthread $(PROF) $(DEBUG) $(MOD) $(MEMPROTECT) $(GF2N_LONG) $(PREP_DIR) $(SECURE) -std=c++11 -Werror CPPFLAGS = $(CFLAGS) LD = $(CXX) diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py index c3842665c..3ed976679 100644 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -45,6 +45,7 @@ def compose(cls, items, bit_length=1): return cls.bit_compose(sum([util.bit_decompose(item, bit_length) for item in items], [])) @classmethod def bit_compose(cls, bits): + bits = list(bits) if len(bits) == 1: return bits[0] bits = list(bits) @@ -72,7 +73,7 @@ def bit_decompose(self, bit_length=None): res = [self.bit_type() for i in range(n)] self.bitdec(self, *res) else: - res = self.trans([self]) + res = self.bit_type.trans([self]) self.decomposed = res return res + suffix else: @@ -83,8 +84,8 @@ def bit_decompose_clear(a, n_bits): cbits.conv_cint_vec(a, *res) return res @classmethod - def malloc(cls, size): - return Program.prog.malloc(size, cls) + def malloc(cls, size, creator_tape=None): + return Program.prog.malloc(size, cls, creator_tape=creator_tape) @staticmethod def n_elements(): return 1 @@ -430,6 +431,8 @@ def reveal(self): def equal(self, other, n=None): bits = (~(self + other)).bit_decompose() return reduce(operator.mul, bits) + def right_shift(self, m, k, security=None, signed=True): + return self.TruncPr(k, m) def TruncPr(self, k, m, kappa=None): if k > self.n: raise Exception('TruncPr overflow: %d > %d' % (k, self.n)) @@ -481,8 +484,8 @@ class sbitvec(_vec): def get_type(cls, n): class sbitvecn(cls, _structure): @staticmethod - def malloc(size): - return sbit.malloc(size * n) + def malloc(size, creator_tape=None): + return sbit.malloc(size * n, creator_tape=creator_tape) @staticmethod def n_elements(): return n @@ -566,7 +569,8 @@ def __init__(self, elements=None, length=None): x = sbitintvec.from_vec(r_bits) + sbitintvec.from_vec(cb) v = x.v self.v = v[:length] - elif elements is not None: + elif elements is not None and not (util.is_constant(elements) and \ + elements == 0): self.v = sbits.trans(elements) def popcnt(self): res = sbitint.wallace_tree([[b] for b in self.v]) @@ -606,7 +610,10 @@ def conv(cls, other): return cls.from_vec(other.v) @property def size(self): - return self.v[0].n + if not self.v or util.is_constant(self.v[0]): + return 1 + else: + return self.v[0].n @property def n_bits(self): return len(self.v) @@ -725,6 +732,8 @@ def cast(self, n): return self.get_type(n).bit_compose(bits) def round(self, k, m, kappa=None, nearest=None, signed=None): bits = self.bit_decompose() + if signed: + bits += [bits[-1]] * (k - len(bits)) res_bits = self.bit_adder(bits[m:k], [bits[m-1]]) return self.get_type(k - m).compose(res_bits) def int_div(self, other, bit_length=None): @@ -781,7 +790,7 @@ def set_length(*args): @classmethod def bit_compose(cls, bits): # truncate and extend bits - bits = bits[:cls.n] + bits = list(bits)[:cls.n] bits += [0] * (cls.n - len(bits)) return super(sbitint, cls).bit_compose(bits) def force_bit_decompose(self, n_bits=None): @@ -801,6 +810,7 @@ def TruncMul(self, other, k, m, kappa=None, nearest=False): b = t.bit_compose(other_bits + [other_bits[-1]] * (k - len(other_bits))) product = a * b res_bits = product.bit_decompose()[m:k] + res_bits += [res_bits[-1]] * (self.n - len(res_bits)) t = self.combo_type(other) return t.bit_compose(res_bits) def __mul__(self, other): @@ -824,6 +834,15 @@ def get_bit_matrix(cls, self_bits, other): else: res.append([(x & bit) for x in other.bit_decompose(n - i)]) return res + @classmethod + def popcnt_bits(cls, bits): + res = sbitvec.from_vec(bits).popcnt().elements()[0] + res = cls.conv(res) + return res + def pow2(self, k): + l = int(math.ceil(math.log(k, 2))) + bits = [self.equal(i, l) for i in range(k)] + return self.bit_compose(bits) class sbitintvec(sbitvec, _number, _bitint, _sbitintbase): def __add__(self, other): @@ -867,8 +886,11 @@ class cbitfix(object): conv = staticmethod(lambda x: x) load_mem = classmethod(lambda cls, *args: cls(cbits.load_mem(*args))) store_in_mem = lambda self, *args: self.v.store_in_mem(*args) - def __init__(self, value): - self.v = value + @classmethod + def _new(cls, value): + res = cls() + res.v = value + return res def output(self): v = self.v if self.k < v.unit: @@ -897,10 +919,10 @@ def get_input_from(cls, player): inst.inputb(player, cls.k, cls.f, v) return cls._new(v) def __xor__(self, other): - return type(self)(self.v ^ other.v) + return type(self)._new(self.v ^ other.v) def __mul__(self, other): if isinstance(other, sbit): - return type(self)(self.int_type(other * self.v)) + return type(self)._new(self.int_type(other * self.v)) elif isinstance(other, sbitfixvec): return other * self else: @@ -911,10 +933,11 @@ def __mul__(self, other): def multipliable(other, k, f, size): class cls(_fix): int_type = sbitint.get_type(k) + clear_type = cbitfix cls.set_precision(f, k) return cls._new(cls.int_type(other), k, f) -sbitfix.set_precision(20, 41) +sbitfix.set_precision(16, 31) class sbitfixvec(_fix): int_type = sbitintvec diff --git a/Compiler/allocator.py b/Compiler/allocator.py index 5b647c3bc..1ae37a61c 100644 --- a/Compiler/allocator.py +++ b/Compiler/allocator.py @@ -220,6 +220,7 @@ def __init__(self, block, options, merge_classes): else: self.max_parallel_open = float('inf') self.counter = defaultdict(lambda: 0) + self.rounds = defaultdict(lambda: 0) self.dependency_graph(merge_classes) def do_merge(self, merges_iter): @@ -271,6 +272,7 @@ def longest_paths_merge(self): merge = merges[i] t = type(self.instructions[merge[0]]) self.counter[t] += len(merge) + self.rounds[t] += 1 if len(merge) > 10000: print('Merging %d %s in round %d/%d' % \ (len(merge), t.__name__, i, len(merges))) diff --git a/Compiler/comparison.py b/Compiler/comparison.py index 5e542d798..3005036a2 100644 --- a/Compiler/comparison.py +++ b/Compiler/comparison.py @@ -135,27 +135,37 @@ def Trunc(d, a, k, m, kappa, signed): mulm(d, t, c[2]) def TruncRing(d, a, k, m, signed): - if program.use_split() == 3: + program.curr_tape.require_bit_length(1) + if program.use_split() in (2, 3): + if signed: + a += (1 << (k - 1)) from Compiler.types import sint from .GC.types import sbitint length = int(program.options.ring) - summands = a.split_to_n_summands(length, 3) + summands = a.split_to_n_summands(length, program.use_split()) x = sbitint.wallace_tree_without_finish(summands, True) - if m == 1: - low = x[1][1] - high = sint.conv(CarryOutLE(x[1][:-1], x[0][:-1])) + \ - sint.conv(x[0][-1]) + if program.use_split() == 2: + carries = sbitint.get_carries(*x) + low = carries[m] + high = sint.conv(carries[length]) else: - mid_carry = CarryOutRawLE(x[1][:m], x[0][:m]) - low = sint.conv(mid_carry) + sint.conv(x[0][m]) - tmp = util.tree_reduce(carry, (sbitint.half_adder(xx, yy) - for xx, yy in zip(x[1][m:-1], - x[0][m:-1]))) - top_carry = sint.conv(carry([None, mid_carry], tmp, False)[1]) - high = top_carry + sint.conv(x[0][-1]) + if m == 1: + low = x[1][1] + high = sint.conv(CarryOutLE(x[1][:-1], x[0][:-1])) + \ + sint.conv(x[0][-1]) + else: + mid_carry = CarryOutRawLE(x[1][:m], x[0][:m]) + low = sint.conv(mid_carry) + sint.conv(x[0][m]) + tmp = util.tree_reduce(carry, (sbitint.half_adder(xx, yy) + for xx, yy in zip(x[1][m:-1], + x[0][m:-1]))) + top_carry = sint.conv(carry([None, mid_carry], tmp, False)[1]) + high = top_carry + sint.conv(x[0][-1]) shifted = sint() shrsi(shifted, a, m) res = shifted + sint.conv(low) - (high << (length - m)) + if signed: + res -= (1 << (k - m - 1)) else: a_prime = Mod2mRing(None, a, k, m, signed) a -= a_prime diff --git a/Compiler/floatingpoint.py b/Compiler/floatingpoint.py index 80b046526..1fde0aab4 100644 --- a/Compiler/floatingpoint.py +++ b/Compiler/floatingpoint.py @@ -53,6 +53,12 @@ def maskField(a, k, kappa): @instructions_base.ret_cisc def EQZ(a, k, kappa): + prog = program.Program.prog + if prog.use_split(): + from GC.types import sbitvec + v = sbitvec(a, k).v + bit = util.tree_reduce(operator.and_, (~b for b in v)) + return types.sint.conv(bit) if program.Program.prog.options.ring: c, r = maskRing(a, k) else: @@ -307,16 +313,22 @@ def BitDec(a, k, m, kappa, bits_to_compute=None): def BitDecRing(a, k, m): n_shift = int(program.Program.prog.options.ring) - m assert(n_shift >= 0) - if program.Program.prog.use_dabit: - r, r_bits = zip(*(types.sint.get_dabit() for i in range(m))) - r = types.sint.bit_compose(r) + if program.Program.prog.use_split(): + x = a.split_to_two_summands(m) + bits = types._bitint.carry_lookahead_adder(x[0], x[1], fewer_inv=False) + # reversing to reduce number of rounds + return [types.sint.conv(bit) for bit in reversed(bits)][::-1] else: - r_bits = [types.sint.get_random_bit() for i in range(m)] - r = types.sint.bit_compose(r_bits) - shifted = ((a - r) << n_shift).reveal() - masked = shifted >> n_shift - bits = r_bits[0].bit_adder(r_bits, masked.bit_decompose(m)) - return [types.sint.conv(bit) for bit in bits] + if program.Program.prog.use_dabit: + r, r_bits = zip(*(types.sint.get_dabit() for i in range(m))) + r = types.sint.bit_compose(r) + else: + r_bits = [types.sint.get_random_bit() for i in range(m)] + r = types.sint.bit_compose(r_bits) + shifted = ((a - r) << n_shift).reveal() + masked = shifted >> n_shift + bits = r_bits[0].bit_adder(r_bits, masked.bit_decompose(m)) + return [types.sint.conv(bit) for bit in bits] def BitDecField(a, k, m, kappa, bits_to_compute=None): r_dprime = types.sint() @@ -476,22 +488,20 @@ def TruncRoundNearestAdjustOverflow(a, length, target_length, kappa): def Int2FL(a, gamma, l, kappa): lam = gamma - 1 - s = types.sint() - comparison.LTZ(s, a, gamma, kappa) - z = EQZ(a, gamma, kappa) - a = (1 - 2 * s) * a - a_bits = BitDec(a, lam, lam, kappa) + s = a.less_than(0, gamma, security=kappa) + z = a.equal(0, gamma, security=kappa) + a = s.if_else(-a, a) + a_bits = a.bit_decompose(lam, security=kappa) a_bits.reverse() b = PreOR(a_bits, kappa) - t = a * (1 + sum(2**i * (1 - b_i) for i,b_i in enumerate(b))) - p = - (lam - sum(b)) + t = a * (1 + a.bit_compose(1 - b_i for b_i in b)) + p = a.popcnt_bits(b) - lam if gamma - 1 > l: if types.sfloat.round_nearest: v, overflow = TruncRoundNearestAdjustOverflow(t, gamma - 1, l, kappa) p = p + overflow else: - v = types.sint() - comparison.Trunc(v, t, gamma - 1, gamma - l - 1, kappa, False) + v = t.right_shift(gamma - l - 1, gamma - 1, kappa, signed=False) else: v = 2**(l-gamma+1) * t p = (p + gamma - 1 - l) * (1 -z) @@ -539,6 +549,7 @@ def TruncPrRing(a, k, m, signed=True): n_ring = int(program.Program.prog.options.ring) assert n_ring >= k, '%d too large' % k if k == n_ring: + program.Program.prog.curr_tape.require_bit_length(1) if program.Program.prog.use_edabit(): a += types.sint.get_edabit(m, True)[0] else: @@ -555,7 +566,8 @@ def TruncPrRing(a, k, m, signed=True): else: # extra bit to mask overflow prog = program.Program.prog - if prog.use_edabit() or prog.use_split() == 3: + prog.curr_tape.require_bit_length(1) + if prog.use_edabit() or prog.use_split() > 2: lower = sint.get_random_int(m) upper = sint.get_random_int(k - m) msb = sint.get_random_bit() diff --git a/Compiler/instructions_base.py b/Compiler/instructions_base.py index d66f5a522..9c213e466 100644 --- a/Compiler/instructions_base.py +++ b/Compiler/instructions_base.py @@ -441,6 +441,8 @@ def new_instructions(self, size, regs): program.options.cisc = True reset_global_vector_size() program.curr_tape = old_tape + for x, bl in tape.req_bit_length.items(): + old_tape.require_bit_length(bl, x) from Compiler.allocator import Merger merger = Merger(block, program.options, tuple(program.to_merge)) @@ -523,25 +525,26 @@ def wrapper(*args, **kwargs): def sfix_cisc(function): from Compiler.types import sfix, sint, cfix, copy_doc - def instruction(res, arg, k, f): + def instruction(res, arg, k, f, *args): assert k is not None assert f is not None old = sfix.k, sfix.f, cfix.k, cfix.f sfix.k, sfix.f, cfix.k, cfix.f = [None] * 4 - res.mov(res, function(sfix._new(arg, k=k, f=f)).v) + res.mov(res, function(sfix._new(arg, k=k, f=f), *args).v) sfix.k, sfix.f, cfix.k, cfix.f = old instruction.__name__ = function.__name__ instruction = cisc(instruction) def wrapper(*args, **kwargs): if isinstance(args[0], sfix): - assert len(args) == 1 + for arg in args[1:]: + assert util.is_constant(arg) assert not kwargs assert args[0].size == args[0].v.size k = args[0].k f = args[0].f res = sfix._new(sint(size=args[0].size), k=k, f=f) - instruction(res.v, args[0].v, k, f) + instruction(res.v, args[0].v, k, f, *args[1:]) return res else: return function(*args, **kwargs) diff --git a/Compiler/library.py b/Compiler/library.py index c1658de45..fb939bf36 100644 --- a/Compiler/library.py +++ b/Compiler/library.py @@ -134,6 +134,7 @@ def print_ln_if(cond, ss, *args): print_str_if(cond, ss + '\n', *args) def print_str_if(cond, ss, *args): + """ Print string conditionally. See :py:func:`print_ln_if` for details. """ if util.is_constant(cond): if cond: print_ln(ss, *args) @@ -160,7 +161,8 @@ def print_str_if(cond, ss, *args): def print_ln_to(player, ss, *args): """ Print line at :py:obj:`player` only. Note that printing is - disabled by default except at player 0. + disabled by default except at player 0. Activate interactive mode + with `-I` to enable it for all players. :param player: int :param ss: Python string @@ -814,8 +816,8 @@ def range_loop(loop_body, start, stop=None, step=None): if step is None: step = 1 def loop_fn(i): - loop_body(i) - return i + step + res = loop_body(i) + return util.if_else(res == 0, stop, i + step) if isinstance(step, int): if step > 0: condition = lambda x: x < stop @@ -840,7 +842,9 @@ def for_range(start, stop=None, step=None): in Python :py:func:`range`, but they can by any public integer. Information has to be passed out via container types such as :py:class:`Compiler.types.Array` or declaring registers as - :py:obj:`global`. + :py:obj:`global`. Note that changing Python data structures such + as lists within the loop is not possible, but the compiler cannot + warn about this. :param start/stop/step: regint/cint/int @@ -1057,7 +1061,7 @@ def f(i, j): """ return for_range_multithread(n_threads, None, n_loops) -def multithread(n_threads, n_items): +def multithread(n_threads, n_items, max_size=None): """ Distribute the computation of :py:obj:`n_items` to :py:obj:`n_threads` threads, but leave the in-thread repetition up @@ -1075,8 +1079,19 @@ def multithread(n_threads, n_items): def f(base, size): ... """ - return map_reduce(n_threads, None, n_items, initializer=lambda: [], - reducer=None, looping=False) + if max_size is None: + return map_reduce(n_threads, None, n_items, initializer=lambda: [], + reducer=None, looping=False) + else: + def wrapper(function): + @multithread(n_threads, n_items) + def new_function(base, size): + for i in range(0, size, max_size): + part_base = base + i + part_size = min(max_size, size - i) + function(part_base, part_size) + break_point() + return wrapper def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \ thread_mem_req={}, looping=True): @@ -1563,8 +1578,8 @@ def cint_cint_division(a, b, k, f): theta = int(ceil(log(k/3.5) / log(2))) two = cint(2) * two_power(f) - sign_b = cint(1) - 2 * cint(b < 0) - sign_a = cint(1) - 2 * cint(a < 0) + sign_b = cint(1) - 2 * cint(b.less_than(0, k)) + sign_a = cint(1) - 2 * cint(a.less_than(0, k)) absolute_b = b * sign_b absolute_a = a * sign_a w0 = approximate_reciprocal(absolute_b, k, f, theta) @@ -1632,9 +1647,12 @@ def FPDiv(a, b, k, f, kappa, simplex_flag=False, nearest=False): f = max((k - nearest) // 2 + 1, f) assert 2 * f > k - nearest theta = int(ceil(log(k/3.5) / log(2))) + + base.set_global_vector_size(b.size) alpha = b.get_type(2 * k).two_power(2*f) w = AppRcr(b, k, f, kappa, simplex_flag, nearest).extend(2 * k) x = alpha - b.extend(2 * k) * w + base.reset_global_vector_size() y = a.extend(2 *k) * w y = y.round(2*k, f, kappa, nearest, signed=True) diff --git a/Compiler/ml.py b/Compiler/ml.py index 2060d22e9..6814d15a6 100644 --- a/Compiler/ml.py +++ b/Compiler/ml.py @@ -42,6 +42,7 @@ """ import math +import re from Compiler import mpc_math, util from Compiler.types import * @@ -58,13 +59,13 @@ def log_e(x): def exp(x): return mpc_math.pow_fx(math.e, x) -def sanitize(x, raw, lower, upper): +def get_limit(x): exp_limit = 2 ** (x.k - x.f - 1) - limit = math.log(exp_limit) - if get_program().options.ring: - res = raw - else: - res = (x > limit).if_else(upper, raw) + return math.log(exp_limit) + +def sanitize(x, raw, lower, upper): + limit = get_limit(x) + res = (x > limit).if_else(upper, raw) return (x < -limit).if_else(lower, res) def sigmoid(x): @@ -137,10 +138,12 @@ def op(a, b): return comp.if_else(a[0], b[0]), comp.if_else(a[1], b[1]) return tree_reduce(op, enumerate(x))[0] +report_progress = False + def progress(x): - return - print_ln(x) - time() + if report_progress: + print_ln(x) + time() def set_n_threads(n_threads): Layer.n_threads = n_threads @@ -159,6 +162,10 @@ def __getitem__(self, *args): self.alloc() return super(Tensor, self).__getitem__(*args) + def assign_vector(self, *args): + self.alloc() + return super(Tensor, self).assign_vector(*args) + class Layer: n_threads = 1 inputs = [] @@ -190,12 +197,23 @@ def Y(self, value): class NoVariableLayer(Layer): input_from = lambda *args, **kwargs: None -class Output(Layer): + nablas = lambda self: () + reset = lambda self: None + +class Output(NoVariableLayer): """ Fixed-point logistic regression output layer. :param N: number of examples :param approx: :py:obj:`False` (default) or parameter for :py:obj:`approx_sigmoid` """ + n_outputs = 2 + + @classmethod + def from_args(cls, N, program): + res = cls(N, approx='approx' in program.args) + res.compute_loss = not 'no_loss' in program.args + return res + def __init__(self, N, debug=False, approx=False): self.N = N self.X = sfix.Array(N) @@ -206,9 +224,7 @@ def __init__(self, N, debug=False, approx=False): self.debug = debug self.weights = None self.approx = approx - - nablas = lambda self: () - reset = lambda self: None + self.compute_loss = True def divisor(self, divisor, size): return cfix(1.0 / divisor, size=size) @@ -224,11 +240,13 @@ def _(base, size): x = self.X.get_vector(base, size) y = self.Y.get(batch.get_vector(base, size)) if self.approx: - lse.assign(approx_lse_0(x, self.approx) + x * (1 - y), base) + if self.compute_loss: + lse.assign(approx_lse_0(x, self.approx) + x * (1 - y), base) return e_x = exp(-x) self.e_x.assign(e_x, base) - lse.assign(lse_0_from_e_x(-x, e_x) + x * (1 - y), base) + if self.compute_loss: + lse.assign(lse_0_from_e_x(-x, e_x) + x * (1 - y), base) self.l.write(sum(lse) * \ self.divisor(N, 1)) @@ -246,13 +264,10 @@ def _(base, size): diff = self.eval(size, base) - \ self.Y.get(batch.get_vector(base, size)) assert sfix.f == cfix.f - if self.weights is None: - diff *= self.divisor(N, size) - else: + if self.weights is not None: assert N == len(self.weights) diff *= self.weights.get_vector(base, size) - if self.weight_total != 1: - diff *= self.divisor(self.weight_total, size) + assert self.weight_total == N self.nabla_X.assign(diff, base) # @for_range_opt(len(diff)) # def _(i): @@ -271,6 +286,244 @@ def set_weights(self, weights): self.weights.assign(weights) self.weight_total = sum(weights) + def average_loss(self, N): + return self.l.reveal() + + def reveal_correctness(self, n=None, Y=None, debug=False): + if n is None: + n = self.X.sizes[0] + if Y is None: + Y = self.Y + n_correct = MemValue(0) + n_printed = MemValue(0) + @for_range_opt(n) + def _(i): + truth = Y[i].reveal() + b = self.X[i].reveal() + if debug: + nabla = self.nabla_X[i].reveal() + guess = b > 0 + correct = truth == guess + n_correct.iadd(correct) + if debug: + to_print = (1 - correct) * (n_printed < 10) + n_printed.iadd(to_print) + print_ln_if(to_print, '%s: %s %s %s %s', + i, truth, guess, b, nabla) + return n_correct + +class MultiOutputBase(NoVariableLayer): + def __init__(self, N, d_out, approx=False, debug=False): + self.X = sfix.Matrix(N, d_out) + self.Y = sint.Matrix(N, d_out) + self.nabla_X = sfix.Matrix(N, d_out) + self.l = MemValue(sfix(-1)) + self.losses = sfix.Array(N) + self.approx = None + self.N = N + self.d_out = d_out + self.compute_loss = True + + def eval(self, N): + d_out = self.X.sizes[1] + res = sfix.Matrix(N, d_out) + res.assign_vector(self.X.get_part_vector(0, N)) + return res + + def average_loss(self, N): + return sum(self.losses.get_vector(0, N)).reveal() / N + + def reveal_correctness(self, n=None, Y=None, debug=False): + if n is None: + n = self.X.sizes[0] + if Y is None: + Y = self.Y + n_correct = MemValue(0) + n_printed = MemValue(0) + @for_range_opt(n) + def _(i): + a = Y[i].reveal_list() + b = self.X[i].reveal_list() + if debug: + loss = self.losses[i].reveal() + exp = self.get_extra_debugging(i) + nabla = self.nabla_X[i].reveal_list() + truth = argmax(a) + guess = argmax(b) + correct = truth == guess + n_correct.iadd(correct) + if debug: + to_print = (1 - correct) * (n_printed < 10) + n_printed.iadd(to_print) + print_ln_if(to_print, '%s: %s %s %s %s %s %s', + i, truth, guess, loss, b, exp, nabla) + return n_correct + + @property + def n_outputs(self): + return self.d_out + + def get_extra_debugging(self, i): + return '' + + @staticmethod + def from_args(program, N, n_output): + if 'relu_out' in program.args: + res = ReluMultiOutput(N, n_output) + else: + res = MultiOutput(N, n_output, approx='approx' in program.args) + res.cheaper_loss = 'mse' in program.args + res.compute_loss = not 'no_loss' in program.args + return res + +class MultiOutput(MultiOutputBase): + """ + Output layer for multi-class classification with softmax and cross entropy. + + :param N: number of examples + :param d_out: number of classes + :param approx: use ReLU division instead of softmax for the loss + """ + def __init__(self, N, d_out, approx=False, debug=False): + MultiOutputBase.__init__(self, N, d_out) + self.exp = sfix.Matrix(N, d_out) + self.approx = approx + self.positives = sint.Matrix(N, d_out) + self.relus = sfix.Matrix(N, d_out) + self.cheaper_loss = False + self.debug = debug + self.true_X = sfix.Array(N) + + def forward(self, batch): + N = len(batch) + d_out = self.X.sizes[1] + tmp = self.losses + @for_range_opt_multithread(self.n_threads, N) + def _(i): + if self.approx: + positives = self.X[i].get_vector() > (0 if self.cheaper_loss else 0.1) + relus = positives.if_else(self.X[i].get_vector(), 0) + self.positives[i].assign_vector(positives) + self.relus[i].assign_vector(relus) + if self.compute_loss: + if self.cheaper_loss: + s = sum(relus) + tmp[i] = sum((self.Y[batch[i]][j] * s - relus[j]) ** 2 + for j in range(d_out)) / s ** 2 * 0.5 + else: + div = relus / sum(relus).expand_to_vector(d_out) + self.losses[i] = -sfix.dot_product( + self.Y[batch[i]].get_vector(), log_e(div)) + else: + m = util.max(self.X[i]) + mv = m.expand_to_vector(d_out) + x = self.X[i].get_vector() + e = (x - mv > -get_limit(x)).if_else(exp(x - mv), 0) + self.exp[i].assign_vector(e) + if self.compute_loss: + true_X = sfix.dot_product(self.Y[batch[i]], self.X[i]) + tmp[i] = m + log_e(sum(e)) - true_X + self.true_X[i] = true_X + self.l.write(sum(tmp.get_vector(0, N)) / N) + + def eval(self, N): + d_out = self.X.sizes[1] + res = sfix.Matrix(N, d_out) + if self.approx: + @for_range_opt_multithread(self.n_threads, N) + def _(i): + relus = (self.X[i].get_vector() > 0).if_else( + self.X[i].get_vector(), 0) + res[i].assign_vector(relus / sum(relus).expand_to_vector(d_out)) + return res + @for_range_opt_multithread(self.n_threads, N) + def _(i): + e = exp(self.X[i].get_vector()) + res[i].assign_vector(e / sum(e).expand_to_vector(d_out)) + return res + + def backward(self, batch): + d_out = self.X.sizes[1] + if self.approx: + @for_range_opt_multithread(self.n_threads, len(batch)) + def _(i): + if self.cheaper_loss: + s = sum(self.relus[i]) + ss = s * s * s + inv = 1 / ss + @for_range_opt(d_out) + def _(j): + res = 0 + for k in range(d_out): + relu = self.relus[i][k] + summand = relu - self.Y[batch[i]][k] * s + summand *= (sfix.from_sint(j == k) - relu) + res += summand + fallback = -self.Y[batch[i]][j] + res *= inv + self.nabla_X[i][j] = self.positives[i][j].if_else(res, fallback) + return + relus = self.relus[i].get_vector() + positives = self.positives[i].get_vector() + inv = (1 / sum(relus)).expand_to_vector(d_out) + truths = self.Y[batch[i]].get_vector() + raw = truths / relus - inv + self.nabla_X[i] = -positives.if_else(raw, truths) + self.maybe_debug_backward(batch) + return + @for_range_opt_multithread(self.n_threads, len(batch)) + def _(i): + for j in range(d_out): + dividend = self.exp[i][j] + divisor = sum(self.exp[i]) + div = (divisor > 0.1).if_else(dividend / divisor, 0) + self.nabla_X[i][j] = (-self.Y[batch[i]][j] + div) + self.maybe_debug_backward(batch) + + def maybe_debug_backward(self, batch): + if self.debug: + @for_range(len(batch)) + def _(i): + check = 0 + for j in range(self.X.sizes[1]): + to_check = self.nabla_X[i][j].reveal() + check += (to_check > len(batch)) + (to_check < -len(batch)) + print_ln_if(check, 'X %s', self.X[i].reveal_nested()) + print_ln_if(check, 'exp %s', self.exp[i].reveal_nested()) + print_ln_if(check, 'nabla X %s', + self.nabla_X[i].reveal_nested()) + + def get_extra_debugging(self, i): + if self.approx: + return self.relus[i].reveal_list() + else: + return self.exp[i].reveal_list() + +class ReluMultiOutput(MultiOutputBase): + """ + Output layer for multi-class classification with back-propagation + based on ReLU division. + + :param N: number of examples + :param d_out: number of classes + """ + def forward(self, batch): + self.l.write(999) + + def backward(self, batch): + N = len(batch) + d_out = self.X.sizes[1] + relus = sfix.Matrix(N, d_out) + @for_range_opt_multithread(self.n_threads, len(batch)) + def _(i): + positives = self.X[i].get_vector() > 0 + relus = positives.if_else(self.X[i].get_vector(), 0) + s = sum(relus) + inv = 1 / s + prod = relus * inv + res = prod - self.Y[batch[i]].get_vector() + self.nabla_X[i].assign_vector(res) + class DenseBase(Layer): thetas = lambda self: (self.W, self.b) nablas = lambda self: (self.nabla_W, self.nabla_b) @@ -279,26 +532,20 @@ def backward_params(self, f_schur_Y, batch): N = len(batch) tmp = Matrix(self.d_in, self.d_out, unreduced_sfix) - assert self.d == 1 - if self.d_out == 1: - @multithread(self.n_threads, self.d_in) - def _(base, size): - A = sfix.Matrix(1, self.N, address=f_schur_Y.address) - B = sfix.Matrix(self.N, self.d_in, address=self.X.address) - mp = A.direct_mul(B, reduce=False, - indices=(regint(0, size=1), - regint.inc(N), - batch.get_vector(), - regint.inc(size, base))) - tmp.assign_vector(mp, base) - else: - @for_range_opt_multithread(self.n_threads, [self.d_in, self.d_out]) - def _(j, k): - a = [f_schur_Y[i][0][k] for i in range(N)] - b = [self.X[i][0][j] for i in batch] - tmp[j][k] = sfix.unreduced_dot_product(a, b) - - if self.d_in * self.d_out < 100000: + @multithread(self.n_threads, self.d_in) + def _(base, size): + A = sfix.Matrix(self.N, self.d_out, address=f_schur_Y.address) + B = sfix.Matrix(self.N, self.d_in, address=self.X.address) + mp = B.direct_trans_mul(A, reduce=False, + indices=(regint.inc(size, base), + batch.get_vector(), + regint.inc(N), + regint.inc(self.d_out))) + tmp.assign_part_vector(mp, base) + + progress('nabla W (matmul)') + + if self.d_in * self.d_out < 200000: print('reduce at once') @multithread(self.n_threads, self.d_in * self.d_out) def _(base, size): @@ -309,10 +556,46 @@ def _(base, size): def _(i): self.nabla_W[i] = tmp[i].get_vector().reduce_after_mul() - self.nabla_b.assign(sum(sum(f_schur_Y[k][j][i] for k in range(N)) - for j in range(self.d)) for i in range(self.d_out)) + progress('nabla W') + + self.nabla_b.assign_vector(sum(sum(f_schur_Y[k][j].get_vector() + for k in range(N)) + for j in range(self.d))) - progress('nabla W/b') + progress('nabla b') + + if self.debug: + limit = N * self.debug + @for_range_opt(self.d_in) + def _(i): + @for_range_opt(self.d_out) + def _(j): + to_check = self.nabla_W[i][j].reveal() + check = sum(to_check > limit) + sum(to_check < -limit) + @if_(check) + def _(): + print_ln('nabla W %s %s %s: %s', i, j, self.W.sizes, to_check) + print_ln('Y %s', [f_schur_Y[k][0][j].reveal() + for k in range(N)]) + print_ln('X %s', [self.X[k][0][i].reveal() + for k in range(N)]) + @for_range_opt(self.d_out) + def _(j): + to_check = self.nabla_b[j].reveal() + check = sum(to_check > limit) + sum(to_check < -limit) + @if_(check) + def _(): + print_ln('nabla b %s %s: %s', j, len(self.b), to_check) + print_ln('Y %s', [f_schur_Y[k][0][j].reveal() + for k in range(N)]) + @for_range_opt(len(batch)) + def _(i): + to_check = self.nabla_X[i].get_vector().reveal() + check = sum(to_check > limit) + sum(to_check < -limit) + @if_(check) + def _(): + print_ln('X %s %s', i, self.X[i].reveal_nested()) + print_ln('Y %s %s', i, f_schur_Y[i].reveal_nested()) class Dense(DenseBase): """ Fixed-point dense (matrix multiplication) layer. @@ -321,7 +604,7 @@ class Dense(DenseBase): :param d_in: input dimension :param d_out: output dimension """ - def __init__(self, N, d_in, d_out, d=1, activation='id'): + def __init__(self, N, d_in, d_out, d=1, activation='id', debug=False): self.activation = activation if activation == 'id': self.f = lambda x: x @@ -349,15 +632,13 @@ def __init__(self, N, d_in, d_out, d=1, activation='id'): self.f_input = MultiArray([N, d, d_out], sfix) + self.debug = debug + def reset(self): d_in = self.d_in d_out = self.d_out r = math.sqrt(6.0 / (d_in + d_out)) - @for_range(d_in) - def _(i): - @for_range(d_out) - def _(j): - self.W[i][j] = sfix.get_random(-r, r) + self.W.assign_vector(sfix.get_random(-r, r, size=self.W.total_size())) self.b.assign_all(0) def input_from(self, player, raw=False): @@ -372,15 +653,14 @@ def compute_f_input(self, batch): prod = MultiArray([N, self.d, self.d_out], sfix) else: prod = self.f_input - @multithread(self.n_threads, N) + max_size = program.Program.prog.budget // self.d_out + @multithread(self.n_threads, N, max_size) def _(base, size): X_sub = sfix.Matrix(self.N, self.d_in, address=self.X.address) - prod.assign_vector( - X_sub.direct_mul(self.W, indices=(batch.get_vector(base, size), - regint.inc(self.d_in), - regint.inc(self.d_in), - regint.inc(self.d_out))), - base) + prod.assign_part_vector( + X_sub.direct_mul(self.W, indices=( + batch.get_vector(base, size), regint.inc(self.d_in), + regint.inc(self.d_in), regint.inc(self.d_out))), base) if self.input_bias: if self.d_out == 1: @@ -389,7 +669,7 @@ def _(base, size): v = prod.get_vector(base, size) + self.b.expand_to_vector(0, size) self.f_input.assign_vector(v, base) else: - @for_range_opt_multithread(self.n_threads, N) + @for_range_multithread(self.n_threads, 100, N) def _(i): v = prod[i].get_vector() + self.b.get_vector() self.f_input[i].assign_vector(v) @@ -397,8 +677,24 @@ def _(i): def forward(self, batch=None): self.compute_f_input(batch=batch) - self.Y.assign_vector(self.f( - self.f_input.get_part_vector(0, len(batch)))) + @multithread(self.n_threads, len(batch), 128) + def _(base, size): + self.Y.assign_part_vector(self.f( + self.f_input.get_part_vector(base, size)), base) + if self.debug: + limit = self.debug + @for_range_opt(len(batch)) + def _(i): + @for_range_opt(self.d_out) + def _(j): + to_check = self.Y[i][0][j].reveal() + check = to_check > limit + @if_(check) + def _(): + print_ln('dense Y %s %s %s %s', i, j, self.W.sizes, to_check) + print_ln('X %s', self.X[i].reveal_nested()) + print_ln('W %s', + [self.W[k][j].reveal() for k in range(self.d_in)]) def backward(self, compute_nabla_X=True, batch=None): N = len(batch) @@ -419,26 +715,31 @@ def backward(self, compute_nabla_X=True, batch=None): f_prime_bit = MultiArray([N, d, d_out], sint) f_schur_Y = MultiArray([N, d, d_out], sfix) - self.compute_f_input() - f_prime_bit.assign_vector(self.f_prime(self.f_input.get_vector())) + @multithread(self.n_threads, f_prime_bit.total_size()) + def _(base, size): + f_prime_bit.assign_vector( + self.f_prime(self.f_input.get_vector(base, size)), base) progress('f prime') - @for_range_opt(N) - def _(i): - i = batch[i] - f_schur_Y[i] = nabla_Y[i].schur(f_prime_bit[i]) + @multithread(self.n_threads, f_prime_bit.total_size()) + def _(base, size): + f_schur_Y.assign_vector(nabla_Y.get_vector(base, size) * + f_prime_bit.get_vector(base, size), + base) progress('f prime schur Y') if compute_nabla_X: - @for_range_opt(N) - def _(i): - i = batch[i] - if self.activation == 'id': - nabla_X[i] = nabla_Y[i].mul_trans(W) - else: - nabla_X[i] = nabla_Y[i].schur(f_prime_bit[i]).mul_trans(W) + @multithread(self.n_threads, N) + def _(base, size): + B = sfix.Matrix(N, d_out, address=f_schur_Y.address) + nabla_X.assign_part_vector( + B.direct_mul_trans(W, indices=(regint.inc(size, base), + regint.inc(self.d_out), + regint.inc(self.d_out), + regint.inc(self.d_in))), + base) progress('nabla X') @@ -1151,6 +1452,7 @@ def comp(left, right): class Optimizer: """ Base class for graphs of layers. """ n_threads = Layer.n_threads + always_shuffle = True @property def layers(self): @@ -1175,6 +1477,14 @@ def set_layers_with_inputs(self, layers): layer.last_used = list(filter(lambda x: x not in used, layer.inputs)) used.update(layer.inputs) + def batch_for(self, layer, batch): + if layer in (self.layers[0], self.layers[-1]): + return batch + else: + batch = regint.Array(len(batch)) + batch.assign(regint.inc(len(batch))) + return batch + def forward(self, N=None, batch=None, keep_intermediate=True, model_from=None): """ Compute graph. @@ -1193,7 +1503,7 @@ def forward(self, N=None, batch=None, keep_intermediate=True, if model_from is not None: layer.input_from(model_from) break_point() - layer.forward(batch=batch) + layer.forward(batch=self.batch_for(layer, batch)) break_point() if not keep_intermediate: for l in layer.last_used: @@ -1212,26 +1522,30 @@ def backward(self, batch): """ Compute backward propagation. """ for layer in reversed(self.layers): if len(layer.inputs) == 0: - layer.backward(compute_nabla_X=False, batch=batch) + layer.backward(compute_nabla_X=False, + batch=self.batch_for(layer, batch)) else: - layer.backward(batch=batch) + layer.backward(batch=self.batch_for(layer, batch)) if len(layer.inputs) == 1: + layer.inputs[0].nabla_Y.alloc() layer.inputs[0].nabla_Y.assign_vector( layer.nabla_X.get_part_vector(0, len(batch))) - def run(self, batch_size=None): + def run(self, batch_size=None, stop_on_loss=0): """ Run training. :param batch_size: batch size (defaults to example size of first layer) """ + if self.n_epochs == 0: + return if batch_size is not None: N = batch_size else: N = self.layers[0].N - i = MemValue(0) + i = self.i_epoch n_iterations = MemValue(0) - @do_while - def _(): + @for_range(self.n_epochs) + def _(_): if self.X_by_label is None: self.X_by_label = [[None] * self.layers[0].N] assert len(self.X_by_label) in (1, 2) @@ -1239,16 +1553,18 @@ def _(): n = N // len(self.X_by_label) n_per_epoch = int(math.ceil(1. * max(len(X) for X in self.X_by_label) / n)) - n_iterations.iadd(n_per_epoch) print('%d runs per epoch' % n_per_epoch) indices_by_label = [] for label, X in enumerate(self.X_by_label): indices = regint.Array(n * n_per_epoch) indices_by_label.append(indices) indices.assign(regint.inc(len(indices), 0, 1, 1, len(X))) - indices.shuffle() + if self.always_shuffle or n_per_epoch > 1: + indices.shuffle() + loss_sum = MemValue(sfix(0)) @for_range(n_per_epoch) def _(j): + n_iterations.iadd(1) batch = regint.Array(N) for label, X in enumerate(self.X_by_label): indices = indices_by_label[label] @@ -1257,20 +1573,84 @@ def _(j): label * n) self.forward(batch=batch) self.backward(batch=batch) - self.update(i) - loss = self.layers[-1].l + self.update(i, batch=batch) + loss_sum.iadd(self.layers[-1].l) + if self.print_loss_reduction: + before = self.layers[-1].average_loss(N) + self.forward(batch=batch) + after = self.layers[-1].average_loss(N) + print_ln('loss reduction in batch %s: %s (%s - %s)', j, + before - after, before, after) + elif self.print_losses: + print_ln('loss in batch %s: %s', j, self.layers[-1].average_loss(N)) + if stop_on_loss: + loss = self.layers[-1].average_loss(N) + res = (loss < stop_on_loss) * (loss >= 0) + self.stopped_on_loss.write(1 - res) + return res if self.report_loss and self.layers[-1].approx != 5: - print_ln('loss after epoch %s: %s', i, loss.reveal()) + print_ln('loss in epoch %s: %s', i, + (loss_sum.reveal() * cfix(1 / n_per_epoch))) else: print_ln('done with epoch %s', i) time() i.iadd(1) - res = (i < self.n_epochs) + res = True if self.tol > 0: res *= (1 - (loss >= 0) * (loss < self.tol)).reveal() return res print_ln('finished after %s epochs and %s iterations', i, n_iterations) + def run_by_args(self, program, n_runs, batch_size, test_X, test_Y): + for arg in program.args: + m = re.match('rate(.*)', arg) + if m: + self.gamma = MemValue(cfix(float(m.group(1)))) + if 'nomom' in program.args: + self.momentum = 0 + model_input = 'model_input' in program.args + if model_input: + for layer in self.layers: + layer.input_from(0) + else: + self.reset() + @for_range(n_runs) + def _(i): + if not model_input: + start_timer(1) + self.run(batch_size, stop_on_loss=100) + stop_timer(1) + if 'no_acc' in program.args: + return + N = self.layers[0].X.sizes[0] + self.forward(N) + batch = regint.Array(N) + batch.assign_vector(regint.inc(N)) + self.layers[-1].backward(batch) + n_correct = self.layers[-1].reveal_correctness(N, debug=True) + print_ln('train_acc: %s (%s/%s)', cfix(n_correct, k=63, f=32) / N, + n_correct, N) + training_address = self.layers[0].X.address + self.layers[0].X.address = test_X.address + n_test = len(test_Y) + self.forward(n_test) + self.layers[0].X.address = training_address + n_correct = self.layers[-1].reveal_correctness(n_test, test_Y) + print_ln('acc: %s (%s/%s)', cfix(n_correct, k=63, f=32) / n_test, + n_correct, n_test) + if model_input: + start_timer(1) + self.run(batch_size) + stop_timer(1) + else: + @if_(util.or_op(self.stopped_on_loss, n_correct < + int(n_test // self.layers[-1].n_outputs * 1.1))) + def _(): + self.gamma.imul(.5) + self.reset() + print_ln('reset after reducing learning rate to %s', + self.gamma) + class Adam(Optimizer): def __init__(self, layers, n_epochs): self.alpha = .001 @@ -1318,7 +1698,7 @@ class SGD(Optimizer): :param n_epochs: number of epochs for training :param report_loss: disclose and print loss """ - def __init__(self, layers, n_epochs, debug=False, report_loss=False): + def __init__(self, layers, n_epochs, debug=False, report_loss=None): self.momentum = 0.9 self.layers = layers self.n_epochs = n_epochs @@ -1330,11 +1710,19 @@ def __init__(self, layers, n_epochs, debug=False, report_loss=False): self.thetas.extend(layer.thetas()) for theta in layer.thetas(): self.delta_thetas.append(theta.same_shape()) - self.gamma = MemValue(sfix(0.01)) + self.gamma = MemValue(cfix(0.01)) self.debug = debug - self.report_loss = report_loss + if report_loss is None: + self.report_loss = layers[-1].compute_loss + else: + self.report_loss = report_loss self.tol = 0.000 self.X_by_label = None + self.print_update_average = False + self.print_losses = False + self.print_loss_reduction = False + self.i_epoch = MemValue(0) + self.stopped_on_loss = MemValue(0) def reset(self, X_by_label=None): """ Reset layer parameters. @@ -1353,40 +1741,64 @@ def _(i): y.assign_all(0) for layer in self.layers: layer.reset() + self.i_epoch.write(0) + self.stopped_on_loss.write(0) - def update(self, i_epoch): + def update(self, i_epoch, batch): for nabla, theta, delta_theta in zip(self.nablas, self.thetas, self.delta_thetas): - @multithread(self.n_threads, len(nabla)) + @multithread(self.n_threads, nabla.total_size()) def _(base, size): old = delta_theta.get_vector(base, size) red_old = self.momentum * old - new = self.gamma * nabla.get_vector(base, size) + rate = self.gamma.expand_to_vector(size) + nabla_vector = nabla.get_vector(base, size) + log_batch_size = math.log(len(batch), 2) + # divide by len(batch) by truncation + # increased rate if len(batch) is not a power of two + pre_trunc = nabla_vector.v * rate.v + k = nabla_vector.k + rate.k + m = rate.f + int(log_batch_size) + v = pre_trunc.round(k, m, signed=True, + nearest=sfix.round_nearest) + new = nabla_vector._new(v) diff = red_old - new delta_theta.assign_vector(diff, base) theta.assign_vector(theta.get_vector(base, size) + delta_theta.get_vector(base, size), base) - if self.debug: - for x, name in (old, 'old'), (red_old, 'red_old'), \ - (new, 'new'), (diff, 'diff'): - x = x.reveal() - print_ln_if((x > 1000) + (x < -1000), - name + ': %s %s %s %s', - *[y.v.reveal() for y in (old, red_old, \ - new, diff)]) + if self.print_update_average: + vec = abs(delta_theta.get_vector().reveal()) + print_ln('update average: %s (%s)', + sum(vec) * cfix(1 / len(vec)), len(vec)) if self.debug: + limit = int(self.debug) d = delta_theta.get_vector().reveal() - a = cfix.Array(len(d.v)) + aa = [cfix.Array(len(d.v)) for i in range(3)] + a = aa[0] a.assign(d) @for_range(len(a)) def _(i): x = a[i] - print_ln_if((x > 1000) + (x < -1000), - 'update len=%d' % len(nabla)) + print_ln_if((x > limit) + (x < -limit), + 'update epoch=%s %s index=%s %s', + i_epoch.read(), str(delta_theta), i, x) + a = aa[1] a.assign(nabla.get_vector().reveal()) @for_range(len(a)) def _(i): x = a[i] - print_ln_if((x > 1000) + (x < -1000), - 'nabla len=%d' % len(nabla)) + print_ln_if((x > len(batch) * limit) + (x < -len(batch) * limit), + 'nabla epoch=%s %s index=%s %s', + i_epoch.read(), str(nabla), i, x) + a = aa[2] + a.assign(theta.get_vector().reveal()) + @for_range(len(a)) + def _(i): + x = a[i] + print_ln_if((x > limit) + (x < -limit), + 'theta epoch=%s %s index=%s %s', + i_epoch.read(), str(theta), i, x) + index = regint.get_random(64) % len(a) + print_ln('%s at %s: nabla=%s update=%s theta=%s', str(theta), index, + aa[1][index], aa[0][index], aa[2][index]) self.gamma.imul(1 - 10 ** - 6) diff --git a/Compiler/mpc_math.py b/Compiler/mpc_math.py index a8014266b..918b0c591 100644 --- a/Compiler/mpc_math.py +++ b/Compiler/mpc_math.py @@ -13,6 +13,7 @@ from Compiler import comparison from Compiler import program from Compiler import instructions_base +from Compiler import library, util # polynomials as enumerated on Hart's book ## @@ -33,11 +34,8 @@ 0.00000000000000000040] ## # @private -p_1045 = [1.000000077443021686, 0.693147180426163827795756, - 0.224022651071017064605384, 0.055504068620466379157744, - 0.009618341225880462374977, 0.001332730359281437819329, - 0.000155107460590052573978, 0.000014197847399765606711, - 0.000001863347724137967076] +p_1045 = [math.log(2) ** i / math.factorial(i) for i in range(12)] + ## # @private p_2524 = [-2.05466671951, -8.8626599391, @@ -92,8 +90,8 @@ # # @return truncated sint value of x def trunc(x): - if type(x) is types.sfix: - return floatingpoint.Trunc(x.v, x.k, x.f, x.kappa, signed=True) + if isinstance(x, types._fix): + return x.v.right_shift(x.f, x.k, security=x.kappa, signed=True) elif type(x) is types.sfloat: v, p, z, s = floatingpoint.FLRound(x, 0) #return types.sfloat(v, p, z, s, x.err) @@ -125,7 +123,7 @@ def load_sint(x, l_type): # @return the evaluation of the polynomial. return type depends on inputs. def p_eval(p_c, x): degree = len(p_c) - 1 - if type(x) is types.sfix: + if isinstance(x, types._fix): # ignore coefficients smaller than precision for c in reversed(p_c): if c < 2 ** -(x.f + 1): @@ -160,10 +158,10 @@ def sTrigSub(x): y = x - (f) * x.coerce(2 * pi) # reduction to \pi b1 = y > pi - w = b1 * ((2 * pi - y) - y) + y + w = b1.if_else(2 * pi - y, y) # reduction to \pi/2 b2 = w > pi_over_2 - w = b2 * ((pi - w) - w) + w + w = b2.if_else(pi - w, w) # returns scaled angle and boolean flags return w, b1, b2 @@ -182,9 +180,8 @@ def ssin(w, s): v = w * (1.0 / pi_over_2) v_2 = v ** 2 # adjust sign according to the movement in the reduction - b = s * (-2) + 1 # calculate the sin using polynomial evaluation - local_sin = b * v * p_eval(p_3307, v_2) + local_sin = s.if_else(-v, v) * p_eval(p_3307, v_2) return local_sin @@ -203,10 +200,10 @@ def scos(w, s): # calculates the v of the w. v = w v_2 = v ** 2 - # adjust sign according to the movement in the reduction - b = s * (-2) + 1 # calculate the cos using polynomial evaluation - local_cos = b * p_eval(p_3508, v_2) + tmp = p_eval(p_3508, v_2) + # adjust sign according to the movement in the reduction + local_cos = s.if_else(-tmp, tmp) return local_cos @@ -264,11 +261,12 @@ def tan(x): @types.vectorize @instructions_base.sfix_cisc -def exp2_fx(a): +def exp2_fx(a, zero_output=False): """ Power of two for fixed-point numbers. :param a: exponent for :math:`2^a` (sfix) + :param zero_output: whether to output zero for very small values. If not, the result will be undefined. :return: :math:`2^a` if it is within the range. Undefined otherwise """ @@ -279,54 +277,95 @@ def exp2_fx(a): n_int_bits = int(math.ceil(math.log(a.k - a.f, 2))) n_bits = a.f + n_int_bits n_shift = int(types.program.options.ring) - a.k - if types.program.use_edabit(): - l = sint.get_edabit(a.f, True) - u = sint.get_edabit(a.k - a.f, True) - r_bits = l[1] + u[1] - r = l[0] + (u[0] << a.f) - lower_r = l[0] + if types.program.use_split(): + assert not zero_output + from Compiler.GC.types import sbitvec + if types.program.use_split() == 3: + x = a.v.split_to_two_summands(a.k) + bits = types._bitint.carry_lookahead_adder(x[0], x[1], + fewer_inv=False) + # converting MSB first reduces the number of rounds + s = sint.conv(bits[-1]) + lower_overflow = sint.conv(x[0][a.f]) + \ + sint.conv(x[0][a.f] ^ x[1][a.f] ^ bits[a.f]) + lower = a.v.raw_mod2m(a.f) - (lower_overflow << a.f) + elif types.program.use_split() == 4: + x = list(zip(*a.v.split_to_n_summands(a.k, 4))) + bi = types._bitint + red = bi.wallace_reduction + sums1, carries1 = red(*x[:3], get_carry=False) + sums2, carries2 = red(x[3], sums1, carries1, False) + bits = bi.carry_lookahead_adder(sums2, carries2, + fewer_inv=False) + overflows = bi.full_adder(carries1[a.f], carries2[a.f], + bits[a.f] ^ sums2[a.f] ^ carries2[a.f]) + overflows = reversed(list((sint.conv(x) + for x in reversed(overflows)))) + lower_overflow = sint.bit_compose(sint.conv(x) + for x in overflows) + s = sint.conv(bits[-1]) + lower = a.v.raw_mod2m(a.f) - (lower_overflow << a.f) + else: + bits = sbitvec(a.v, a.k) + s = sint.conv(bits[-1]) + lower = sint.bit_compose(sint.conv(b) for b in bits[:a.f]) + higher_bits = bits[a.f:n_bits] else: - r_bits = [sint.get_random_bit() for i in range(a.k)] - r = sint.bit_compose(r_bits) - lower_r = sint.bit_compose(r_bits[:a.f]) - shifted = ((a.v - r) << n_shift).reveal() - masked_bits = (shifted >> n_shift).bit_decompose(a.k) - lower_overflow = comparison.CarryOutRaw(masked_bits[a.f-1::-1], - r_bits[a.f-1::-1]) - lower_masked = sint.bit_compose(masked_bits[:a.f]) - lower = lower_r + lower_masked - (sint.conv(lower_overflow) << (a.f)) + if types.program.use_edabit(): + l = sint.get_edabit(a.f, True) + u = sint.get_edabit(a.k - a.f, True) + r_bits = l[1] + u[1] + r = l[0] + (u[0] << a.f) + lower_r = l[0] + else: + r_bits = [sint.get_random_bit() for i in range(a.k)] + r = sint.bit_compose(r_bits) + lower_r = sint.bit_compose(r_bits[:a.f]) + shifted = ((a.v - r) << n_shift).reveal() + masked_bits = (shifted >> n_shift).bit_decompose(a.k) + lower_overflow = comparison.CarryOutRaw(masked_bits[a.f-1::-1], + r_bits[a.f-1::-1]) + lower_masked = sint.bit_compose(masked_bits[:a.f]) + lower = lower_r + lower_masked - \ + (sint.conv(lower_overflow) << (a.f)) + higher_bits = r_bits[0].bit_adder(r_bits[a.f:n_bits], + masked_bits[a.f:n_bits], + carry_in=lower_overflow, + get_carry=True) + carry = comparison.CarryOutLE(masked_bits[n_bits:-1], + r_bits[n_bits:-1], + higher_bits[-1]) + if zero_output: + # should be for free + highest_bits = r_bits[0].ripple_carry_adder( + masked_bits[n_bits:-1], [0] * (a.k - n_bits), + carry_in=higher_bits[-1]) + bits_to_check = [x.bit_xor(y) + for x, y in zip(highest_bits[:-1], + r_bits[n_bits:-1])] + t = sint.conv(floatingpoint.KOpL(lambda x, y: x.bit_and(y), + bits_to_check)) + # sign + s = carry.bit_xor(sint.conv(r_bits[-1])).bit_xor(masked_bits[-1]) + del higher_bits[-1] c = types.sfix._new(lower, k=a.k, f=a.f) - higher_bits = r_bits[0].bit_adder(r_bits[a.f:n_bits], - masked_bits[a.f:n_bits], - carry_in=lower_overflow, - get_carry=True) - assert(len(higher_bits) == n_bits - a.f + 1) + assert(len(higher_bits) == n_bits - a.f) pow2_bits = [sint.conv(x) for x in higher_bits] - d = floatingpoint.Pow2_from_bits(pow2_bits[:-1]) + d = floatingpoint.Pow2_from_bits(pow2_bits) e = p_eval(p_1045, c) g = d * e small_result = types.sfix._new(g.v.round(a.f + 2 ** n_int_bits, 2 ** n_int_bits, signed=False, nearest=types.sfix.round_nearest), k=a.k, f=a.f) - carry = comparison.CarryOutLE(masked_bits[n_bits:-1], - r_bits[n_bits:-1], - higher_bits[-1]) - # should be for free - highest_bits = r_bits[0].ripple_carry_adder( - masked_bits[n_bits:-1], [0] * (a.k - n_bits), - carry_in=higher_bits[-1]) - bits_to_check = [x.bit_xor(y) - for x, y in zip(highest_bits[:-1], r_bits[n_bits:-1])] - t = sint.conv(floatingpoint.KOpL(lambda x, y: x.bit_and(y), - bits_to_check)) - # sign - s = carry.bit_xor(sint.conv(r_bits[-1])).bit_xor(masked_bits[-1]) - return s.if_else(t.if_else(small_result, 0), g) + if zero_output: + small_result = t.if_else(small_result, 0) + return s.if_else(small_result, g) else: + assert not zero_output # obtain absolute value of a s = a < 0 - a = (s * (-2) + 1) * a + a = s.if_else(-a, a) # isolates fractional part of number b = trunc(a) c = a - b @@ -335,7 +374,7 @@ def exp2_fx(a): # evaluates fractional part of a in p_1045 e = p_eval(p_1045, c) g = d * e - return (1 - s) * g + s / g + return s.if_else(1 / g, g) @types.vectorize @@ -353,19 +392,20 @@ def log2_fx(x): :return: (sfix) the value of :math:`\log_2(x)` """ - if type(x) is types.sfix: + if isinstance(x, types._fix): # transforms sfix to f*2^n, where f is [o.5,1] bounded # obtain number bounded by [0,5 and 1] by transforming input to sfloat v, p, z, s = floatingpoint.Int2FL(x.v, x.k, x.f, x.kappa) p -= x.f vlen = x.f + v = x._new(v, k=x.k, f=x.f) else: d = types.sfloat(x) v, p, vlen = d.v, d.p, d.vlen + w = x.coerce(1.0 / (2 ** (vlen))) + v *= w # isolates mantisa of d, now the n can be also substituted by the # secret shared p from d in the expresion above. - w = x.coerce(1.0 / (2 ** (vlen))) - v = v * w # polynomials for the log_2 evaluation of f are calculated P = p_eval(p_2524, v) Q = p_eval(q_2524, v) @@ -384,7 +424,7 @@ def pow_fx(x, y): :param y: (sfix, clear types) secret shared exponent. - :return: :math:`x^y` (sfix) + :return: :math:`x^y` (sfix) if positive and in range """ log2_x =0 # obtains log2(x) @@ -456,9 +496,6 @@ def floor_fx(x): def MSB(b, k): # calculation of z # x in order 0 - k - if (k > types.program.bit_length): - raise OverflowError("The supported bit \ - lenght of the application is smaller than k") x_order = b.bit_decompose(k) x = [0] * k @@ -511,9 +548,7 @@ def norm_simplified_SQ(b, k): w_array[i] = z[2 * i - 1] + z[2 * i] # w aggregation - w = types.sint(0) - for i in range(k_over_2): - w += (2 ** i) * w_array[i] + w = b.bit_compose(w_array) # return computed values #return m_odd, m, w @@ -538,9 +573,9 @@ def sqrt_simplified_fx(x): # process to set up the precision and allocate correct 2**f if x.f % 2 == 1: m_odd = (1 - 2 * m_odd) + m_odd - w = (w * 2 - w) * (1-m_odd) + w + w = m_odd.if_else(w, 2 * w) # map number to use sfix format and instantiate the number - w = types.sfix(w * 2 ** ((x.f - (x.f % 2)) // 2), k=x.k, f=x.f) + w = x._new(w << ((x.f - (x.f % 2)) // 2), k=x.k, f=x.f) # obtains correct 2 ** (m/2) w = (w * (2 ** (1/2.0)) - w) * m_odd + w # produce x/ 2^(m/2) @@ -739,15 +774,15 @@ def atan(x): """ # obtain absolute value of x s = x < 0 - x_abs = (s * (-2) + 1) * x + x_abs = s.if_else(-x, x) # angle isolation b = x_abs > 1 v = 1 / x_abs - v = (1 - b) * (x_abs - v) + v + v = b.if_else(v, x_abs) v_2 =v*v # range of polynomial coefficients - assert x.k - x.f >= 15 + assert x.k - x.f >= 19 P = p_eval(p_5102, v_2) Q = p_eval(q_5102, v_2) @@ -756,8 +791,8 @@ def atan(x): y_pi_over_two = pi_over_2 - y # sign correction - y = (1 - b) * (y - y_pi_over_two) + y_pi_over_two - y = (1 - s) * (y - (-y)) + (-y) + y = b.if_else(y_pi_over_two, y) + y = s.if_else(-y, y) return y diff --git a/Compiler/program.py b/Compiler/program.py index 326dc796e..38364bad4 100644 --- a/Compiler/program.py +++ b/Compiler/program.py @@ -98,6 +98,8 @@ def __init__(self, args, options): self.use_dabit = options.mixed self._edabit = options.edabit self._split = False + if options.split: + self.use_split(int(options.split)) self._square = False self._always_raw = False Program.prog = self @@ -243,10 +245,12 @@ def curr_block(self): """ The basic block that is currently being created. """ return self.curr_tape.active_basicblock - def malloc(self, size, mem_type, reg_type=None): + def malloc(self, size, mem_type, reg_type=None, creator_tape=None): """ Allocate memory from the top """ if not isinstance(size, int): raise CompilerError('size must be known at compile time') + if (creator_tape or self.curr_tape) != self.tapes[0]: + raise CompilerError('cannot allocate memory outside main thread') if size == 0: return if isinstance(mem_type, type): @@ -330,7 +334,9 @@ def use_split(self, change=None): if change is None: return self._split else: - assert change in (2, 3) + if change and not self.options.ring: + raise CompilerError('splitting only supported for rings') + assert change > 1 self._split = change def use_square(self, change=None): @@ -350,8 +356,12 @@ def options_from_args(self): self.use_trunc_pr = True if 'split' in self.args or 'split3' in self.args: self.use_split(3) + if 'split4' in self.args: + self.use_split(4) if 'raw' in self.args: self.always_raw(True) + if 'edabit' in self.args: + self.use_edabit(True) class Tape: """ A tape contains a list of basic blocks, onto which instructions are added. """ @@ -559,12 +569,14 @@ def optimize(self, options): numrounds = merger.longest_paths_merge() block.n_rounds = numrounds block.n_to_merge = len(merger.open_nodes) - if numrounds > 0 and self.program.verbose: - print('Program requires %d rounds of communication' % numrounds) if merger.counter and self.program.verbose: print('Block requires', \ ', '.join('%d %s' % (y, x.__name__) \ for x, y in list(merger.counter.items()))) + if merger.counter and self.program.verbose: + print('Block requires %s rounds' % \ + ', '.join('%d %s' % (y, x.__name__) \ + for x, y in list(merger.rounds.items()))) # free memory merger = None if options.dead_code_elimination: diff --git a/Compiler/types.py b/Compiler/types.py index 91b41ad6c..41bedba00 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -10,6 +10,8 @@ and thread-specific. The memory is allocated statically and shared between threads. This means that memory-based types such as :py:class:`Array` can be used to transfer information between threads. +Note that creating memory-based types outside the main thread is not +supported. If viewing this documentation in processed form, many function signatures appear generic because of the use of decorators. See the source code for the @@ -74,6 +76,7 @@ from .util import is_zero, is_one import operator from functools import reduce +import re class ClientMessageType: @@ -118,17 +121,6 @@ def join(self): program.join_tape(self.run_handles.pop(0)) -def copy_doc(a, b): - try: - a.__doc__ = b.__doc__ - except: - pass - -def no_doc(operation): - def wrapper(*args, **kwargs): - return operation(*args, **kwargs) - return wrapper - def copy_doc(a, b): try: a.__doc__ = b.__doc__ @@ -316,6 +308,14 @@ def _(i): res.iadd(res.value_type.conv(aa[i] * bb[i])) return res.read() + def __abs__(self): + """ Absolute value. """ + return (self < 0).if_else(-self, self) + + @staticmethod + def popcnt_bits(bits): + return sum(bits) + class _int(object): """ Integer functionality. """ @@ -534,11 +534,11 @@ def bit_compose(bits): return sum(b << i for i,b in enumerate(bits)) @classmethod - def malloc(cls, size): + def malloc(cls, size, creator_tape=None): """ Allocate memory (statically). :param size: compile-time (int) """ - return program.malloc(size, cls) + return program.malloc(size, cls, creator_tape=creator_tape) @classmethod def free(cls, addr): @@ -833,6 +833,20 @@ def __rmod__(self, other): :param other: cint/regint/int """ return self.coerce_op(other, modc, True) + def less_than(self, other, bit_length): + """ Clear comparison for particular bit length. + + :param other: cint/regint/int + :param bit_length: signed bit length of inputs + :return: 0/1 (regint), undefined if inputs outside range """ + if bit_length <= 64: + return self < other + else: + diff = self - other + shifted = diff >> (bit_length - 1) + res = regint(shifted & 1) + return res + def __lt__(self, other): """ Clear 64-bit comparison. @@ -1732,15 +1746,20 @@ def get_random_int(cls, bits): """ Secret random n-bit number according to security model. :param bits: compile-time integer (int) """ - if program.use_split() == 3: + if program.use_edabit(): + return sint.get_edabit(bits, True)[0] + elif program.use_split() > 2: tmp = sint() randoms(tmp, bits) x = tmp.split_to_two_summands(bits, True) - overflow = comparison.CarryOutLE(x[1][:-1], x[0][:-1]) + \ - sint.conv(x[0][-1]) + carry = comparison.CarryOutRawLE(x[1][:bits], x[0][:bits]) + if program.use_split() > 3: + from .GC.types import sbitint + x = sbitint.full_adder(carry, x[0][bits], x[1][bits]) + overflow = sint.conv(x[1]) * 2 + sint.conv(x[0]) + else: + overflow = sint.conv(carry) + sint.conv(x[0][bits]) return tmp - (overflow << bits) - elif program.use_edabit(): - return sint.get_edabit(bits, True)[0] res = sint() comparison.PRandInt(res, bits) return res @@ -1791,7 +1810,7 @@ def long_one(): def bit_decompose_clear(a, n_bits): return floatingpoint.bits(a, n_bits) - @classmethod + @vectorized_classmethod def get_raw_input_from(cls, player): res = cls() rawinput(player, res) @@ -1980,7 +1999,7 @@ def __lshift__(self, other, bit_length=None, security=None): @vectorize @read_mem_value - def __rshift__(self, other, bit_length=None, security=None): + def __rshift__(self, other, bit_length=None, security=None, signed=True): """ Secret right shift. :param other: secret or public integer (sint/cint/regint/int) """ @@ -1990,7 +2009,7 @@ def __rshift__(self, other, bit_length=None, security=None): if other == 0: return self res = sint() - comparison.Trunc(res, self, bit_length, other, security, True) + comparison.Trunc(res, self, bit_length, other, security, signed) return res elif isinstance(other, sint): return floatingpoint.Trunc(self, bit_length, other, security) @@ -2092,6 +2111,15 @@ def split_to_two_summands(self, length, get_carry=False): columns = self.split_to_n_summands(length, n) return _bitint.wallace_tree_without_finish(columns, get_carry) + @vectorize + def raw_right_shift(self, length): + res = sint() + shrsi(res, self, length) + return res + + def raw_mod2m(self, m): + return self - (self.raw_right_shift(m) << m) + @vectorize def reveal_to(self, player): """ Reveal secret value to :py:obj:`player`. @@ -2304,6 +2332,14 @@ def carry_lookahead_adder(cls, a, b, fewer_inv=False, carry_in=0, b.pop(0) else: break + carries = cls.get_carries(a, b, fewer_inv=fewer_inv, carry_in=carry_in) + res = lower + cls.sum_from_carries(a, b, carries) + if get_carry: + res += [carries[-1]] + return res + + @classmethod + def get_carries(cls, a, b, fewer_inv=False, carry_in=0): d = [cls.half_adder(ai, bi) for (ai,bi) in zip(a,b)] carry = floatingpoint.carry if fewer_inv: @@ -2314,10 +2350,7 @@ def carry_lookahead_adder(cls, a, b, fewer_inv=False, carry_in=0, carries = list(zip(*pre_op(carry, [(0, carry_in)] + d)))[1] else: carries = [] - res = lower + cls.sum_from_carries(a, b, carries) - if get_carry: - res += [carries[-1]] - return res + return carries @staticmethod def sum_from_carries(a, b, carries): @@ -2469,6 +2502,18 @@ def wallace_tree_from_columns(cls, columns, get_carry=True): def wallace_tree(cls, rows): return cls.wallace_tree_from_columns([list(x) for x in zip(*rows)]) + @classmethod + def wallace_reduction(cls, a, b, c, get_carry=True): + assert len(a) == len(b) == len(c) + tmp = zip(*(cls.full_adder(*x) for x in zip(a, b, c))) + sums, carries = (list(x) for x in tmp) + carries = [0] + carries + if get_carry: + sums += [0] + else: + del carries[-1] + return sums, carries + def __sub__(self, other): if type(other) == sgf2n: raise CompilerError('Unclear subtraction') @@ -2497,7 +2542,7 @@ def __lshift__(self, other): def __rshift__(self, other): return self.compose(self.bit_decompose()[other:]) - def bit_decompose(self, n_bits=None, *args): + def bit_decompose(self, n_bits=None, security=None): if self.bits is None: self.bits = self.force_bit_decompose(self.n_bits) if n_bits is None: @@ -2541,14 +2586,16 @@ def __ge__(self, other): def __gt__(self, other): return 1 - (self <= other) - def __eq__(self, other): + def __eq__(self, other, bit_length=None, security=None): diff = self ^ other - diff_bits = [1 - x for x in diff.bit_decompose()] + diff_bits = [1 - x for x in diff.bit_decompose()[:bit_length]] return floatingpoint.KMul(diff_bits) def __ne__(self, other): return 1 - (self == other) + equal = __eq__ + def __neg__(self): return 1 + self.compose(1 ^ b for b in self.bit_decompose()) @@ -2752,9 +2799,9 @@ def parse_type(other, k=None, f=None): class cfix(_number, _structure): """ Clear fixed-point number represented as clear integer. """ - __slots__ = ['value', 'f', 'k', 'size'] + __slots__ = ['value', 'f', 'k'] reg_type = 'c' - scalars = (int, float, regint) + scalars = (int, float, regint, cint) @classmethod def set_precision(cls, f, k = None): """ Set the precision of the integer representation. Note that some @@ -2779,7 +2826,7 @@ def set_precision(cls, f, k = None): @vectorized_classmethod def load_mem(cls, address, mem_type=None): """ Load from memory by public address. """ - return cls(cint.load_mem(address)) + return cls._new(cint.load_mem(address)) @vectorized_classmethod def read_from_socket(cls, client_id, n=1): @@ -2787,7 +2834,7 @@ def read_from_socket(cls, client_id, n=1): Sender will have already bit shifted and sent as cints.""" cint_input = cint.read_from_socket(client_id, n) if n == 1: - return cfix(cint_inputs) + return cfix._new(cint_inputs) else: return list(map(cfix, cint_inputs)) @@ -2805,34 +2852,47 @@ def cfix_to_cint(fix_val): writesocketc(client_id, message_type, *cint_values) @staticmethod - def malloc(size): - return program.malloc(size, cint) + def malloc(size, creator_tape=None): + return program.malloc(size, cint, creator_tape=creator_tape) @staticmethod def n_elements(): return 1 + @classmethod + def from_int(cls, other): + res = cls() + res.load_int(other) + return res + + @classmethod + def _new(cls, other, k=None, f=None): + res = cls(k=k, f=f) + res.v = cint.conv(other) + return res + + @staticmethod + def int_rep(v, f): + v = v * (2 ** f) + try: + v = int(round(v)) + except TypeError: + pass + return v + @vectorize_init + @read_mem_value def __init__(self, v=None, k=None, f=None, size=None): """ :param v: cfix/float/int """ f = self.f if f is None else f k = self.k if k is None else k self.f = f self.k = k - self.size = get_global_vector_size() - if isinstance(v, cint): - self.v = cint(v,size=self.size) - elif isinstance(v, cfix.scalars): - v = v * (2 ** f) - try: - v = int(round(v)) - except TypeError: - pass - self.v = cint(v, size=self.size) + if isinstance(v, cfix.scalars): + v = self.int_rep(v, f) + self.v = cint(v, size=size) elif isinstance(v, cfix): self.v = v.v - elif isinstance(v, MemValue): - self.v = v elif v is None: self.v = cint(0) else: @@ -2840,7 +2900,10 @@ def __init__(self, v=None, k=None, f=None, size=None): def __iter__(self): for x in self.v: - yield type(self)(x, self.k, self.f) + yield self._new(x, self.k, self.f) + + def __len__(self): + return len(self.v) @vectorize def load_int(self, v): @@ -2863,6 +2926,10 @@ def store_in_mem(self, address): """ Store in memory by public address. """ self.v.store_in_mem(address) + @property + def size(self): + return self.v.size + def sizeof(self): return self.size * 4 @@ -2873,7 +2940,7 @@ def add(self, other): :param other: cfix/cint/regint/int """ other = parse_type(other) if isinstance(other, cfix): - return cfix(self.v + other.v) + return cfix._new(self.v + other.v) else: return NotImplemented @@ -2884,18 +2951,26 @@ def mul(self, other): :param other: cfix/cint/regint/int/sint """ if isinstance(other, sint): return sfix._new(self.v * other, k=self.k, f=self.f) + if isinstance(other, (int, regint, cint)): + return cfix._new(self.v * cint(other), k=self.k, f=self.f) other = parse_type(other) if isinstance(other, cfix): assert self.f == other.f - sgn = cint(1 - 2 * (self.v * other.v < 0)) + sgn = cint(1 - 2 * ((self < 0) ^ (other < 0))) absolute = self.v * other.v * sgn val = sgn * (absolute >> self.f) - return cfix(val) + return cfix._new(val) elif isinstance(other, sfix): return NotImplemented else: raise CompilerError('Invalid type %s for cfix.__mul__' % type(other)) - + + def positive_mul(self, other): + assert isinstance(other, float) + assert other >= 0 + v = self.v * int(round(other * 2 ** self.f)) + return self._new(v >> self.f, k=self.k, f=self.f) + @vectorize def __sub__(self, other): """ Clear fixed-point subtraction. @@ -2903,9 +2978,9 @@ def __sub__(self, other): :param other: cfix/cint/regint/int """ other = parse_type(other) if isinstance(other, cfix): - return cfix(self.v - other.v) + return cfix._new(self.v - other.v) elif isinstance(other, sfix): - return sfix(self.v - other.v) + return sfix._new(self.v - other.v) else: raise NotImplementedError @@ -2913,7 +2988,7 @@ def __sub__(self, other): def __neg__(self): """ Clear fixed-point negation. """ # cfix type always has .v - return cfix(-self.v) + return cfix._new(-self.v) def __rsub__(self, other): return -self + other @@ -2939,7 +3014,8 @@ def __lt__(self, other): """ Clear fixed-point comparison. """ other = parse_type(other) if isinstance(other, cfix): - return self.v < other.v + assert self.k == other.k + return self.v.less_than(other.v, self.k) elif isinstance(other, sfix): if(self.k != other.k or self.f != other.f): raise TypeError('Incompatible fixed point types in comparison') @@ -2952,7 +3028,7 @@ def __le__(self, other): """ Clear fixed-point comparison. """ other = parse_type(other) if isinstance(other, cfix): - return self.v <= other.v + return 1 - (self > other) elif isinstance(other, sfix): return other.v.greater_equal(self.v, self.k, other.kappa) else: @@ -2963,7 +3039,7 @@ def __gt__(self, other): """ Clear fixed-point comparison. """ other = parse_type(other) if isinstance(other, cfix): - return self.v > other.v + return other.__lt__(self) elif isinstance(other, sfix): return other.v.less_than(self.v, self.k, other.kappa) else: @@ -2974,7 +3050,7 @@ def __ge__(self, other): """ Clear fixed-point comparison. """ other = parse_type(other) if isinstance(other, cfix): - return self.v >= other.v + return 1 - (self < other) elif isinstance(other, sfix): return other.v.less_equal(self.v, self.k, other.kappa) else: @@ -3000,9 +3076,10 @@ def __truediv__(self, other): """ Clear fixed-point division. :param other: cfix/cint/regint/int """ - other = parse_type(other) + other = parse_type(other, self.k, self.f) if isinstance(other, cfix): - return cfix(library.cint_cint_division(self.v, other.v, self.k, self.f)) + return cfix._new(library.cint_cint_division( + self.v, other.v, self.k, self.f), k=self.k, f=self.f) elif isinstance(other, sfix): assert self.k == other.k assert self.f == other.f @@ -3016,11 +3093,11 @@ def __truediv__(self, other): def print_plain(self): """ Clear fixed-point output. """ if self.k > 64: - raise CompilerError('Printing of fixed-point numbers not ' + - 'implemented for more than 64-bit precision') - tmp = regint() - convmodp(tmp, self.v, bitlength=self.k) - sign = cint(tmp < 0) + sign = (((self.v + (1 << (self.k - 1))) >> self.k) & 1) + else: + tmp = regint() + convmodp(tmp, self.v, bitlength=self.k) + sign = cint(tmp < 0) abs_v = sign.if_else(-self.v, self.v) print_float_plain(cint(abs_v), cint(-self.f), \ cint(0), cint(sign), cint(0)) @@ -3064,8 +3141,8 @@ def coerce(cls, other): return cls.conv(other) @classmethod - def malloc(cls, size): - return program.malloc(size, cls.int_type) + def malloc(cls, size, creator_tape=None): + return program.malloc(size, cls.int_type, creator_tape=creator_tape) @classmethod def free(cls, addr): @@ -3192,7 +3269,7 @@ def __ne__(self, other): class _fix(_single): """ Secret fixed point type. """ - __slots__ = ['v', 'f', 'k', 'size'] + __slots__ = ['v', 'f', 'k'] def set_precision(cls, f, k = None): cls.f = f @@ -3200,12 +3277,28 @@ def set_precision(cls, f, k = None): if k is None: cls.k = 2 * f else: - if k < f: - raise CompilerError('bit length cannot be less than precision') cls.k = k set_precision.__doc__ = cfix.set_precision.__doc__ set_precision = classmethod(set_precision) + @classmethod + def set_precision_from_args(cls, program): + f = None + k = None + for arg in program.args: + m = re.match('f([0-9]+)$', arg) + if m: + f = int(m.group(1)) + m = re.match('k([0-9]+)$', arg) + if m: + k = int(m.group(1)) + if f is not None: + print ('Setting fixed-point precision to %d/%s' % (f, k)) + cls.set_precision(f, k) + cfix.set_precision(f, k) + elif k is not None: + raise CompilerError('need to set fractional precision') + @classmethod def coerce(cls, other): if isinstance(other, (_fix, cls.clear_type)): @@ -3224,13 +3317,13 @@ def from_sint(cls, other, k=None, f=None): @classmethod def _new(cls, other, k=None, f=None): - res = cls(other, k=k, f=f) + res = cls(k=k, f=f) + res.v = cls.int_type.conv(other) return res @vectorize_init def __init__(self, _v=None, k=None, f=None, size=None): - """ :params _v: compile-time value (int/float) """ - self.size = get_global_vector_size() + """ :params _v: int/float/regint/cint/sint/sfloat """ if k is None: k = self.k else: @@ -3241,15 +3334,12 @@ def __init__(self, _v=None, k=None, f=None, size=None): self.f = f assert k is not None assert f is not None - # warning: don't initialize a sfix from a sint, this is only used in internal methods; - # for external initialization use load_int. if _v is None: self.v = self.int_type(0) elif isinstance(_v, self.int_type): - self.v = _v - self.size = _v.size + self.load_int(_v) elif isinstance(_v, cfix.scalars): - self.v = self.int_type(int(round(_v * (2 ** f))), size=self.size) + self.v = self.int_type(cfix.int_rep(_v, f=f), size=size) elif isinstance(_v, self.float_type): p = (f + _v.p) b = (p.greater_equal(0, _v.vlen)) @@ -3265,7 +3355,6 @@ def __init__(self, _v=None, k=None, f=None, size=None): if not isinstance(self.v, self.int_type): raise CompilerError('sfix conversion failure: %s/%s' % (_v, self.v)) - @vectorize def load_int(self, v): self.v = self.int_type(v) << self.f @@ -3365,7 +3454,7 @@ def reveal(self): class revealed_fix(self.clear_type): f = self.f k = self.k - return revealed_fix(val) + return revealed_fix._new(val) class sfix(_fix): """ Secret fixed-point number represented as secret integer. @@ -3410,6 +3499,24 @@ def direct_matrix_mul(cls, A, B, n, m, l, reduce=True, indices=None): res = res.reduce_after_mul() return res + @classmethod + def dot_product(cls, x, y, res_params=None): + """ Secret dot product. + + :param x: iterable of appropriate secret type + :param y: iterable of appropriate secret type and same length """ + x, y = list(x), list(y) + if res_params is None: + if isinstance(x[0], cls.int_type): + x, y = y, x + if isinstance(y[0], cls.int_type): + return cls._new(cls.int_type.dot_product((xx.v for xx in x), y), + k=x[0].k, f=x[0].f) + return super().dot_product(x, y, res_params) + + def expand_to_vector(self, size): + return self._new(self.v.expand_to_vector(size), k=self.k, f=self.f) + def coerce(self, other): return parse_type(other, k=self.k, f=self.f) @@ -3426,7 +3533,7 @@ def unreduced(self, v, other=None, res_params=None, n_summands=1): @staticmethod def multipliable(v, k, f, size): - return cfix(cint.conv(v, size=size), k, f) + return cfix._new(cint.conv(v, size=size), k, f) def reveal_to(self, player): """ Reveal secret value to :py:obj:`player`. @@ -3436,8 +3543,8 @@ def reveal_to(self, player): :param player: public integer (int/regint/cint) :returns: value to be used with :py:func:`Compiler.library.print_ln_to` """ - return personal(player, cfix(self.v.reveal_to(player)._v, - self.k, self.f)) + return personal(player, cfix._new(self.v.reveal_to(player)._v, + self.k, self.f)) class unreduced_sfix(_single): int_type = sint @@ -3466,10 +3573,9 @@ def __add__(self, other): @vectorize def reduce_after_mul(self): - return sfix(sfix.int_type.round(self.v, self.k, self.m, self.kappa, - nearest=sfix.round_nearest, - signed=True), - k=self.k // 2, f=self.m) + v = sfix.int_type.round(self.v, self.k, self.m, self.kappa, + nearest=sfix.round_nearest, signed=True) + return sfix._new(v, k=self.k // 2, f=self.m) sfix.unreduced_type = unreduced_sfix @@ -3696,8 +3802,9 @@ def n_elements(): return 4 @classmethod - def malloc(cls, size): - return program.malloc(size * cls.n_elements(), sint) + def malloc(cls, size, creator_tape=None): + return program.malloc(size * cls.n_elements(), sint, + creator_tape=creator_tape) @classmethod def is_address_tuple(cls, address): @@ -4142,12 +4249,14 @@ def __init__(self, length, value_type, address=None, debug=None, alloc=True): self.address = address self.address_cache = {} self.debug = debug + self.creator_tape = program.curr_tape if alloc: self.alloc() def alloc(self): if self.address is None: - self.address = self.value_type.malloc(self.length) + self.address = self.value_type.malloc(self.length, + self.creator_tape) def delete(self): self.value_type.free(self.address) @@ -4371,6 +4480,10 @@ def reveal_list(self): reveal_nested = reveal_list + def __str__(self): + return '%s array of length %s at %s' % (self.value_type, len(self), + self.address) + sint.dynamic_array = Array sgf2n.dynamic_array = Array @@ -4471,6 +4584,12 @@ def get_part_vector(self, base=0, size=None): return self.value_type.load_mem(self.address + base * part_size, size=size) + def assign_part_vector(self, vector, base=0): + assert self.value_type.n_elements() == 1 + part_size = reduce(operator.mul, self.sizes[1:]) + assert vector.size <= self.total_size() + vector.store_in_mem(self.address + base * part_size) + def get_addresses(self, *indices): assert self.value_type.n_elements() == 1 assert len(indices) == len(self.sizes) @@ -4592,13 +4711,16 @@ class t(self.value_type): t = self.value_type res_matrix = Matrix(self.sizes[0], other.sizes[1], t) try: - if max(res_matrix.sizes) > 1000: - raise AttributeError() - A = self.get_vector() - B = other.get_vector() - res_matrix.assign_vector( - self.value_type.matrix_mul(A, B, self.sizes[1], - res_params)) + try: + res_matrix.assign_vector(self.direct_mul(other)) + except AttributeError: + if max(res_matrix.sizes) > 1000: + raise AttributeError() + A = self.get_vector() + B = other.get_vector() + res_matrix.assign_vector( + self.value_type.matrix_mul(A, B, self.sizes[1], + res_params)) except (AttributeError, AssertionError): # fallback for sfloat etc. @library.for_range_opt(self.sizes[0]) @@ -4644,7 +4766,60 @@ def direct_mul(self, other, reduce=True, indices=None): self.sizes[0], *other.sizes, reduce=reduce, indices=indices) + def direct_mul_trans(self, other, reduce=True, indices=None): + """ + Matrix multiplication with the transpose of :py:obj:`other` + in the virtual machine. + + :param self: :py:class:`Matrix` / 2-dimensional :py:class:`MultiArray` + :param other: :py:class:`Matrix` / 2-dimensional :py:class:`MultiArray` + :param indices: 4-tuple of :py:class:`regint` vectors for index selection (default is complete multiplication) + :return: Matrix as vector of relevant type (row-major) + + """ + assert len(self.sizes) == 2 + assert len(other.sizes) == 2 + if indices is None: + assert self.sizes[1] == other.sizes[1] + indices = [regint.inc(i) for i in self.sizes + other.sizes[::-1]] + assert len(indices[1]) == len(indices[2]) + indices = list(indices) + indices[3] *= other.sizes[0] + return self.value_type.direct_matrix_mul( + self.address, other.address, None, self.sizes[1], 1, + reduce=reduce, indices=indices) + + def direct_trans_mul(self, other, reduce=True, indices=None): + """ + Matrix multiplication with the transpose of :py:obj:`self` + in the virtual machine. + + :param self: :py:class:`Matrix` / 2-dimensional :py:class:`MultiArray` + :param other: :py:class:`Matrix` / 2-dimensional :py:class:`MultiArray` + :param indices: 4-tuple of :py:class:`regint` vectors for index selection (default is complete multiplication) + :return: Matrix as vector of relevant type (row-major) + + """ + assert len(self.sizes) == 2 + assert len(other.sizes) == 2 + if indices is None: + assert self.sizes[0] == other.sizes[0] + indices = [regint.inc(i) for i in self.sizes[::-1] + other.sizes] + assert len(indices[1]) == len(indices[2]) + indices = list(indices) + indices[1] *= self.sizes[1] + return self.value_type.direct_matrix_mul( + self.address, other.address, None, 1, other.sizes[1], + reduce=reduce, indices=indices) + def direct_mul_to_matrix(self, other): + """ Matrix multiplication in the virtual machine. + + :param self: :py:class:`Matrix` / 2-dimensional :py:class:`MultiArray` + :param other: :py:class:`Matrix` / 2-dimensional :py:class:`MultiArray` + :returns: :py:obj:`Matrix` + + """ res = self.value_type.Matrix(self.sizes[0], other.sizes[1]) res.assign_vector(self.direct_mul(other)) return res @@ -4756,6 +4931,10 @@ def f(sizes): return [f(sizes[1:]) for i in range(sizes[0])] return f(self.sizes) + def __str__(self): + return '%s multi-array of lengths %s at %s' % (self.value_type, + self.sizes, self.address) + class MultiArray(SubMultiArray): """ Multidimensional array. """ def __init__(self, sizes, value_type, debug=None, address=None, alloc=True): @@ -4990,25 +5169,15 @@ def __repr__(self): return 'MemValue(%s,%d)' % (self.value_type, self.address) -class MemFloat(_mem): +class MemFloat(MemValue): def __init__(self, *args): - value = sfloat(*args) - self.v = MemValue(value.v) - self.p = MemValue(value.p) - self.z = MemValue(value.z) - self.s = MemValue(value.s) + super().__init__(sfloat(*args)) def write(self, *args): value = sfloat(*args) - self.v.write(value.v) - self.p.write(value.p) - self.z.write(value.z) - self.s.write(value.s) - - def read(self): - return sfloat(self.v, self.p, self.z, self.s) + super().write(value) -class MemFix(_mem): +class MemFix(MemValue): def __init__(self, *args): arg_type = type(*args) if arg_type == sfix: @@ -5017,22 +5186,10 @@ def __init__(self, *args): value = cfix(*args) else: raise CompilerError('MemFix init argument error') - self.reg_type = value.v.reg_type - self.v = MemValue(value.v) + super().__init__(value) def write(self, *args): - value = sfix(*args) - self.v.write(value.v) - - def reveal(self): - return cfix(self.v.reveal()) - - def read(self): - val = self.v.read() - if isinstance(val, sint): - return sfix(val) - else: - return cfix(val) + super().write(self.value_type(*args)) def getNamedTupleType(*names): class NamedTuple(object): diff --git a/Compiler/util.py b/Compiler/util.py index d7109cd3b..10f2693a2 100644 --- a/Compiler/util.py +++ b/Compiler/util.py @@ -174,8 +174,17 @@ def is_all_ones(x, n): else: return False -def max(x, y): - return if_else(x > y, x, y) +def max(x, y=None): + if y is None: + return tree_reduce(max, x) + else: + return if_else(x > y, x, y) + +def min(x, y=None): + if y is None: + return tree_reduce(min, x) + else: + return if_else(x < y, x, y) def long_one(x): try: diff --git a/ECDSA/Fake-ECDSA.cpp b/ECDSA/Fake-ECDSA.cpp index c794c26b7..c27c23e70 100644 --- a/ECDSA/Fake-ECDSA.cpp +++ b/ECDSA/Fake-ECDSA.cpp @@ -5,6 +5,7 @@ #include "ECDSA/P256Element.h" #include "Tools/mkpath.h" +#include "GC/TinierSecret.h" #include "Protocols/fake-stuff.hpp" #include "Protocols/Share.hpp" diff --git a/Exceptions/Exceptions.h b/Exceptions/Exceptions.h index 8fb60178a..9534780a6 100644 --- a/Exceptions/Exceptions.h +++ b/Exceptions/Exceptions.h @@ -90,8 +90,9 @@ class Offline_Check_Error: public runtime_error runtime_error("Offline-Check-Error : " + m) {} }; class mac_fail: public bad_value - { virtual const char* what() const throw() - { return "MacCheck Failure"; } + { + public: + mac_fail(string msg = "MacCheck Failure") : bad_value(msg) {} }; class consistency_check_fail: public exception { virtual const char* what() const throw() diff --git a/GC/BitAdder.hpp b/GC/BitAdder.hpp index b917d030d..9f8525971 100644 --- a/GC/BitAdder.hpp +++ b/GC/BitAdder.hpp @@ -64,7 +64,7 @@ void BitAdder::add(vector >& res, int n_bits = summands.size(); for (size_t i = begin; i < end; i++) - res[i].resize(n_bits + 1); + res.at(i).resize(n_bits + 1); size_t n_items = end - begin; diff --git a/GC/FakeSecret.cpp b/GC/FakeSecret.cpp index 258324e14..d5629b63c 100644 --- a/GC/FakeSecret.cpp +++ b/GC/FakeSecret.cpp @@ -8,11 +8,12 @@ #include "GC/square64.h" #include "GC/Processor.hpp" +#include "GC/ShareSecret.hpp" +#include "Processor/Input.hpp" namespace GC { -SwitchableOutput FakeSecret::out; const int FakeSecret::default_length; void FakeSecret::load_clear(int n, const Integer& x) @@ -87,6 +88,14 @@ FakeSecret FakeSecret::input(int from, word input, int n_bits) return input; } +void FakeSecret::inputbvec(Processor& processor, + ProcessorBase& input_processor, const vector& args) +{ + Input input; + input.reset_all(*ShareThread::s().P); + processor.inputbvec(input, input_processor, args, 0); +} + void FakeSecret::and_(int n, const FakeSecret& x, const FakeSecret& y, bool repeat) { @@ -96,4 +105,19 @@ void FakeSecret::and_(int n, const FakeSecret& x, const FakeSecret& y, *this = BitVec(x & y).mask(n); } +void FakeSecret::my_input(Input& inputter, BitVec value, int n_bits) +{ + inputter.add_mine(value, n_bits); +} + +void FakeSecret::other_input(Input&, int, int) +{ + throw runtime_error("emulation is supposed to be lonely"); +} + +void FakeSecret::finalize_input(Input& inputter, int from, int n_bits) +{ + *this = inputter.finalize(from, n_bits); +} + } /* namespace GC */ diff --git a/GC/FakeSecret.h b/GC/FakeSecret.h index 19f3198e1..40dda61a0 100644 --- a/GC/FakeSecret.h +++ b/GC/FakeSecret.h @@ -24,6 +24,8 @@ #include #include +class ProcessorBase; + namespace GC { @@ -53,7 +55,10 @@ class FakeSecret : public ShareInterface, public BitVec typedef FakeProtocol Protocol; typedef FakeInput Input; + typedef SwitchableOutput out_type; + static string type_string() { return "fake secret"; } + static string type_short() { return "emulB"; } static string phase_name() { return "Faking"; } static const int default_length = 64; @@ -62,7 +67,8 @@ class FakeSecret : public ShareInterface, public BitVec static const bool actual_inputs = true; - static SwitchableOutput out; + static const true_type invertible; + static const true_type characteristic_two; static DataFieldType field_type() { return DATA_GF2; } @@ -87,8 +93,8 @@ class FakeSecret : public ShareInterface, public BitVec template static void inputb(T& processor, ArithmeticProcessor&, const vector& args) { processor.input(args); } - template - static void inputbvec(T&, U&, const vector&) { throw not_implemented(); } + static void inputbvec(Processor& processor, + ProcessorBase& input_processor, const vector& args); template static void reveal_inst(T& processor, const vector& args) { processor.reveal(args); } @@ -136,6 +142,14 @@ class FakeSecret : public ShareInterface, public BitVec void reveal(int n_bits, Clear& x) { (void) n_bits; x = a; } void invert(FakeSecret) { throw not_implemented(); } + + void input(istream&, bool) { throw not_implemented(); } + + bool operator<(FakeSecret) const { return false; } + + void my_input(Input& inputter, BitVec value, int n_bits); + void other_input(Input& inputter, int from, int n_bits = 1); + void finalize_input(Input& inputter, int from, int n_bits); }; } /* namespace GC */ diff --git a/GC/Machine.hpp b/GC/Machine.hpp index 03560b164..c2934e8e6 100644 --- a/GC/Machine.hpp +++ b/GC/Machine.hpp @@ -47,13 +47,6 @@ template void Machine::load_schedule(string progname) { BaseMachine::load_schedule(progname); - for (auto i : {1, 0, 0}) - { - int n; - inpf >> n; - if (n != i) - throw runtime_error("old schedule format not supported"); - } print_compiler(); } diff --git a/GC/NoShare.h b/GC/NoShare.h index 7cf1ec8d9..78a092dcf 100644 --- a/GC/NoShare.h +++ b/GC/NoShare.h @@ -7,6 +7,7 @@ #define GC_NOSHARE_H_ #include "Processor/DummyProtocol.h" +#include "BMR/Register.h" #include "Tools/SwitchableOutput.h" class InputArgs; @@ -41,6 +42,11 @@ class NoValue : public ValueInterface return 0; } + static string type_string() + { + return "no"; + } + static void fail() { throw runtime_error("VM does not support binary circuits"); @@ -93,8 +99,6 @@ class NoShare static const bool expensive_triples = false; static const bool is_real = false; - static SwitchableOutput out; - static MC* new_mc(mac_key_type) { return new MC; @@ -130,7 +134,7 @@ class NoShare NoValue::fail(); } - static void inputb(Processor&, ArithmeticProcessor&, const vector&) { fail(); } + static void inputb(Processor&, const ArithmeticProcessor&, const vector&) { fail(); } static void reveal_inst(Processor&, const vector&) { fail(); } static void xors(Processor&, const vector&) { fail(); } static void ands(Processor&, const vector&) { fail(); } @@ -139,6 +143,10 @@ class NoShare static void input(Processor&, InputArgs&) { fail(); } static void trans(Processor&, Integer, const vector&) { fail(); } + static void xors(Processor&, vector) { fail(); } + static void ands(Processor&, vector) { fail(); } + static void andrs(Processor&, vector) { fail(); } + static NoShare constant(const GC::Clear&, int, mac_key_type) { fail(); return {}; } NoShare() {} @@ -161,8 +169,8 @@ class NoShare void operator^=(NoShare) { fail(); } NoShare operator+(const NoShare&) const { fail(); return {}; } - NoShare operator-(NoShare) const { fail(); return 0; } - NoShare operator*(NoValue) const { fail(); return 0; } + NoShare operator-(const NoShare&) const { fail(); return {}; } + NoShare operator*(const NoValue&) const { fail(); return {}; } NoShare operator+(int) const { fail(); return {}; } NoShare operator&(int) const { fail(); return {}; } @@ -172,6 +180,8 @@ class NoShare NoShare get_bit(int) const { fail(); return {}; } void invert(int, NoShare) { fail(); } + + void input(istream&, bool) { fail(); } }; } /* namespace GC */ diff --git a/GC/Processor.h b/GC/Processor.h index 3703cac5c..25fcca90b 100644 --- a/GC/Processor.h +++ b/GC/Processor.h @@ -44,6 +44,8 @@ class Processor : public ::ProcessorBase, public GC::RuntimeBranching Timer xor_timer; + typename T::out_type out; + Processor(Machine& machine); Processor(Memories& memories, Machine* machine = 0); ~Processor(); diff --git a/GC/Processor.hpp b/GC/Processor.hpp index a91c24e0c..9b8c190f5 100644 --- a/GC/Processor.hpp +++ b/GC/Processor.hpp @@ -301,15 +301,15 @@ void Processor::print_reg(int reg, int n, int size) bigint output; for (int i = 0; i < size; i++) output += bigint((unsigned long)C[reg + i].get()) << (T::default_length * i); - T::out << "Reg[" << reg << "] = " << hex << showbase << output << dec << " # "; + out << "Reg[" << reg << "] = " << hex << showbase << output << dec << " # "; print_str(n); - T::out << endl << flush; + out << endl << flush; } template void Processor::print_reg_plain(Clear& value) { - T::out << hex << showbase << value << dec << flush; + out << hex << showbase << value << dec << flush; } template @@ -323,7 +323,7 @@ void Processor::print_reg_signed(unsigned n_bits, Integer reg) n_shift = sizeof(value.get()) * 8 - n_bits; if (n_shift > 63) n_shift = 0; - T::out << dec << (value.get() << n_shift >> n_shift) << flush; + out << dec << (value.get() << n_shift >> n_shift) << flush; } else { @@ -334,26 +334,26 @@ void Processor::print_reg_signed(unsigned n_bits, Integer reg) } if (tmp >= bigint(1) << (n_bits - 1)) tmp -= bigint(1) << n_bits; - T::out << dec << tmp << flush; + out << dec << tmp << flush; } } template void Processor::print_chr(int n) { - T::out << (char)n << flush; + out << (char)n << flush; } template void Processor::print_str(int n) { - T::out << string((char*)&n,sizeof(n)) << flush; + out << string((char*)&n,sizeof(n)) << flush; } template void Processor::print_float(const vector& args) { - bigint::output_float(T::out, + bigint::output_float(out, bigint::get_float(C[args[0]], C[args[1]], C[args[2]], C[args[3]]), C[args[4]]); } @@ -361,7 +361,7 @@ void Processor::print_float(const vector& args) template void Processor::print_float_prec(int n) { - T::out << setprecision(n); + out << setprecision(n); } } /* namespace GC */ diff --git a/GC/Rep4Secret.cpp b/GC/Rep4Secret.cpp new file mode 100644 index 000000000..f12c3534c --- /dev/null +++ b/GC/Rep4Secret.cpp @@ -0,0 +1,26 @@ +/* + * Rep4Secret.cpp + * + */ + +#ifndef GC_REP4SECRET_CPP_ +#define GC_REP4SECRET_CPP_ + +#include "Rep4Secret.h" + +#include "ShareSecret.hpp" +#include "ShareThread.hpp" +#include "Protocols/Rep4MC.hpp" + +namespace GC +{ + +void Rep4Secret::load_clear(int n, const Integer& x) +{ + this->check_length(n, x); + *this = constant(x, ShareThread::s().P->my_num()); +} + +} + +#endif /* GC_REP4SECRET_CPP_ */ diff --git a/GC/Rep4Secret.h b/GC/Rep4Secret.h new file mode 100644 index 000000000..f17ae1e37 --- /dev/null +++ b/GC/Rep4Secret.h @@ -0,0 +1,53 @@ +/* + * Rep4Secret.h + * + */ + +#ifndef GC_REP4SECRET_H_ +#define GC_REP4SECRET_H_ + +#include "ShareSecret.h" +#include "Processor/NoLivePrep.h" +#include "Protocols/Rep4MC.h" +#include "Protocols/Rep4Share.h" + +namespace GC +{ + +class Rep4Secret : public RepSecretBase +{ + typedef RepSecretBase super; + typedef Rep4Secret This; + +public: + typedef DummyLivePrep LivePrep; + typedef Rep4 Protocol; + typedef Rep4MC MC; + typedef MC MAC_Check; + typedef Rep4Input Input; + + static const bool expensive_triples = false; + + static MC* new_mc(typename super::mac_key_type) { return new MC; } + + static This constant(const typename super::clear& constant, int my_num, + typename super::mac_key_type = {}) + { + return Rep4Share::constant(constant, my_num); + } + + Rep4Secret() + { + } + template + Rep4Secret(const T& other) : + super(other) + { + } + + void load_clear(int n, const Integer& x); +}; + +} + +#endif /* GC_REP4SECRET_H_ */ diff --git a/GC/Secret.h b/GC/Secret.h index f8a11b492..a45f40bf1 100644 --- a/GC/Secret.h +++ b/GC/Secret.h @@ -62,13 +62,13 @@ class Secret typedef typename T::Input Input; + typedef typename T::out_type out_type; + static string type_string() { return "evaluation secret"; } static string phase_name() { return T::name(); } static int default_length; - static typename T::out_type out; - static const bool needs_ot = false; static const bool is_real = true; @@ -170,9 +170,6 @@ class Secret template int Secret::default_length = 64; -template -typename T::out_type Secret::out = T::out; - template inline ostream& operator<<(ostream& o, Secret& secret) { diff --git a/GC/SemiPrep.cpp b/GC/SemiPrep.cpp index 04a7d3885..41752eab1 100644 --- a/GC/SemiPrep.cpp +++ b/GC/SemiPrep.cpp @@ -58,10 +58,11 @@ SemiPrep::~SemiPrep() void SemiPrep::buffer_bits() { - auto& thread = Thread::s(); - word r = thread.secure_prng.get_word(); + word r = secure_prng.get_word(); for (size_t i = 0; i < sizeof(word) * 8; i++) + { this->bits.push_back((r >> i) & 1); + } } size_t SemiPrep::data_sent() diff --git a/GC/SemiPrep.h b/GC/SemiPrep.h index 243444fb7..d45a12985 100644 --- a/GC/SemiPrep.h +++ b/GC/SemiPrep.h @@ -23,6 +23,8 @@ class SemiPrep : public BufferPrep, ShiftableTripleBuffer& thread); SemiPrep(DataPositions& usage, bool = true); diff --git a/GC/ShareParty.hpp b/GC/ShareParty.hpp index a2fd84d0a..926a74f33 100644 --- a/GC/ShareParty.hpp +++ b/GC/ShareParty.hpp @@ -80,8 +80,6 @@ ShareParty::ShareParty(int argc, const char** argv, int default_batch_size) : this->machine.more_comm_less_comp = opt.get("-c")->isSet; - T::out.activate(my_num == 0 or online_opts.interactive); - if (not this->machine.use_encryption and not T::dishonest_majority) insecure("unencrypted communication"); diff --git a/GC/ShareSecret.h b/GC/ShareSecret.h index a7c32e851..3a76dd233 100644 --- a/GC/ShareSecret.h +++ b/GC/ShareSecret.h @@ -44,8 +44,6 @@ class ShareSecret static const bool is_real = true; static const bool actual_inputs = true; - static SwitchableOutput out; - static void store_clear_in_dynamic(Memory& mem, const vector& accesses); @@ -83,21 +81,26 @@ class ShareSecret void other_input(T& inputter, int from, int n_bits = 1); template void finalize_input(T& inputter, int from, int n_bits); + + U& operator=(const U&); }; -template -class ReplicatedSecret : public FixedVec, public ShareSecret +template +class RepSecretBase : public FixedVec, public ShareSecret { - typedef FixedVec super; + typedef FixedVec super; + typedef RepSecretBase This; public: + typedef U part_type; + typedef U small_type; + typedef U whole_type; + typedef BitVec clear; typedef BitVec open_type; typedef BitVec mac_type; typedef BitVec mac_key_type; - typedef ReplicatedBase Protocol; - typedef NoShare bit_type; static const int N_BITS = clear::N_BITS; @@ -109,7 +112,7 @@ class ReplicatedSecret : public FixedVec, public ShareSecret static string type_string() { return "replicated secret"; } static string phase_name() { return "Replicated computation"; } - static const int default_length = 8 * sizeof(typename ReplicatedSecret::value_type); + static const int default_length = N_BITS; static int threshold(int) { @@ -124,9 +127,45 @@ class ReplicatedSecret : public FixedVec, public ShareSecret { } - static void read_or_generate_mac_key(string, const Names&, mac_key_type) {} + static void read_or_generate_mac_key(string, const Player&, mac_key_type) + { + } - static ReplicatedSecret constant(const clear& value, int my_num, mac_key_type) + RepSecretBase() + { + } + template + RepSecretBase(const T& other) : + super(other) + { + } + + void bitcom(Memory& S, const vector& regs); + void bitdec(Memory& S, const vector& regs) const; + + void xor_(int n, const This& x, const This& y) + { *this = x ^ y; (void)n; } + + This operator&(const Clear& other) + { return super::operator&(BitVec(other)); } + + This lsb() + { return *this & 1; } + + This get_bit(int i) + { return (*this >> i) & 1; } +}; + +template +class ReplicatedSecret : public RepSecretBase +{ + typedef RepSecretBase super; + +public: + typedef ReplicatedBase Protocol; + + static ReplicatedSecret constant(const typename super::clear& value, int my_num, + typename super::mac_key_type) { ReplicatedSecret res; if (my_num < 2) @@ -140,27 +179,43 @@ class ReplicatedSecret : public FixedVec, public ShareSecret void load_clear(int n, const Integer& x); - void bitcom(Memory& S, const vector& regs); - void bitdec(Memory& S, const vector& regs) const; - BitVec local_mul(const ReplicatedSecret& other) const; - void xor_(int n, const ReplicatedSecret& x, const ReplicatedSecret& y) - { *this = x ^ y; (void)n; } - void reveal(size_t n_bits, Clear& x); +}; - ReplicatedSecret operator&(const Clear& other) - { return super::operator&(BitVec(other)); } +class SemiHonestRepPrep; - ReplicatedSecret lsb() - { return *this & 1; } +class SmallRepSecret : public FixedVec, 2> +{ + typedef FixedVec, 2> super; + typedef SmallRepSecret This; - ReplicatedSecret get_bit(int i) - { return (*this >> i) & 1; } -}; +public: + typedef ReplicatedMC MC; + typedef BitVec_ open_type; + typedef open_type clear; + typedef BitVec mac_key_type; -class SemiHonestRepPrep; + static MC* new_mc(mac_key_type) + { + return new MC; + } + + SmallRepSecret() + { + } + template + SmallRepSecret(const T& other) : + super(other) + { + } + + This lsb() const + { + return *this & 1; + } +}; class SemiHonestRepSecret : public ReplicatedSecret { @@ -176,7 +231,7 @@ class SemiHonestRepSecret : public ReplicatedSecret typedef ReplicatedInput Input; typedef SemiHonestRepSecret part_type; - typedef SemiHonestRepSecret small_type; + typedef SmallRepSecret small_type; typedef SemiHonestRepSecret whole_type; static const bool expensive_triples = false; diff --git a/GC/ShareSecret.hpp b/GC/ShareSecret.hpp index f9a453618..8edb1f4c0 100644 --- a/GC/ShareSecret.hpp +++ b/GC/ShareSecret.hpp @@ -25,14 +25,11 @@ namespace GC { -template -const int ReplicatedSecret::N_BITS; - -template -const int ReplicatedSecret::default_length; +template +const int RepSecretBase::N_BITS; -template -SwitchableOutput ShareSecret::out; +template +const int RepSecretBase::default_length; template void ShareSecret::check_length(int n, const Integer& x) @@ -59,16 +56,16 @@ void ReplicatedSecret::load_clear(int n, const Integer& x) *this = x; } -template -void ReplicatedSecret::bitcom(Memory& S, const vector& regs) +template +void RepSecretBase::bitcom(Memory& S, const vector& regs) { *this = 0; for (unsigned int i = 0; i < regs.size(); i++) *this ^= (S[regs[i]] << i); } -template -void ReplicatedSecret::bitdec(Memory& S, const vector& regs) const +template +void RepSecretBase::bitdec(Memory& S, const vector& regs) const { for (unsigned int i = 0; i < regs.size(); i++) S[regs[i]] = (*this >> i) & 1; @@ -285,12 +282,11 @@ void ShareSecret::xors(Processor& processor, const vector& args) ShareThread::s().xors(processor, args); } -template -void ReplicatedSecret::trans(Processor& processor, +template +void RepSecretBase::trans(Processor& processor, int n_outputs, const vector& args) { - assert(length == 2); - for (int k = 0; k < 2; k++) + for (int k = 0; k < L; k++) { for (int j = 0; j < DIV_CEIL(n_outputs, N_BITS); j++) for (int l = 0; l < DIV_CEIL(args.size() - n_outputs, N_BITS); l++) @@ -330,6 +326,14 @@ void ShareSecret::random_bit() *this = res; } +template +U& GC::ShareSecret::operator=(const U& other) +{ + U& real_this = static_cast(*this); + real_this = other; + return real_this; +} + } #endif diff --git a/GC/Thread.hpp b/GC/Thread.hpp index 5dba99102..637873cf5 100644 --- a/GC/Thread.hpp +++ b/GC/Thread.hpp @@ -54,7 +54,13 @@ void Thread::run() P = new CryptoPlayer(N, thread_num << 16); else P = new PlainPlayer(N, thread_num << 16); - processor.open_input_file(N.my_num(), thread_num); + processor.open_input_file(N.my_num(), thread_num, + master.opts.cmd_private_input_file); + processor.out.activate(N.my_num() == 0 or master.opts.interactive); + processor.setup_redirection(P->my_num(), thread_num, master.opts); + if (processor.stdout_redirect_file.is_open()) + processor.out.redirect_to_file(processor.stdout_redirect_file); + done.push(0); pre_run(); diff --git a/GC/TinySecret.h b/GC/TinySecret.h index 6270ad34a..c7a804276 100644 --- a/GC/TinySecret.h +++ b/GC/TinySecret.h @@ -38,6 +38,7 @@ class VectorSecret : public Secret typedef typename part_type::sacri_type sacri_type; typedef typename part_type::mac_type mac_type; + typedef typename part_type::mac_share_type mac_share_type; typedef BitDiagonal Rectangle; typedef typename T::super check_type; @@ -152,6 +153,11 @@ class VectorSecret : public Secret reg.output(s, human); } + void input(istream&, bool) + { + throw not_implemented(); + } + template void my_input(U& inputter, BitVec value, int n_bits) { diff --git a/GC/instructions.h b/GC/instructions.h index 9143d01da..f07e863e6 100644 --- a/GC/instructions.h +++ b/GC/instructions.h @@ -129,7 +129,7 @@ X(GLDMC, ) \ X(LDMS, ) \ X(LDMC, ) \ - X(PRINTINT, S0.out << I0) \ + X(PRINTINT, PROC.out << I0) \ X(STARTGRIND, CALLGRIND_START_INSTRUMENTATION) \ X(STOPGRIND, CALLGRIND_STOP_INSTRUMENTATION) \ X(RUN_TAPE, MACH->run_tapes(EXTRA)) \ diff --git a/Machines/Player-Online.cpp b/Machines/Player-Online.cpp index e9599cfbc..e93724eb4 100644 --- a/Machines/Player-Online.cpp +++ b/Machines/Player-Online.cpp @@ -6,7 +6,7 @@ #include "Processor/config.h" #include "Protocols/Share.h" #include "GC/TinierSecret.h" -#include "Math/gfp.h" +#include "Math/gfp.hpp" #include "Player-Online.hpp" diff --git a/Machines/RepRing.hpp b/Machines/RepRing.hpp index cec263e58..f508b25a2 100644 --- a/Machines/RepRing.hpp +++ b/Machines/RepRing.hpp @@ -1,2 +1,3 @@ #include "Rep.hpp" #include "Protocols/Spdz2kPrep.hpp" +#include "Protocols/RepRingOnlyEdabitPrep.hpp" diff --git a/Machines/emulate.cpp b/Machines/emulate.cpp index 24377bc95..32a9b6f18 100644 --- a/Machines/emulate.cpp +++ b/Machines/emulate.cpp @@ -16,16 +16,14 @@ #include "Protocols/ReplicatedPrep.hpp" #include "Protocols/FakeShare.hpp" -SwitchableOutput GC::NoShare::out; - int main(int argc, const char** argv) { - assert(argc > 1); OnlineOptions online_opts; - Names N(0, 9999, vector({"localhost"})); + Names N(0, randombytes_random() % (65536 - 1024) + 1024, vector({"localhost"})); ez::ezOptionParser opt; RingOptions ring_opts(opt, argc, argv); opt.parse(argc, argv); + opt.syntax = string(argv[0]) + " "; string progname; if (opt.firstArgs.size() > 1) progname = *opt.firstArgs.at(1); @@ -41,7 +39,16 @@ int main(int argc, const char** argv) exit(1); } - switch (ring_opts.R) +#ifdef ROUND_NEAREST_IN_EMULATION + cerr << "Using nearest rounding instead of probabilistic truncation" << endl; +#else +#ifdef RISKY_TRUNCATION_IN_EMULATION + cerr << "Using risky truncation" << endl; +#endif +#endif + + int R = ring_opts.ring_size_from_opts_or_schedule(progname); + switch (R) { case 64: Machine>, FakeShare>(0, N, progname, @@ -53,7 +60,27 @@ int main(int argc, const char** argv) online_opts.memtype, gf2n::default_degree(), 0, 0, 0, 0, true, online_opts.live_prep, online_opts).run(); break; + case 256: + Machine>, FakeShare>(0, N, progname, + online_opts.memtype, gf2n::default_degree(), 0, 0, 0, 0, true, + online_opts.live_prep, online_opts).run(); + break; + case 192: + Machine>, FakeShare>(0, N, progname, + online_opts.memtype, gf2n::default_degree(), 0, 0, 0, 0, true, + online_opts.live_prep, online_opts).run(); + break; + case 384: + Machine>, FakeShare>(0, N, progname, + online_opts.memtype, gf2n::default_degree(), 0, 0, 0, 0, true, + online_opts.live_prep, online_opts).run(); + break; + case 512: + Machine>, FakeShare>(0, N, progname, + online_opts.memtype, gf2n::default_degree(), 0, 0, 0, 0, true, + online_opts.live_prep, online_opts).run(); + break; default: - cerr << "Not compiled for " << ring_opts.R << "-bit rings" << endl; + cerr << "Not compiled for " << R << "-bit rings" << endl; } } diff --git a/Machines/mascot-party.cpp b/Machines/mascot-party.cpp index ad60f2faf..75ff93120 100644 --- a/Machines/mascot-party.cpp +++ b/Machines/mascot-party.cpp @@ -1,6 +1,6 @@ #include "Player-Online.hpp" -#include "Math/gfp.h" +#include "Math/gfp.hpp" #include "GC/TinierSecret.h" int main(int argc, const char** argv) diff --git a/Machines/ps-rep-ring-party.cpp b/Machines/ps-rep-ring-party.cpp index df54f9cb5..c5cf7ebf4 100644 --- a/Machines/ps-rep-ring-party.cpp +++ b/Machines/ps-rep-ring-party.cpp @@ -13,14 +13,27 @@ int main(int argc, const char** argv) { ez::ezOptionParser opt; - RingOptions opts(opt, argc, argv); + RingOptions opts(opt, argc, argv, true); switch (opts.R) { case 64: - ReplicatedMachine, PostSacriRepFieldShare>( - argc, argv, opt); - break; - case 72: + switch (opts.S) + { + case 40: + ReplicatedMachine, + PostSacriRepFieldShare>(argc, argv, opt); + break; + case 64: + ReplicatedMachine, + PostSacriRepFieldShare>(argc, argv, opt); + break; + default: + cerr << "Security parameter " << opts.S << " not implemented" + << endl; + exit(1); + } + break; + case 72: ReplicatedMachine, PostSacriRepFieldShare>( argc, argv, opt); break; diff --git a/Machines/rep4-ring-party.cpp b/Machines/rep4-ring-party.cpp new file mode 100644 index 000000000..aa0c5d9c6 --- /dev/null +++ b/Machines/rep4-ring-party.cpp @@ -0,0 +1,38 @@ +/* + * rep4-party.cpp + * + */ + +#include "Protocols/Rep4Share2k.h" +#include "Protocols/Rep4Share.h" +#include "Protocols/Rep4MC.h" +#include "Protocols/ReplicatedMachine.h" +#include "Math/Z2k.h" +#include "Math/gf2n.h" +#include "Tools/ezOptionParser.h" +#include "GC/Rep4Secret.h" +#include "Processor/RingOptions.h" + +#include "Protocols/RepRingOnlyEdabitPrep.hpp" +#include "Protocols/ReplicatedMachine.hpp" +#include "Protocols/Rep4Input.hpp" +#include "Protocols/Rep4Prep.hpp" +#include "Protocols/Rep4MC.hpp" +#include "Protocols/Rep4.hpp" +#include "GC/BitAdder.hpp" +#include "Math/Z2k.hpp" +#include "Rep.hpp" + +int main(int argc, const char** argv) +{ + ez::ezOptionParser opt; + RingOptions ring_opts(opt, argc, argv); + switch (ring_opts.R) + { +#define X(R) case R: ReplicatedMachine, Rep4Share>(argc, argv, opt, 4); break; + X(64) X(80) X(88) + default: + cerr << ring_opts.R << "-bit computation not implemented" << endl; + exit(1); + } +} diff --git a/Machines/replicated-ring-party.cpp b/Machines/replicated-ring-party.cpp index e55d6657a..434d5338d 100644 --- a/Machines/replicated-ring-party.cpp +++ b/Machines/replicated-ring-party.cpp @@ -24,6 +24,10 @@ int main(int argc, const char** argv) ReplicatedMachine, Rep3Share>(argc, argv, "replicated-ring", opt); break; + case 128: + ReplicatedMachine, Rep3Share>(argc, argv, + "replicated-ring", opt); + break; default: throw runtime_error(to_string(opts.R) + "-bit computation not implemented"); } diff --git a/Machines/semi2k-party.cpp b/Machines/semi2k-party.cpp index f6dffa277..24a77ce8f 100644 --- a/Machines/semi2k-party.cpp +++ b/Machines/semi2k-party.cpp @@ -12,6 +12,7 @@ #include "Player-Online.hpp" #include "Semi.hpp" #include "GC/ShareSecret.hpp" +#include "Protocols/RepRingOnlyEdabitPrep.hpp" int main(int argc, const char** argv) { diff --git a/Machines/spdz2k-party.cpp b/Machines/spdz2k-party.cpp index 929fbe7f5..2a92b5103 100644 --- a/Machines/spdz2k-party.cpp +++ b/Machines/spdz2k-party.cpp @@ -11,6 +11,7 @@ #include "Networking/Server.h" #include "Player-Online.hpp" +#include "Math/Z2k.hpp" int main(int argc, const char** argv) { diff --git a/Machines/sy-rep-field-party.cpp b/Machines/sy-rep-field-party.cpp new file mode 100644 index 000000000..a84aa10da --- /dev/null +++ b/Machines/sy-rep-field-party.cpp @@ -0,0 +1,41 @@ +/* + * sy-rep-field-party.cpp + * + */ + +#include "Protocols/SpdzWiseShare.h" +#include "Protocols/MaliciousRep3Share.h" +#include "Protocols/ReplicatedMachine.h" +#include "Protocols/MAC_Check.h" +#include "Protocols/SpdzWiseMC.h" +#include "Protocols/SpdzWisePrep.h" +#include "Protocols/SpdzWiseInput.h" +#include "Math/gfp.h" +#include "Math/gf2n.h" +#include "Tools/ezOptionParser.h" +#include "Processor/NoLivePrep.h" +#include "GC/MaliciousCcdSecret.h" + +#include "Protocols/ReplicatedMachine.hpp" +#include "Protocols/Replicated.hpp" +#include "Protocols/MaliciousRepMC.hpp" +#include "Protocols/Share.hpp" +#include "Protocols/fake-stuff.hpp" +#include "Protocols/SpdzWise.hpp" +#include "Protocols/SpdzWisePrep.hpp" +#include "Protocols/SpdzWiseInput.hpp" +#include "Protocols/SpdzWiseShare.hpp" +#include "Processor/Data_Files.hpp" +#include "Processor/Instruction.hpp" +#include "Processor/Machine.hpp" +#include "GC/ShareSecret.hpp" +#include "GC/RepPrep.hpp" +#include "GC/ThreadMaster.hpp" +#include "Math/gfp.hpp" + +int main(int argc, const char** argv) +{ + ez::ezOptionParser opts; + ReplicatedMachine>, + SpdzWiseShare>>(argc, argv, opts); +} diff --git a/Machines/sy-rep-ring-party.cpp b/Machines/sy-rep-ring-party.cpp new file mode 100644 index 000000000..444df7558 --- /dev/null +++ b/Machines/sy-rep-ring-party.cpp @@ -0,0 +1,68 @@ +/* + * sy-rep-ring-party.cpp + * + */ + +#include "Protocols/ReplicatedMachine.h" +#include "Protocols/SpdzWiseRingShare.h" +#include "Protocols/MaliciousRep3Share.h" +#include "Protocols/SpdzWiseMC.h" +#include "Protocols/SpdzWiseRingPrep.h" +#include "Protocols/SpdzWiseInput.h" +#include "Protocols/MalRepRingPrep.h" +#include "Processor/RingOptions.h" +#include "GC/MaliciousCcdSecret.h" + +#include "Protocols/ReplicatedMachine.hpp" +#include "Protocols/Replicated.hpp" +#include "Protocols/MaliciousRepMC.hpp" +#include "Protocols/Share.hpp" +#include "Protocols/fake-stuff.hpp" +#include "Protocols/SpdzWise.hpp" +#include "Protocols/SpdzWiseRing.hpp" +#include "Protocols/SpdzWisePrep.hpp" +#include "Protocols/SpdzWiseInput.hpp" +#include "Protocols/SpdzWiseShare.hpp" +#include "Protocols/PostSacrifice.hpp" +#include "Protocols/MalRepRingPrep.hpp" +#include "Protocols/MaliciousRepPrep.hpp" +#include "Protocols/RepRingOnlyEdabitPrep.hpp" +#include "Processor/Data_Files.hpp" +#include "Processor/Instruction.hpp" +#include "Processor/Machine.hpp" +#include "GC/ShareSecret.hpp" +#include "GC/RepPrep.hpp" +#include "GC/ThreadMaster.hpp" + +int main(int argc, const char** argv) +{ + ez::ezOptionParser opt; + RingOptions opts(opt, argc, argv, true); + switch (opts.R) + { + case 64: + switch (opts.S) + { + case 40: + ReplicatedMachine, + SpdzWiseShare>>(argc, argv, opt); + break; + case 64: + ReplicatedMachine, + SpdzWiseShare>>(argc, argv, opt); + break; + default: + cerr << "Security parameter " << opts.S << " not implemented" + << endl; + exit(1); + } + break; + case 72: + ReplicatedMachine, + SpdzWiseShare>>(argc, argv, opt); + break; + default: + throw runtime_error( + to_string(opts.R) + "-bit computation not implemented"); + } +} diff --git a/Machines/sy-shamir-party.cpp b/Machines/sy-shamir-party.cpp new file mode 100644 index 000000000..fb765779d --- /dev/null +++ b/Machines/sy-shamir-party.cpp @@ -0,0 +1,33 @@ +/* + * sy-shamir-party.cpp + * + */ + +#include "ShamirMachine.h" +#include "Protocols/ReplicatedMachine.h" +#include "Protocols/SpdzWiseShare.h" +#include "Protocols/MaliciousShamirShare.h" +#include "Protocols/SpdzWiseMC.h" +#include "Protocols/SpdzWiseInput.h" +#include "Math/gfp.h" +#include "Math/gf2n.h" +#include "GC/CcdSecret.h" +#include "GC/MaliciousCcdSecret.h" + +#include "Protocols/Share.hpp" +#include "Protocols/SpdzWise.hpp" +#include "Protocols/SpdzWisePrep.hpp" +#include "Protocols/SpdzWiseInput.hpp" +#include "Protocols/SpdzWiseShare.hpp" +#include "Machines/ShamirMachine.hpp" + +int main(int argc, const char** argv) +{ + auto& opts = ShamirOptions::singleton; + ez::ezOptionParser opt; + opts = {opt, argc, argv}; + ReplicatedMachine>, + SpdzWiseShare>>( + argc, argv, + { }, opt, opts.nparties); +} diff --git a/Makefile b/Makefile index ef3d7d48a..fcf18561c 100644 --- a/Makefile +++ b/Makefile @@ -42,7 +42,7 @@ all: arithmetic binary gen_input online offline externalIO bmr doc doc: cd doc; $(MAKE) html -arithmetic: rep-ring rep-field shamir semi2k-party.x semi-party.x mascot +arithmetic: rep-ring rep-field shamir semi2k-party.x semi-party.x mascot sy binary: rep-bin yao semi-bin-party.x tinier-party.x tiny-party.x ccd-party.x malicious-ccd-party.x real-bmr ifeq ($(USE_NTL),1) @@ -77,7 +77,7 @@ overdrive: simple-offline.x pairwise-offline.x cnc-offline.x rep-field: malicious-rep-field-party.x replicated-field-party.x ps-rep-field-party.x -rep-ring: replicated-ring-party.x brain-party.x malicious-rep-ring-party.x ps-rep-ring-party.x Fake-Offline.x +rep-ring: replicated-ring-party.x brain-party.x malicious-rep-ring-party.x ps-rep-ring-party.x rep4-ring-party.x rep-bin: replicated-bin-party.x malicious-rep-bin-party.x Fake-Offline.x @@ -98,10 +98,12 @@ endif shamir: shamir-party.x malicious-shamir-party.x galois-degree.x +sy: sy-rep-field-party.x sy-rep-ring-party.x sy-shamir-party.x + ecdsa: $(patsubst ECDSA/%.cpp,%.x,$(wildcard ECDSA/*-ecdsa-party.cpp)) ecdsa-static: static-dir $(patsubst ECDSA/%.cpp,static/%.x,$(wildcard ECDSA/*-ecdsa-party.cpp)) -$(LIBRELEASE): $(patsubst %.cpp,%.o,$(wildcard Protocols/*.cpp)) $(PROCESSOR) $(COMMON) $(BMR) $(FHEOFFLINE) $(GC) +$(LIBRELEASE): Protocols/MalRepRingOptions.o $(PROCESSOR) $(COMMON) $(BMR) $(GC) $(AR) -csr $@ $^ static/%.x: Machines/%.o $(LIBRELEASE) $(LIBSIMPLEOT) @@ -184,12 +186,18 @@ hemi-party.x: $(FHEOFFLINE) $(GC_SEMI) $(OT) soho-party.x: $(FHEOFFLINE) $(GC_SEMI) $(OT) cowgear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o $(OT) chaigear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o $(OT) +static/hemi-party.x: $(FHEOFFLINE) +static/soho-party.x: $(FHEOFFLINE) +static/cowgear-party.x: $(FHEOFFLINE) +static/chaigear-party.x: $(FHEOFFLINE) mascot-party.x: Machines/SPDZ.o $(OT) static/mascot-party.x: Machines/SPDZ.o Player-Online.x: Machines/SPDZ.o $(OT) mama-party.x: $(OT) ps-rep-ring-party.x: Protocols/MalRepRingOptions.o malicious-rep-ring-party.x: Protocols/MalRepRingOptions.o +sy-rep-ring-party.x: Protocols/MalRepRingOptions.o +rep4-ring-party.x: GC/Rep4Secret.o semi-ecdsa-party.x: $(OT) $(LIBSIMPLEOT) GC/SemiPrep.o GC/SemiSecret.o mascot-ecdsa-party.x: $(OT) $(LIBSIMPLEOT) fake-spdz-ecdsa-party.x: $(OT) $(LIBSIMPLEOT) diff --git a/Math/BitVec.h b/Math/BitVec.h index 7e0b779a5..fd867a092 100644 --- a/Math/BitVec.h +++ b/Math/BitVec.h @@ -25,6 +25,7 @@ class BitVec_ : public IntBase static const int n_bits = sizeof(T) * 8; static char type_char() { return 'B'; } + static string type_short() { return "B"; } static DataFieldType field_type() { return DATA_GF2; } static bool allows(Dtype dtype) { return dtype == DATA_TRIPLE or dtype == DATA_BIT; } @@ -59,7 +60,7 @@ class BitVec_ : public IntBase void pack(octetStream& os) const { os.store_int(this->a); } void unpack(octetStream& os) { this->a = os.get_int(); } - void pack(octetStream& os, int n) const { os.store_int(this->a, DIV_CEIL(n, 8)); } + void pack(octetStream& os, int n) const { os.store_int(mask(n).a, DIV_CEIL(n, 8)); } void unpack(octetStream& os, int n) { this->a = os.get_int(DIV_CEIL(n, 8)); } static BitVec_ unpack_new(octetStream& os, int n = n_bits) diff --git a/Math/FixedVec.h b/Math/FixedVec.h index 3be4d8545..63521f224 100644 --- a/Math/FixedVec.h +++ b/Math/FixedVec.h @@ -156,11 +156,6 @@ class FixedVec return true; } - bool operator!=(const FixedVec& other) const - { - return not equal(other); - } - bool is_zero() { return equal(0); @@ -170,6 +165,11 @@ class FixedVec return equal(1); } + bool operator!=(const FixedVec& other) const + { + return not equal(other); + } + FixedVecoperator+(const FixedVec& other) const { FixedVec res; @@ -291,6 +291,15 @@ class FixedVec return res; } + T lazy_sum() const + { + assert(L > 1); + T res = v[0].lazy_add(v[1]); + for (int i = 2; i < L; i++) + res = res.lazy_add(v[i]); + return res; + } + FixedVec extend_bit() const { FixedVec res; @@ -343,13 +352,21 @@ class FixedVec void output(ostream& s, bool human) const { - for (auto& x : v) - x.output(s, human); + if (human) + s << *this; + else + for (auto& x : v) + x.output(s, human); } void input(istream& s, bool human) { - for (auto& x : v) - x.input(s, human); + for (int i = 0; i < L; i++) + { + if (human and i != 0) + if (s.get() != ',') + throw runtime_error("cannot read vector"); + (*this)[i].input(s, human); + } } void pack(octetStream& os) const diff --git a/Math/ValueInterface.h b/Math/ValueInterface.h index 32cbb5b78..eb3024175 100644 --- a/Math/ValueInterface.h +++ b/Math/ValueInterface.h @@ -17,6 +17,8 @@ class ValueInterface public: static const int MAX_EDABITS = 0; + static const false_type characteristic_two; + template static void init(bool mont = true) { (void) mont; } static void init_default(int l) { (void) l; } diff --git a/Math/Z2k.h b/Math/Z2k.h index 80cfa81be..ba8cef5e7 100644 --- a/Math/Z2k.h +++ b/Math/Z2k.h @@ -62,13 +62,14 @@ class Z2 : public ValueInterface static int t() { return 0; } static char type_char() { return 'R'; } + static string type_short() { return "R"; } static string type_string() { return "Z2^" + to_string(int(N_BITS)); } static DataFieldType field_type() { return DATA_INT; } - static const bool invertible = false; + static const false_type invertible; - template + template static Z2 Mul(const Z2& x, const Z2& y); static void reqbl(int n); @@ -151,6 +152,9 @@ class Z2 : public ValueInterface void add(octetStream& os) { add(os.consume(size())); } + Z2 lazy_add(const Z2& x) const; + Z2 lazy_mul(const Z2& x) const; + Z2& invert(); void invert(const Z2& a) { *this = a; invert(); } @@ -279,10 +283,17 @@ class SignedZ2 : public Z2 template inline Z2 Z2::operator+(const Z2& other) const +{ + auto res = lazy_add(other); + res.normalize(); + return res; +} + +template +Z2 Z2::lazy_add(const Z2& other) const { Z2 res; mpn_add_fixed_n(res.a, a, other.a); - res.a[N_WORDS - 1] &= UPPER_MASK; return res; } @@ -332,12 +343,13 @@ Z2& Z2::operator>>=(int other) } template -template +template inline Z2 Z2::Mul(const Z2& x, const Z2& y) { Z2 res; mpn_mul_fixed_::N_WORDS, Z2::N_WORDS>(res.a, x.a, y.a); - res.a[N_WORDS - 1] &= UPPER_MASK; + if (not LAZY) + res.normalize(); return res; } @@ -348,6 +360,12 @@ inline Z2<(K > L) ? K : L> Z2::operator*(const Z2& other) const return Z2<(K > L) ? K : L>::Mul(*this, other); } +template +inline Z2 Z2::lazy_mul(const Z2& other) const +{ + return Z2::Mul(*this, other); +} + template Z2 Z2::operator<<(int i) const { diff --git a/Math/Z2k.hpp b/Math/Z2k.hpp index 791c4af70..93864d828 100644 --- a/Math/Z2k.hpp +++ b/Math/Z2k.hpp @@ -13,6 +13,8 @@ template const int Z2::N_BITS; template const int Z2::N_BYTES; +template +const false_type Z2::invertible; template void Z2::reqbl(int n) diff --git a/Math/gf2n.h b/Math/gf2n.h index 751627b5c..0d1ca7cff 100644 --- a/Math/gf2n.h +++ b/Math/gf2n.h @@ -83,6 +83,7 @@ class gf2n_short : public ValueInterface static DataFieldType field_type() { return DATA_GF2N; } static char type_char() { return '2'; } + static string type_short() { return "2"; } static string type_string() { return "gf2n"; } static int size() { return sizeof(a); } @@ -94,8 +95,8 @@ class gf2n_short : public ValueInterface static bool allows(Dtype type) { (void) type; return true; } - static const bool invertible = true; - static const bool characteristic_two = true; + static const true_type invertible; + static const true_type characteristic_two; static gf2n_short cut(int128 x) { return x.get_lower(); } @@ -163,6 +164,9 @@ class gf2n_short : public ValueInterface // x * y when one of x,y is a bit void mul_by_bit(const gf2n_short& x, const gf2n_short& y) { a = x.a * y.a; } + gf2n_short lazy_add(const gf2n_short& x) const { return *this + x; } + gf2n_short lazy_mul(const gf2n_short& x) const { return *this * x; } + gf2n_short operator+(const gf2n_short& x) const { gf2n_short res; res.add(*this, x); return res; } gf2n_short operator*(const gf2n_short& x) const { gf2n_short res; res.mul(*this, x); return res; } gf2n_short& operator+=(const gf2n_short& x) { add(x); return *this; } diff --git a/Math/gf2nlong.h b/Math/gf2nlong.h index c541ce2f8..92c37b33c 100644 --- a/Math/gf2nlong.h +++ b/Math/gf2nlong.h @@ -134,6 +134,7 @@ class gf2n_long : public ValueInterface static DataFieldType field_type() { return DATA_GF2N; } static char type_char() { return '2'; } + static string type_short() { return "2"; } static string type_string() { return "gf2n_long"; } static int size() { return sizeof(a); } @@ -144,8 +145,8 @@ class gf2n_long : public ValueInterface static bool allows(Dtype type) { (void) type; return true; } - static const bool invertible = true; - static const bool characteristic_two = true; + static const true_type invertible; + static const true_type characteristic_two; static gf2n_long cut(int128 x) { return x; } @@ -216,6 +217,9 @@ class gf2n_long : public ValueInterface // x * y when one of x,y is a bit void mul_by_bit(const gf2n_long& x, const gf2n_long& y) { a = x.a.a * y.a.a; } + gf2n_long lazy_add(const gf2n_long& x) const { return *this + x; } + gf2n_long lazy_mul(const gf2n_long& x) const { return *this * x; } + gf2n_long operator+(const gf2n_long& x) const { gf2n_long res; res.add(*this, x); return res; } gf2n_long operator*(const gf2n_long& x) const { gf2n_long res; res.mul(*this, x); return res; } gf2n_long& operator+=(const gf2n_long& x) { add(x); return *this; } @@ -251,6 +255,8 @@ class gf2n_long : public ValueInterface gf2n_long& operator>>=(int i) { SHR(*this, i); return *this; } gf2n_long& operator<<=(int i) { SHL(*this, i); return *this; } + bool operator<(gf2n_long) const { return false; } + /* Crap RNG */ void randomize(PRNG& G, int n = -1); // compatibility with gfp diff --git a/Math/gfp.h b/Math/gfp.h index 48e620c18..9443fcc56 100644 --- a/Math/gfp.h +++ b/Math/gfp.h @@ -88,6 +88,7 @@ class gfp_ : public ValueInterface static DataFieldType field_type() { return DATA_INT; } static char type_char() { return 'p'; } + static string type_short() { return "p"; } static string type_string() { return "gfp"; } static int size() { return t() * sizeof(mp_limb_t); } @@ -100,8 +101,7 @@ class gfp_ : public ValueInterface static void specification(octetStream& os); - static const bool invertible = true; - static const bool characteristic_two = false; + static const true_type invertible; static gfp_ Mul(gfp_ a, gfp_ b) { return a * b; } @@ -184,6 +184,9 @@ class gfp_ : public ValueInterface void mul(const gfp_& x) { a.template mul(a,x.a,ZpD); } + gfp_ lazy_add(const gfp_& x) const { return *this + x; } + gfp_ lazy_mul(const gfp_& x) const { return *this * x; } + gfp_ operator+(const gfp_& x) const { gfp_ res; res.add(*this, x); return res; } gfp_ operator-(const gfp_& x) const { gfp_ res; res.sub(*this, x); return res; } gfp_ operator*(const gfp_& x) const { gfp_ res; res.mul(*this, x); return res; } @@ -266,7 +269,7 @@ class gfp_ : public ValueInterface // Convert representation to and from a bigint number friend void to_bigint(bigint& ans,const gfp_& x,bool reduce=true) - { to_bigint(ans,x.a,x.ZpD,reduce); } + { x.a.template to_bigint(ans, x.ZpD, reduce); } friend void to_gfp(gfp_& ans,const bigint& x) { to_modp(ans.a,x,ans.ZpD); } }; diff --git a/Math/gfp.hpp b/Math/gfp.hpp index 246df321b..fb5843564 100644 --- a/Math/gfp.hpp +++ b/Math/gfp.hpp @@ -9,6 +9,9 @@ #include "Math/bigint.hpp" #include "Math/Setup.hpp" +template +const true_type gfp_::invertible; + template inline void gfp_::read_or_generate_setup(string dir, const OnlineOptions& opts) diff --git a/Math/modp.h b/Math/modp.h index 185e3b848..5ef419bed 100644 --- a/Math/modp.h +++ b/Math/modp.h @@ -74,6 +74,8 @@ class modp_ // Convert representation to and from a modp number void to_bigint(bigint& ans,const Zp_Data& ZpD,bool reduce=true) const; + template + void to_bigint(bigint& ans,const Zp_Data& ZpD,bool reduce=true) const; template void mul(const modp_& x, const modp_& y, const Zp_Data& ZpD); diff --git a/Math/modp.hpp b/Math/modp.hpp index 47fcbe125..0ac6f91de 100644 --- a/Math/modp.hpp +++ b/Math/modp.hpp @@ -113,6 +113,31 @@ void modp_::to_bigint(bigint& ans,const Zp_Data& ZpD,bool reduce) const } +template +template +void modp_::to_bigint(bigint& ans,const Zp_Data& ZpD,bool reduce) const +{ + assert(M == ZpD.t); + auto& x = *this; + mpz_ptr a = ans.get_mpz_t(); + if (a->_mp_alloc < M) + mpz_realloc(a, M); + if (ZpD.montgomery) + { + mp_limb_t one[M]; + inline_mpn_zero(one,M); + one[0]=1; + ZpD.Mont_Mult_(a->_mp_d,x.x,one); + } + else + { inline_mpn_copyi(a->_mp_d,x.x,M); } + a->_mp_size=M; + if (reduce) + while (a->_mp_size>=1 && (a->_mp_d)[a->_mp_size-1]==0) + { a->_mp_size--; } +} + + template void to_modp(modp_& ans,int x,const Zp_Data& ZpD) { diff --git a/Networking/CryptoPlayer.cpp b/Networking/CryptoPlayer.cpp index ad4695d0e..0b6dc6124 100644 --- a/Networking/CryptoPlayer.cpp +++ b/Networking/CryptoPlayer.cpp @@ -20,7 +20,9 @@ void ssl_error(string side, string pronoun, string other, string server) << " failed. Make sure " << pronoun << " have the necessary certificate (" << PREP_DIR << server << ".pem in the default configuration)," - << " and run `c_rehash ` on its location." << endl; + << " and run `c_rehash ` on its location." << endl + << "Also make sure that it's still valid. Certificates generated " + << "with `Scripts/setup-ssl.sh` expire after a month." << endl; } CryptoPlayer::CryptoPlayer(const Names& Nms, int id_base) : diff --git a/Networking/Player.h b/Networking/Player.h index cecf4e135..69d9f59dd 100644 --- a/Networking/Player.h +++ b/Networking/Player.h @@ -90,7 +90,15 @@ struct CommStats size_t data, rounds; Timer timer; CommStats() : data(0), rounds(0) {} - Timer& add(const octetStream& os) { data += os.get_length(); rounds++; return timer; } + Timer& add(const octetStream& os) + { +#ifdef VERBOSE_COMM + cout << "add " << os.get_length() << endl; +#endif + data += os.get_length(); + rounds++; + return timer; + } void add(const octetStream& os, const TimeScope& scope) { add(os) += scope; } CommStats& operator+=(const CommStats& other); CommStats& operator-=(const CommStats& other); diff --git a/Processor/BaseMachine.cpp b/Processor/BaseMachine.cpp index 25b8beec7..73fb3a2b5 100644 --- a/Processor/BaseMachine.cpp +++ b/Processor/BaseMachine.cpp @@ -7,6 +7,7 @@ #include #include +#include using namespace std; BaseMachine* BaseMachine::singleton = 0; @@ -28,13 +29,14 @@ BaseMachine::BaseMachine() : nthreads(0) singleton = this; } -void BaseMachine::load_schedule(string progname) +void BaseMachine::load_schedule(string progname, bool load_bytecode) { this->progname = progname; string fname = "Programs/Schedules/" + progname + ".sch"; #ifdef DEBUG_FILES cerr << "Opening file " << fname << endl; #endif + ifstream inpf; inpf.open(fname); if (inpf.fail()) { throw file_error("Missing '" + fname + "'. Did you compile '" + progname + "'?"); } @@ -54,25 +56,35 @@ void BaseMachine::load_schedule(string progname) string threadname; for (int i=0; i> threadname; - string filename = "Programs/Bytecode/" + threadname + ".bc"; + if (load_bytecode) + { + string filename = "Programs/Bytecode/" + threadname + ".bc"; #ifdef DEBUG_FILES - cerr << "Loading program " << i << " from " << filename << endl; + cerr << "Loading program " << i << " from " << filename << endl; #endif - load_program(threadname, filename); + load_program(threadname, filename); + } } + + for (auto i : {1, 0, 0}) + { + int n; + inpf >> n; + if (n != i) + throw runtime_error("old schedule format not supported"); + } + + inpf.get(); + getline(inpf, compiler); + inpf.close(); } void BaseMachine::print_compiler() { - - char compiler[1000]; - inpf.get(); - inpf.getline(compiler, 1000); #ifdef VERBOSE - if (compiler[0] != 0) + if (compiler.size() != 0) cerr << "Compiler: " << compiler << endl; #endif - inpf.close(); } void BaseMachine::load_program(string threadname, string filename) @@ -112,3 +124,20 @@ string BaseMachine::memory_filename(string type_short, int my_number) { return PREP_DIR "Memory-" + type_short + "-P" + to_string(my_number); } + +int BaseMachine::ring_size_from_schedule(string progname) +{ + assert(not singleton); + BaseMachine machine; + singleton = 0; + machine.load_schedule(progname, false); + smatch m; + regex e("R ([0-9]+)"); + regex_search(machine.compiler, m, e); + if (m.size() > 1) + { + return stoi(m[1]); + } + else + return 0; +} diff --git a/Processor/BaseMachine.h b/Processor/BaseMachine.h index 9ed54d45e..5b268a781 100644 --- a/Processor/BaseMachine.h +++ b/Processor/BaseMachine.h @@ -22,7 +22,7 @@ class BaseMachine std::map timer; - ifstream inpf; + string compiler; void print_timers(); @@ -43,10 +43,12 @@ class BaseMachine static string memory_filename(string type_short, int my_number); + static int ring_size_from_schedule(string progname); + BaseMachine(); virtual ~BaseMachine() {} - void load_schedule(string progname); + void load_schedule(string progname, bool load_bytecode = true); void print_compiler(); void time(); diff --git a/Processor/DataPositions.cpp b/Processor/DataPositions.cpp index 041c9297f..3e72cd1ba 100644 --- a/Processor/DataPositions.cpp +++ b/Processor/DataPositions.cpp @@ -14,13 +14,13 @@ const int DataPositions::tuple_size[N_DTYPE] = { 3, 2, 1, 2, 3, 3 }; void DataPositions::set_num_players(int num_players) { - files.resize(N_DATA_FIELD_TYPE, vector(N_DTYPE)); - inputs.resize(num_players, vector(N_DATA_FIELD_TYPE)); + files = {}; + inputs.resize(num_players, {}); } void DataPositions::increase(const DataPositions& delta) { - inputs.resize(max(inputs.size(), delta.inputs.size()), vector(N_DATA_FIELD_TYPE)); + inputs.resize(max(inputs.size(), delta.inputs.size()), {}); for (unsigned int field_type = 0; field_type < N_DATA_FIELD_TYPE; field_type++) { for (unsigned int dtype = 0; dtype < N_DTYPE; dtype++) @@ -39,8 +39,7 @@ void DataPositions::increase(const DataPositions& delta) DataPositions& DataPositions::operator-=(const DataPositions& delta) { - inputs.resize(max(inputs.size(), delta.inputs.size()), - vector(N_DATA_FIELD_TYPE)); + inputs.resize(max(inputs.size(), delta.inputs.size()), {}); for (unsigned int field_type = 0; field_type < N_DATA_FIELD_TYPE; field_type++) { @@ -144,3 +143,25 @@ void DataPositions::process_line(long long items_used, const char* name, cerr << suffix << endl; } } + +bool DataPositions::empty() const +{ + for (auto& x : files) + for (auto& y : x) + if (y) + return false; + + for (auto& x : inputs) + for (auto& y : x) + if (y) + return false; + + for (auto& x : extended) + if (not x.empty()) + return false; + + if (not edabits.empty()) + return false; + + return true; +} diff --git a/Processor/Data_Files.h b/Processor/Data_Files.h index b71704921..791b2ca9d 100644 --- a/Processor/Data_Files.h +++ b/Processor/Data_Files.h @@ -10,11 +10,14 @@ #include "Processor/InputTuple.h" #include "Tools/Lock.h" #include "Networking/Player.h" +#include "Protocols/edabit.h" #include #include using namespace std; +template class dabit; + class DataTag { int t[4]; @@ -50,9 +53,9 @@ class DataPositions static const char* field_names[N_DATA_FIELD_TYPE]; static const int tuple_size[N_DTYPE]; - vector< vector > files; - vector< vector > inputs; - map extended[N_DATA_FIELD_TYPE]; + array, N_DATA_FIELD_TYPE> files; + vector< array > inputs; + array, N_DATA_FIELD_TYPE> extended; map, long long> edabits; DataPositions(int num_players = 0) { set_num_players(num_players); } @@ -63,6 +66,7 @@ class DataPositions DataPositions& operator-=(const DataPositions& delta); DataPositions operator-(const DataPositions& delta) const; void print_cost() const; + bool empty() const; }; template class Processor; @@ -73,9 +77,12 @@ template class SubProcessor; template class Preprocessing { +protected: DataPositions& usage; -protected: + map, vector>> edabits; + map, edabitvec> my_edabits; + void count(Dtype dtype) { usage.files[T::field_type()][dtype]++; } void count(DataTag tag, int n = 1) { usage.extended[T::field_type()][tag] += n; } void count_input(int player) { usage.inputs[player][T::field_type()]++; } @@ -117,10 +124,12 @@ class Preprocessing virtual array get_triple(int n_bits); virtual T get_bit(); - virtual void get_dabit(T&, typename T::bit_type&) { throw runtime_error("no daBit"); } - virtual void get_edabits(bool, size_t, T*, vector&, - const vector&) - { throw runtime_error("no edaBit"); } + virtual void get_dabit(T&, typename T::bit_type&); + virtual void get_dabit_no_count(T&, typename T::bit_type&) { throw runtime_error("no daBit"); } + virtual void get_edabits(bool strict, size_t size, T* a, + vector& Sb, const vector& regs); + virtual void get_edabit_no_count(bool, int n_bits, edabit& eb); + virtual void buffer_edabits_with_queues(bool, int) { throw runtime_error("no edaBits"); } virtual void push_triples(const vector>&) { throw runtime_error("no pushing"); } @@ -145,10 +154,17 @@ class Sub_Data_Files : public Preprocessing vector> input_buffers; BufferOwner, RefInputTuple> my_input_buffers; map > extended; + BufferOwner, dabit> dabit_buffer; + map edabit_buffers; int my_num,num_players; const string prep_data_dir; + int thread_num; + + Sub_Data_Files* part; + + void buffer_edabits_with_queues(bool stric, int n_bits); public: static string get_suffix(int thread_num); @@ -202,6 +218,9 @@ class Sub_Data_Files : public Preprocessing void setup_extended(const DataTag& tag, int tuple_size = 0); void get_no_count(vector& S, DataTag tag, const vector& regs, int vector_size); + void get_dabit_no_count(T& a, typename T::bit_type& b); + + Preprocessing& get_part(); }; template diff --git a/Processor/Data_Files.hpp b/Processor/Data_Files.hpp index ea456fdad..ae28e8c4c 100644 --- a/Processor/Data_Files.hpp +++ b/Processor/Data_Files.hpp @@ -3,6 +3,8 @@ #include "Processor/Data_Files.h" #include "Processor/Processor.h" +#include "Protocols/dabit.h" +#include "Math/Setup.h" template Lock Sub_Data_Files::tuple_lengths_lock; @@ -54,7 +56,8 @@ template Sub_Data_Files::Sub_Data_Files(int my_num, int num_players, const string& prep_data_dir, DataPositions& usage, int thread_num) : Preprocessing(usage), - my_num(my_num), num_players(num_players), prep_data_dir(prep_data_dir) + my_num(my_num), num_players(num_players), prep_data_dir(prep_data_dir), + thread_num(thread_num), part(0) { #ifdef DEBUG_FILES cerr << "Setting up Data_Files in: " << prep_data_dir << endl; @@ -72,6 +75,11 @@ Sub_Data_Files::Sub_Data_Files(int my_num, int num_players, } } + sprintf(filename, (prep_data_dir + "%s-%s-P%d%s").c_str(), + DataPositions::dtype_names[DATA_DABIT], (T::type_short()).c_str(), my_num, + suffix.c_str()); + dabit_buffer.setup(filename, 1, DataPositions::dtype_names[DATA_DABIT]); + input_buffers.resize(num_players); for (int i=0; i::~Sub_Data_Files() for (auto it = extended.begin(); it != extended.end(); it++) it->second.close(); + dabit_buffer.close(); + for (auto& x: edabit_buffers) + { + x.second->close(); + delete x.second; + } + if (part != 0) + delete part; } template @@ -236,4 +252,48 @@ void Sub_Data_Files::get_no_count(vector& S, DataTag tag, const vector +void Sub_Data_Files::get_dabit_no_count(T& a, typename T::bit_type& b) +{ + dabit tmp; + dabit_buffer.input(tmp); + a = tmp.first; + b = tmp.second; +} + +template +void Sub_Data_Files::buffer_edabits_with_queues(bool strict, int n_bits) +{ +#ifndef INSECURE + throw runtime_error("no secure implementation of reading edaBits from files"); +#endif + if (edabit_buffers.find(n_bits) == edabit_buffers.end()) + { + string filename = prep_data_dir + "edaBits-" + to_string(n_bits) + "-P" + + to_string(my_num); + ifstream* f = new ifstream(filename); + if (f->fail()) + throw runtime_error("cannot open " + filename); + edabit_buffers[n_bits] = f; + } + auto& buffer = *edabit_buffers[n_bits]; + if (buffer.peek() == EOF) + buffer.seekg(0); + edabitvec eb; + eb.input(n_bits, buffer); + this->edabits[{strict, n_bits}].push_back(eb); + if (buffer.fail()) + throw runtime_error("error reading edaBits"); +} + +template +Preprocessing& Sub_Data_Files::get_part() +{ + if (part == 0) + part = new Sub_Data_Files(my_num, num_players, + get_prep_sub_dir(num_players), this->usage, + thread_num); + return *part; +} + #endif diff --git a/Processor/DummyProtocol.h b/Processor/DummyProtocol.h index 125ee2b60..c6ac9b544 100644 --- a/Processor/DummyProtocol.h +++ b/Processor/DummyProtocol.h @@ -45,6 +45,9 @@ class DummyMC : public MAC_Check_Base { throw not_implemented(); } + void CheckFor(const typename T::open_type&, const vector&, const Player&) + { + } DummyMC& get_part_MC() { @@ -56,6 +59,11 @@ class DummyMC : public MAC_Check_Base throw not_implemented(); return {}; } + + int number() + { + return 0; + } }; template @@ -63,12 +71,17 @@ class DummyProtocol : public ProtocolBase { public: Player& P; + int counter; static int get_n_relevant_players() { throw not_implemented(); } + static void multiply(vector, vector>, int, int, SubProcessor) + { + } + DummyProtocol(Player& P) : P(P) { @@ -91,6 +104,9 @@ class DummyProtocol : public ProtocolBase throw not_implemented(); return {}; } + void check() + { + } }; template @@ -170,6 +186,10 @@ class NotImplementedInput { (void) proc, (void) MC; } + template + NotImplementedInput(const T&, const U&, const W&) + { + } NotImplementedInput(Player& P) { (void) P; @@ -200,6 +220,12 @@ class NotImplementedInput (void) proc, (void) regs; throw not_implemented(); } + static void input_mixed(SubProcessor, vector, int, int) + { + } + static void raw_input(SubProcessor, vector, int) + { + } void reset_all(Player& P) { (void) P; @@ -248,7 +274,7 @@ class NotImplementedOutput (void) player, (void) target, (void) source; throw not_implemented(); } - void stop(int player, int source) + void stop(int player, int source, int) { (void) player, (void) source; } diff --git a/Processor/Input.h b/Processor/Input.h index ba10becbb..0558e7969 100644 --- a/Processor/Input.h +++ b/Processor/Input.h @@ -20,9 +20,9 @@ class InputBase { typedef typename T::clear clear; +protected: Player* P; -protected: Buffer buffer; Timer timer; @@ -42,6 +42,7 @@ class InputBase static void finalize(SubProcessor& Proc, int player, const int* params, int size); InputBase(ArithmeticProcessor* proc = 0); + InputBase(SubProcessor* proc); virtual ~InputBase(); virtual void reset(int player) = 0; @@ -56,7 +57,7 @@ class InputBase virtual T finalize_mine() = 0; virtual void finalize_other(int player, T& target, octetStream& o, int n_bits = -1) = 0; - T finalize(int player, int n_bits = -1); + virtual T finalize(int player, int n_bits = -1); void raw_input(SubProcessor& proc, const vector& args, int size); }; diff --git a/Processor/Input.hpp b/Processor/Input.hpp index 34ac2f644..be5269eec 100644 --- a/Processor/Input.hpp +++ b/Processor/Input.hpp @@ -23,6 +23,12 @@ InputBase::InputBase(ArithmeticProcessor* proc) : buffer.setup(&proc->private_input, -1, proc->private_input_filename); } +template +InputBase::InputBase(SubProcessor* proc) : + InputBase(proc ? proc->Proc : 0) +{ +} + template Input::Input(SubProcessor& proc) : Input(proc, proc.MC) @@ -92,7 +98,7 @@ void Input::add_mine(const open_type& input, int n_bits) prep.get_input(share, rr, player); t = input - rr; t.pack(this->os[player]); - share += T::constant(t, 0, MC.get_alphai()); + share += T::constant(t, player, MC.get_alphai()); this->values_input++; } @@ -190,7 +196,7 @@ void Input::finalize_other(int player, T& target, (void) n_bits; target = shares[player].next(); t.unpack(o); - target += T::constant(t, 1, MC.get_alphai()); + target += T::constant(t, P.my_num(), MC.get_alphai()); } template diff --git a/Processor/Instruction.h b/Processor/Instruction.h index 64e7b8f2e..bd8fdb97b 100644 --- a/Processor/Instruction.h +++ b/Processor/Instruction.h @@ -330,6 +330,7 @@ struct TempVars { class BaseInstruction { friend class Program; + template friend class RepRingOnlyEdabitPrep; protected: int opcode; // The code diff --git a/Processor/Instruction.hpp b/Processor/Instruction.hpp index 080ad5b8b..05aba1420 100644 --- a/Processor/Instruction.hpp +++ b/Processor/Instruction.hpp @@ -1217,7 +1217,7 @@ inline void Instruction::execute(Processor& Proc) const Proc2.DataF.get_two(DATA_INVERSE, Proc.get_S2_ref(r[0]),Proc.get_S2_ref(r[1])); break; case RANDOMS: - Procp.protocol.randoms_inst(Procp, *this); + Procp.protocol.randoms_inst(Procp.get_S(), *this); return; case INPUTMASKREG: Procp.DataF.get_input(Proc.get_Sp_ref(r[0]), Proc.temp.rrp, Proc.read_Ci(r[2])); @@ -1314,14 +1314,10 @@ inline void Instruction::execute(Processor& Proc) const break; case SHLC: to_bigint(Proc.temp.aa,Proc.read_Cp(r[2])); - if (Proc.temp.aa > 63) - throw runtime_error("too much left shift"); Proc.get_Cp_ref(r[0]).SHL(Proc.read_Cp(r[1]),Proc.temp.aa); break; case SHRC: to_bigint(Proc.temp.aa,Proc.read_Cp(r[2])); - if (Proc.temp.aa > 63) - throw runtime_error("too much right shift"); Proc.get_Cp_ref(r[0]).SHR(Proc.read_Cp(r[1]),Proc.temp.aa); break; case SHLCI: @@ -1337,8 +1333,8 @@ inline void Instruction::execute(Processor& Proc) const Proc.get_C2_ref(r[0]).SHR(Proc.read_C2(r[1]),n); break; case SHRSI: - Proc.get_Sp_ref(r[0]) = Proc.read_Sp(r[1]) >> n; - break; + sint::shrsi(Procp, *this); + return; case GBITDEC: for (int j = 0; j < size; j++) { diff --git a/Processor/Machine.hpp b/Processor/Machine.hpp index 5fcd2920e..e4fc0f0e0 100644 --- a/Processor/Machine.hpp +++ b/Processor/Machine.hpp @@ -39,35 +39,40 @@ Machine::Machine(int my_number, Names& playerNames, sint::clear::read_or_generate_setup(prep_dir_prefix(), opts); sint::bit_type::mac_key_type::init_field(); + // Initialize gf2n_short for CCD + sint::bit_type::part_type::open_type::init_field(); + // make directory for outputs if necessary mkdir_p(PREP_DIR); + Player* P; + if (use_encryption) + P = new CryptoPlayer(N, 0xF00); + else + P = new PlainPlayer(N, 0xF00); + if (opts.live_prep) { - auto P = new PlainPlayer(N, 0xF00); sint::LivePrep::basic_setup(*P); - delete P; } - sint::read_or_generate_mac_key(prep_dir_prefix(), N, alphapi); - sgf2n::read_or_generate_mac_key(prep_dir_prefix(), N, alpha2i); + sint::read_or_generate_mac_key(prep_dir_prefix(), *P, alphapi); + sgf2n::read_or_generate_mac_key(prep_dir_prefix(), *P, alpha2i); sint::bit_type::part_type::read_or_generate_mac_key( - prep_dir_prefix(), N, alphabi); + prep_dir_prefix(), *P, alphabi); #ifdef DEBUG_MAC cerr << "MAC Key p = " << alphapi << endl; cerr << "MAC Key 2 = " << alpha2i << endl; #endif - // deactivate output if necessary - sint::bit_type::out.activate(my_number == 0 or opts.interactive); - // for OT-based preprocessing sint::clear::next::template init(false); // Initialize the global memory if (memtype.compare("old")==0) { + ifstream inpf; inpf.open(memory_filename(), ios::in | ios::binary); if (inpf.fail()) { throw file_error(memory_filename()); } inpf >> M2 >> Mp >> Mi; @@ -90,16 +95,12 @@ Machine::Machine(int my_number, Names& playerNames, if (live_prep and (sint::needs_ot or sgf2n::needs_ot or sint::bit_type::needs_ot)) { - Player* P; - if (use_encryption) - P = new CryptoPlayer(playerNames, 0xF000); - else - P = new PlainPlayer(playerNames, 0xF000); for (int i = 0; i < nthreads; i++) ot_setups.push_back({ *P, true }); - delete P; } + delete P; + /* Set up the threads */ tinfo.resize(nthreads); threads.resize(nthreads); @@ -190,7 +191,7 @@ void Machine::fill_buffers(int thread_number, int tape_number, } catch (bad_cast& e) { -#ifdef VERBOSE +#ifdef VERBOSE_CENTRAL cerr << "Problem with central preprocessing" << endl; #endif } @@ -210,7 +211,7 @@ void Machine::fill_buffers(int thread_number, int tape_number, } catch (bad_cast& e) { -#ifdef VERBOSE +#ifdef VERBOSE_CENTRAL cerr << "Problem with central bit triple preprocessing: " << e.what() << endl; #endif } @@ -231,12 +232,14 @@ DataPositions Machine::run_tape(int thread_number, int tape_number, //printf("Running line %d\n",exec); if (progs[tape_number].usage_unknown()) { +#ifndef INSECURE if (not opts.live_prep) { cerr << "Internally called tape " << tape_number << " has unknown offline data usage" << endl; throw invalid_program(); } +#endif return DataPositions(N.num_players()); } else @@ -263,9 +266,6 @@ void Machine::run() proc_timer.start(); timer[0].start(); - // legacy - int _; - inpf >> _ >> _ >> _; // run main tape pos.increase(run_tape(0, 0, 0)); join_tape(0); diff --git a/Processor/NoLivePrep.h b/Processor/NoLivePrep.h index 2d2703c84..fe1d5d118 100644 --- a/Processor/NoLivePrep.h +++ b/Processor/NoLivePrep.h @@ -16,6 +16,13 @@ template class NoLivePrep : public Sub_Data_Files { public: + static void basic_setup(Player&) + { + } + static void teardown() + { + } + NoLivePrep(SubProcessor* proc, DataPositions& usage) : Sub_Data_Files(0, 0, "", usage, 0) { (void) proc; diff --git a/Processor/NoProtocol.h b/Processor/NoProtocol.h new file mode 100644 index 000000000..036be1d70 --- /dev/null +++ b/Processor/NoProtocol.h @@ -0,0 +1,42 @@ +/* + * NoProtocol.h + * + */ + +#ifndef PROCESSOR_NOPROTOCOL_H_ +#define PROCESSOR_NOPROTOCOL_H_ + +#include "Protocols/Replicated.h" + +template +class NoProtocol : public ProtocolBase +{ +public: + NoProtocol(Player&) + { + + } + + void init_mul(SubProcessor*) + { + throw not_implemented(); + } + typename T::clear prepare_mul(const T&, const T&, int n = -1) + { + (void) n; + throw not_implemented(); + } + void exchange() + { + throw not_implemented(); + } + T finalize_mul(int n = -1) + { + (void) n; + throw not_implemented(); + } +}; + + + +#endif /* PROCESSOR_NOPROTOCOL_H_ */ diff --git a/Processor/Online-Thread.hpp b/Processor/Online-Thread.hpp index bf4e8074b..5a2cad196 100644 --- a/Processor/Online-Thread.hpp +++ b/Processor/Online-Thread.hpp @@ -240,19 +240,6 @@ void thread_info::Sub_Main_Func() job.pos.increase(Proc.DataF.get_usage()); } - //double elapsed = timeval_diff(&startv, &endv); - //printf("Thread time = %f seconds\n",elapsed/1000000); - //printf("\texec = %d\n",exec); exec++; - //printf("\tMC2.number = %d\n",MC2.number()); - //printf("\tMCp.number = %d\n",MCp.number()); - - // MACCheck - MC2->Check(P); - MCp->Check(P); - //printf("\tMAC checked\n"); - P.Check_Broadcast(); - //printf("\tBroadcast checked\n"); - #ifdef DEBUG_THREADS printf("\tSignalling I have finished\n"); #endif @@ -269,6 +256,7 @@ void thread_info::Sub_Main_Func() // MACCheck MC2->Check(P); MCp->Check(P); + Proc.share_thread.MC->Check(P); //cout << num << " : Checking broadcast" << endl; P.Check_Broadcast(); diff --git a/Processor/OnlineOptions.cpp b/Processor/OnlineOptions.cpp index 563619f9a..8e833e3df 100644 --- a/Processor/OnlineOptions.cpp +++ b/Processor/OnlineOptions.cpp @@ -17,6 +17,7 @@ OnlineOptions::OnlineOptions() : playerno(-1) live_prep = true; batch_size = 10000; memtype = "empty"; + bits_from_squares = false; direct = false; bucket_size = 3; cmd_private_input_file = "Player-Data/Input"; @@ -138,6 +139,15 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc, "-m", // Flag token. "--memory" // Flag token. ); + opt.add( + "", // Default. + 0, // Required? + 0, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Compute random bits from squares", // Help description. + "-Q", // Flag token. + "--bits-from-squares" // Flag token. + ); opt.add( "", // Default. 0, // Required? @@ -174,6 +184,7 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc, live_prep = opt.get("-L")->isSet; opt.get("-b")->getInt(batch_size); opt.get("--memory")->getString(memtype); + bits_from_squares = opt.isSet("-Q"); opt.get("-IF")->getString(cmd_private_input_file); opt.get("-OF")->getString(cmd_private_output_file); diff --git a/Processor/OnlineOptions.h b/Processor/OnlineOptions.h index 8fe301952..40925031a 100644 --- a/Processor/OnlineOptions.h +++ b/Processor/OnlineOptions.h @@ -22,6 +22,7 @@ class OnlineOptions std::string progname; int batch_size; std::string memtype; + bool bits_from_squares; bool direct; int bucket_size; std::string cmd_private_input_file; diff --git a/Processor/Processor.hpp b/Processor/Processor.hpp index 57527dfad..2c1a63540 100644 --- a/Processor/Processor.hpp +++ b/Processor/Processor.hpp @@ -52,6 +52,11 @@ SubProcessor::~SubProcessor() if (bit_prep.data_sent()) cerr << "Sent for global bit preprocessing threads: " << bit_prep.data_sent() * 1e-6 << " MB" << endl; + if (not bit_usage.empty()) + { + cerr << "Mixed-circuit preprocessing cost:" << endl; + bit_usage.print_cost(); + } #endif } @@ -82,13 +87,16 @@ Processor::Processor(int thread_num,Player& P, secure_prng.ReSeed(); shared_prng.SeedGlobally(P); - out.activate(P.my_num() == 0 or machine.opts.interactive); + // only output on party 0 if not interactive + bool output = P.my_num() == 0 or machine.opts.interactive; + out.activate(output); + Procb.out.activate(output); + setup_redirection(P.my_num(), thread_num, opts); - if (!machine.opts.cmd_private_output_file.empty()) + if (stdout_redirect_file.is_open()) { - const string stdout_filename = get_parameterized_filename(P.my_num(), thread_num, opts.cmd_private_output_file); - stdout_redirect_file.open(stdout_filename.c_str(), ios_base::out); out.redirect_to_file(stdout_redirect_file); + Procb.out.redirect_to_file(stdout_redirect_file); } } diff --git a/Processor/ProcessorBase.cpp b/Processor/ProcessorBase.cpp new file mode 100644 index 000000000..3bde79b5f --- /dev/null +++ b/Processor/ProcessorBase.cpp @@ -0,0 +1,17 @@ +/* + * ProcessorBase.cpp + * + */ + +#include "ProcessorBase.hpp" + +void ProcessorBase::setup_redirection(int my_num, int thread_num, + OnlineOptions& opts) +{ + if (not opts.cmd_private_output_file.empty()) + { + const string stdout_filename = get_parameterized_filename(my_num, + thread_num, opts.cmd_private_output_file); + stdout_redirect_file.open(stdout_filename.c_str(), ios_base::out); + } +} diff --git a/Processor/ProcessorBase.h b/Processor/ProcessorBase.h index f00ac3357..79cb7ea8f 100644 --- a/Processor/ProcessorBase.h +++ b/Processor/ProcessorBase.h @@ -12,6 +12,7 @@ using namespace std; #include "Tools/ExecutionStats.h" +#include "OnlineOptions.h" class ProcessorBase { @@ -30,6 +31,8 @@ class ProcessorBase public: ExecutionStats stats; + ofstream stdout_redirect_file; + void pushi(long x) { stacki.push(x); } void popi(long& x) { x = stacki.top(); stacki.pop(); } @@ -50,6 +53,8 @@ class ProcessorBase T get_input(bool interactive, const int* params); template T get_input(istream& is, const string& input_filename, const int* params); + + void setup_redirection(int my_nu, int thread_num, OnlineOptions& opts); }; #endif /* PROCESSOR_PROCESSORBASE_H_ */ diff --git a/Processor/RingOptions.cpp b/Processor/RingOptions.cpp index 1da103095..d59a709a8 100644 --- a/Processor/RingOptions.cpp +++ b/Processor/RingOptions.cpp @@ -4,11 +4,13 @@ */ #include "RingOptions.h" +#include "BaseMachine.h" #include using namespace std; -RingOptions::RingOptions(ez::ezOptionParser& opt, int argc, const char** argv) +RingOptions::RingOptions(ez::ezOptionParser& opt, int argc, const char** argv, + bool security) { opt.add( "64", // Default. @@ -19,8 +21,37 @@ RingOptions::RingOptions(ez::ezOptionParser& opt, int argc, const char** argv) "-R", // Flag token. "--ring" // Flag token. ); + if (security) + opt.add( + "40", // Default. + 0, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Security parameter (default: 40)", // Help description. + "-S", // Flag token. + "--security" // Flag token. + ); opt.parse(argc, argv); opt.get("-R")->getInt(R); + if (security) + opt.get("-S")->getInt(S); + else + S = -1; + R_is_set = opt.isSet("-R"); opt.resetArgs(); - cerr << "Trying to run " << R << "-bit computation" << endl; + if (R_is_set) + cerr << "Trying to run " << R << "-bit computation" << endl; + if (security) + cerr << "Using security parameter " << S << endl; +} + +int RingOptions::ring_size_from_opts_or_schedule(string progname) +{ + if (R_is_set) + return R; + int r = BaseMachine::ring_size_from_schedule(progname); + if (r == 0) + r = R; + cerr << "Trying to run " << r << "-bit computation" << endl; + return r; } diff --git a/Processor/RingOptions.h b/Processor/RingOptions.h index 4a34e88a1..899c7021a 100644 --- a/Processor/RingOptions.h +++ b/Processor/RingOptions.h @@ -7,13 +7,21 @@ #define PROCESSOR_RINGOPTIONS_H_ #include "Tools/ezOptionParser.h" +#include +using namespace std; class RingOptions { + bool R_is_set; + public: int R; + int S; + + RingOptions(ez::ezOptionParser& opt, int argc, const char** argv, + bool security = false); - RingOptions(ez::ezOptionParser& opt, int argc, const char** argv); + int ring_size_from_opts_or_schedule(string progname); }; #endif /* PROCESSOR_RINGOPTIONS_H_ */ diff --git a/Processor/TruncPrTuple.h b/Processor/TruncPrTuple.h new file mode 100644 index 000000000..06a96845f --- /dev/null +++ b/Processor/TruncPrTuple.h @@ -0,0 +1,76 @@ +/* + * TruncPrTuple.h + * + */ + +#ifndef PROCESSOR_TRUNCPRTUPLE_H_ +#define PROCESSOR_TRUNCPRTUPLE_H_ + +#include +#include +using namespace std; + +template +class TruncPrTuple +{ +public: + int dest_base; + int source_base; + int k; + int m; + int n_shift; + + TruncPrTuple(const vector& regs, size_t base) + { + dest_base = regs[base]; + source_base = regs[base + 1]; + k = regs[base + 2]; + m = regs[base + 3]; + n_shift = T::N_BITS - 1 - k; + assert(m < k); + assert(0 < k); + assert(m < T::N_BITS); + } + + T upper(T mask) + { + return (mask << (n_shift + 1)) >> (n_shift + m + 1); + } + + T msb(T mask) + { + return (mask << (n_shift)) >> (T::N_BITS - 1); + } + +}; + +template +class TruncPrTupleWithGap : public TruncPrTuple +{ +public: + TruncPrTupleWithGap(const vector& regs, size_t base) : + TruncPrTuple(regs, base) + { + } + + T upper(T mask) + { + if (big_gap()) + return mask >> this->m; + else + return TruncPrTuple::upper(mask); + } + + T msb(T mask) + { + assert(not big_gap()); + return TruncPrTuple::msb(mask); + } + + bool big_gap() + { + return this->k <= T::N_BITS - 40; + } +}; + +#endif /* PROCESSOR_TRUNCPRTUPLE_H_ */ diff --git a/Processor/instructions.h b/Processor/instructions.h index ca5167c95..fb15007a0 100644 --- a/Processor/instructions.h +++ b/Processor/instructions.h @@ -55,6 +55,9 @@ X(MULCI, auto dest = &Procp.get_C()[r[0]]; auto op1 = &Procp.get_C()[r[1]]; \ typename sint::clear op2 = int(n), \ *dest++ = *op1++ * op2) \ + X(MULSI, auto dest = &Procp.get_S()[r[0]]; auto op1 = &Procp.get_S()[r[1]]; \ + typename sint::clear op2 = int(n), \ + *dest++ = *op1++ * op2) \ X(SHRCI, auto dest = &Procp.get_C()[r[0]]; auto op1 = &Procp.get_C()[r[1]], \ *dest++ = *op1++ >> n) \ X(TRIPLE, auto a = &Procp.get_S()[r[0]]; auto b = &Procp.get_S()[r[1]]; \ diff --git a/Programs/Source/bio.mpc b/Programs/Source/bio.mpc index 7c99ee3f3..b51a8d194 100644 --- a/Programs/Source/bio.mpc +++ b/Programs/Source/bio.mpc @@ -21,7 +21,7 @@ def match(db_entry, sample): from Compiler import util if n_threads is None: - util.tree_reduce(lambda x, y: x.min(y), (match(db[i], sample) for i in range(n))) + res = util.tree_reduce(lambda x, y: x.min(y), (match(db[i], sample) for i in range(n))) else: tmp = sint.Array(n_threads) @@ -31,4 +31,6 @@ else: (match(db[base + i], sample) for i in range(size))) - util.tree_reduce(lambda x, y: x.min(y), tmp) + res = util.tree_reduce(lambda x, y: x.min(y), tmp) + +print_ln('result: %s', res.reveal()) diff --git a/Programs/Source/mnist_49.mpc b/Programs/Source/mnist_49.mpc new file mode 100644 index 000000000..1ad2eaddf --- /dev/null +++ b/Programs/Source/mnist_49.mpc @@ -0,0 +1,69 @@ +import ml +import math +import re +import util + +program.options_from_args() +sfix.set_precision_from_args(program) + +n_examples = 11791 +n_test = 1991 +n_features = 28 ** 2 + +try: + n_epochs = int(program.args[1]) +except: + n_epochs = 100 + +N = n_examples +batch_size = 128 + +assert batch_size <= N + +try: + ml.set_n_threads(int(program.args[2])) +except: + pass + +n_inner = 128 + +n_dense_layers = None +for arg in program.args: + m = re.match('(.*)dense', arg) + if m: + n_dense_layers = int(m.group(1)) + +if n_dense_layers == 1: + layers = [ml.Dense(N, n_features, 1, activation='id')] +elif n_dense_layers > 1: + layers = [ml.Dense(N, n_features, n_inner, activation='relu')] + for i in range(n_dense_layers - 2): + layers += [ml.Dense(N, n_inner, n_inner, activation='relu')] + layers += [ml.Dense(N, n_inner, 1, activation='id')] +else: + raise CompilerError('number of dense layers not specified') + +layers += [ml.Output.from_args(N, program)] + +Y = sint.Array(n_test) +X = sfix.Matrix(n_test, n_features) + +if not ('no_acc' in program.args and 'no_loss' in program.args): + layers[-1].Y.input_from(0) + layers[0].X.input_from(0) + Y.input_from(0) + X.input_from(0) + +sgd = ml.SGD(layers, 1) + +if 'no_out' in program.args: + del sgd.layers[-1] + +if 'forward' in program.args: + sgd.forward(batch=regint.Array(batch_size)) +elif 'backward' in program.args: + sgd.backward(batch=regint.Array(batch_size)) +elif 'update' in program.args: + sgd.update(0, batch=regint.Array(batch_size)) +else: + sgd.run_by_args(program, n_epochs, batch_size, X, Y) diff --git a/Programs/Source/mnist_A.mpc b/Programs/Source/mnist_A.mpc new file mode 100644 index 000000000..0bd1c0a99 --- /dev/null +++ b/Programs/Source/mnist_A.mpc @@ -0,0 +1,99 @@ +import ml +import math + +#ml.report_progress = True + +program.options_from_args() + +approx = 3 + +if 'profile' in program.args: + print('Compiling for profiling') + N = 1000 + n_test = 100 +elif 'debug' in program.args: + N = 10 + n_test = 10 +elif 'gisette' in program.args: + print('Compiling for 4/9') + N = 11791 + n_test = 1991 +else: + N = 12665 + n_test = 2115 + +n_examples = N +n_features = 28 ** 2 + +try: + n_epochs = int(program.args[1]) +except: + n_epochs = 100 + +try: + batch_size = int(program.args[2]) +except: + batch_size = N + +assert batch_size <= N + +try: + ml.set_n_threads(int(program.args[3])) +except: + pass + +if 'debug' in program.args: + n_inner = 10 + n_features = 10 +else: + n_inner = 128 + +if 'norelu' in program.args: + activation = 'id' +else: + activation = 'relu' + +layers = [ml.Dense(N, n_features, n_inner, activation=activation), + ml.Dense(N, n_inner, n_inner, activation=activation), + ml.Dense(N, n_inner, 1), + ml.Output(N, approx=approx)] + +if '2dense' in program.args: + del layers[1] + +layers[-1].Y.input_from(0) +layers[0].X.input_from(0) + +Y = sint.Array(n_test) +X = sfix.Matrix(n_test, n_features) +Y.input_from(0) +X.input_from(0) + +sgd = ml.SGD(layers, 10, report_loss=True) +sgd.reset() + +@for_range(int(math.ceil(n_epochs / 10))) +def _(i): + start_timer(1) + sgd.run(batch_size) + stop_timer(1) + + def get_correct(Y, n): + n_correct = regint(0) + for i in range(n): + n_correct += (Y[i].reveal() > 0).bit_xor( + layers[-2].Y[i][0][0][0].reveal() < 0) + return n_correct + + sgd.forward(N) + + n_correct = get_correct(layers[-1].Y, N) + print_ln('train_acc: %s (%s/%s)', cfix(n_correct) / N, n_correct, N) + + training_address = layers[0].X.address + layers[0].X.address = X.address + sgd.forward(n_test) + layers[0].X.address = training_address + + n_correct = get_correct(Y, n_test) + print_ln('acc: %s (%s/%s)', cfix(n_correct) / n_test, n_correct, n_test) diff --git a/Programs/Source/mnist_full_A.mpc b/Programs/Source/mnist_full_A.mpc new file mode 100644 index 000000000..b1e249595 --- /dev/null +++ b/Programs/Source/mnist_full_A.mpc @@ -0,0 +1,109 @@ +import ml +import math +import re +import util + +#ml.report_progress = True + +program.options_from_args() + +if 'profile' in program.args: + print('Compiling for profiling') + N = 1000 + n_test = 100 +elif 'debug' in program.args: + N = 100 + n_test = 100 +else: + N = 60000 + n_test = 10000 + +n_examples = N +n_features = 28 ** 2 + +try: + n_epochs = int(program.args[1]) +except: + n_epochs = 100 + +try: + batch_size = int(program.args[2]) +except: + batch_size = N + +assert batch_size <= N + +try: + ml.set_n_threads(int(program.args[3])) +except: + pass + +n_inner = 128 + +if 'norelu' in program.args: + activation = 'id' +else: + activation = 'relu' + +if 'nearest' in program.args: + sfix.round_nearest = True + +if 'double' in program.args: + sfix.set_precision(32, 63) + cfix.set_precision(32, 63) +elif 'triple' in program.args: + sfix.set_precision(48, 91) + cfix.set_precision(48, 91) +elif 'quadruple' in program.args: + sfix.set_precision(64, 127) + cfix.set_precision(64, 127) +elif 'sextuple' in program.args: + sfix.set_precision(96, 191) + cfix.set_precision(96, 191) +elif 'octuple' in program.args: + sfix.set_precision(128, 255) + cfix.set_precision(128, 255) + +assert sfix.f * 4 == int(program.options.ring) + +debug_ml = ('debug_ml' in program.args) * 2 ** (sfix.f / 2) + +if '1dense' in program.args: + layers = [ml.Dense(N, n_features, 10, debug=debug_ml)] +else: + layers = [ml.Dense(N, n_features, n_inner, activation=activation, debug=debug_ml), + ml.Dense(N, n_inner, n_inner, activation=activation, debug=debug_ml), + ml.Dense(N, n_inner, 10, debug=debug_ml)] + +layers += [ml.MultiOutput.from_args(program, N, 10)] + +layers[-1].cheaper_loss = 'mse' in program.args + +if '2dense' in program.args: + del layers[1] + +layers[-1].Y.input_from(0) +layers[0].X.input_from(0) + +Y = sint.Matrix(n_test, 10) +X = sfix.Matrix(n_test, n_features) +Y.input_from(0) +X.input_from(0) + +if 'always_acc' in program.args: + n_part_epochs = 1 +else: + n_part_epochs = 10 + +sgd = ml.SGD(layers, n_part_epochs, report_loss=True, debug=debug_ml) +#sgd.print_update_average = True +sgd.print_losses = 'print_losses' in program.args + +if 'faster' in program.args: + sgd.gamma = MemValue(cfix(.1)) + +if 'slower' in program.args: + sgd.gamma = MemValue(cfix(.001)) + +sgd.run_by_args(program, int(math.ceil(n_epochs / n_part_epochs)), batch_size, + X, Y) diff --git a/Programs/Source/mnist_logreg.mpc b/Programs/Source/mnist_logreg.mpc new file mode 100644 index 000000000..f7d77bd59 --- /dev/null +++ b/Programs/Source/mnist_logreg.mpc @@ -0,0 +1,59 @@ +import ml + +program.options_from_args() + +approx = 3 + +if 'gisette' in program.args: + print('Compiling for 4/9') + N = 11791 + n_test = 1991 +else: + N = 12665 + n_test = 2115 + +n_examples = N +n_features = 28 ** 2 + +try: + n_epochs = int(program.args[1]) +except: + n_epochs = 100 + +try: + batch_size = int(program.args[2]) +except: + batch_size = N + +try: + ml.set_n_threads(int(program.args[3])) +except: + pass + +layers = [ml.Dense(N, n_features, 1), + ml.Output(N, approx=approx)] + +layers[1].Y.input_from(0) +layers[0].X.input_from(0) + +Y = sint.Array(n_test) +X = sfix.Matrix(n_test, n_features) +Y.input_from(0) +X.input_from(0) + +sgd = ml.SGD(layers, n_epochs, report_loss=True) +sgd.reset() + +start_timer(1) +sgd.run(batch_size) +stop_timer(1) + +layers[0].X.assign(X) +sgd.forward(n_test) + +n_correct = cfix(0) + +for i in range(n_test): + n_correct += Y[i].reveal().bit_xor(layers[0].Y[i][0][0][0].reveal() < 0) + +print_ln('acc: %s (%s/%s)', n_correct / n_test, n_correct, n_test) diff --git a/Programs/Source/regression.mpc b/Programs/Source/regression.mpc index 333aebe4e..413f3c59d 100644 --- a/Programs/Source/regression.mpc +++ b/Programs/Source/regression.mpc @@ -37,6 +37,12 @@ if len(program.args) > 2: n_normal = 49 n_features = 17814 +if 'mnist' in program.args: + print('Compiling for MNIST') + n_examples = 2115 + n_normal = 980 + n_features = 28 ** 2 + n_pos = n_examples - n_normal n_epochs = 1 if len(program.args) > 1: diff --git a/Protocols/FakeInput.h b/Protocols/FakeInput.h index 3c232520a..ea3d118ea 100644 --- a/Protocols/FakeInput.h +++ b/Protocols/FakeInput.h @@ -15,6 +15,10 @@ class FakeInput : public InputBase PointerVector results; public: + FakeInput() + { + } + FakeInput(SubProcessor&, typename T::MAC_Check&) { } diff --git a/Protocols/FakePrep.h b/Protocols/FakePrep.h index 277aa9f24..587274f7f 100644 --- a/Protocols/FakePrep.h +++ b/Protocols/FakePrep.h @@ -77,6 +77,12 @@ class FakePrep : public BufferPrep a = bit; b = bit; } + + void get_one_no_count(Dtype dtype, T& a) + { + assert(dtype == DATA_BIT); + a = G.get_uchar() & 1; + } }; #endif /* PROTOCOLS_FAKEPREP_H_ */ diff --git a/Protocols/FakeProtocol.h b/Protocols/FakeProtocol.h index 006dd7539..cce26b222 100644 --- a/Protocols/FakeProtocol.h +++ b/Protocols/FakeProtocol.h @@ -7,6 +7,7 @@ #define PROTOCOLS_FAKEPROTOCOL_H_ #include "Replicated.h" +#include "Math/Z2k.h" template class FakeProtocol : public ProtocolBase @@ -14,6 +15,10 @@ class FakeProtocol : public ProtocolBase PointerVector results; SeededPRNG G; + T dot_prod; + + T trunc_max; + public: Player& P; @@ -21,6 +26,27 @@ class FakeProtocol : public ProtocolBase { } +#ifdef VERBOSE + ~FakeProtocol() + { + output_trunc_max<0>(T::invertible); + } + + template + void output_trunc_max(false_type) + { + if (trunc_max != T()) + cerr << "Maximum bit length in truncation: " + << (bigint(typename T::clear(trunc_max)).numBits() + 1) + << " (" << trunc_max << ")" << endl; + } + + template + void output_trunc_max(true_type) + { + } +#endif + void init_mul(SubProcessor*) { results.clear(); @@ -41,6 +67,28 @@ class FakeProtocol : public ProtocolBase return results.next(); } + void init_dotprod(SubProcessor* proc) + { + init_mul(proc); + dot_prod = {}; + } + + void prepare_dotprod(const T& x, const T& y) + { + dot_prod += x * y; + } + + void next_dotprod() + { + results.push_back(dot_prod); + dot_prod = 0; + } + + T finalize_dotprod(int) + { + return finalize_mul(); + } + void randoms(T& res, int n_bits) { res.randomize_part(G, n_bits); @@ -52,11 +100,63 @@ class FakeProtocol : public ProtocolBase } void trunc_pr(const vector& regs, int size, SubProcessor& proc) + { + trunc_pr<0>(regs, size, proc, T::characteristic_two); + } + + template + void trunc_pr(const vector&, int, SubProcessor&, true_type) + { + throw not_implemented(); + } + + template + void trunc_pr(const vector& regs, int size, SubProcessor& proc, false_type) { for (size_t i = 0; i < regs.size(); i += 4) for (int l = 0; l < size; l++) - proc.get_S_ref(regs[i] + l) = proc.get_S_ref(regs[i + 1] + l) - >> regs[i + 3]; + { + auto& res = proc.get_S_ref(regs[i] + l); + auto& source = proc.get_S_ref(regs[i + 1] + l); + T tmp = source - (T(1) << regs[i + 2] - 1); + tmp = tmp < T() ? (T() - tmp) : tmp; + trunc_max = max(trunc_max, tmp); +#ifdef CHECK_BOUNDS_IN_TRUNC_PR_EMULATION + auto test = (source >> (regs[i + 2])); + if (test != 0) + { + cerr << typename T::clear(source) << " has more than " + << regs[i + 2] + << " bits in " << regs[i + 3] + << "-bit truncation (test value " + << typename T::clear(test) << ")" << endl; + throw runtime_error("trunc_pr overflow"); + } +#endif + int n_shift = regs[i + 3]; +#ifdef ROUND_NEAREST_IN_EMULATION + res = source >> n_shift; + if (n_shift > 0) + { + bool overflow = T(source >> (n_shift - 1)).get_bit(0); + res += overflow; + } +#else +#ifdef RISKY_TRUNCATION_IN_EMULATION + T r; + r.randomize(G); + + if (source.negative()) + res = -T(((-source + r) >> n_shift) - (r >> n_shift)); + else + res = ((source + r) >> n_shift) - (r >> n_shift); +#else + T r; + r.randomize_part(G, n_shift - 1); + res = (source + r) >> n_shift; +#endif +#endif + } } }; diff --git a/Protocols/MAC_Check_Base.h b/Protocols/MAC_Check_Base.h index cefe73784..ab0765ccb 100644 --- a/Protocols/MAC_Check_Base.h +++ b/Protocols/MAC_Check_Base.h @@ -25,7 +25,8 @@ class MAC_Check_Base public: int values_opened; - MAC_Check_Base() : values_opened(0) {} + MAC_Check_Base(const typename T::mac_key_type::Scalar& mac_key = { }) : + alphai(mac_key), values_opened(0) {} virtual ~MAC_Check_Base() {} virtual void Check(const Player& P) { (void)P; } diff --git a/Protocols/MalRepRingPrep.h b/Protocols/MalRepRingPrep.h index 7e30152b0..2ccb32211 100644 --- a/Protocols/MalRepRingPrep.h +++ b/Protocols/MalRepRingPrep.h @@ -36,19 +36,12 @@ class RingOnlyBitsFromSquaresPrep : public virtual BufferPrep void buffer_bits(); }; -// extra class to avoid recursion template -class MalRepRingPrepWithBits: public virtual MaliciousRingPrep, - public virtual MalRepRingPrep, +class SimplerMalRepRingPrep : public virtual MalRepRingPrep, public virtual RingOnlyBitsFromSquaresPrep { public: - MalRepRingPrepWithBits(SubProcessor* proc, DataPositions& usage); - - void set_protocol(typename T::Protocol& protocol) - { - MaliciousRingPrep::set_protocol(protocol); - } + SimplerMalRepRingPrep(SubProcessor* proc, DataPositions& usage); void buffer_triples() { @@ -72,4 +65,27 @@ class MalRepRingPrepWithBits: public virtual MaliciousRingPrep, } }; +template +class MalRepRingPrepWithBits: public virtual MaliciousRingPrep, + public virtual SimplerMalRepRingPrep +{ +public: + MalRepRingPrepWithBits(SubProcessor* proc, DataPositions& usage); + + void set_protocol(typename T::Protocol& protocol) + { + MaliciousRingPrep::set_protocol(protocol); + } + + void buffer_squares() + { + MalRepRingPrep::buffer_squares(); + } + + void buffer_bits() + { + RingOnlyBitsFromSquaresPrep::buffer_bits(); + }; +}; + #endif /* PROTOCOLS_MALREPRINGPREP_H_ */ diff --git a/Protocols/MalRepRingPrep.hpp b/Protocols/MalRepRingPrep.hpp index ca7f3f70d..790c45e99 100644 --- a/Protocols/MalRepRingPrep.hpp +++ b/Protocols/MalRepRingPrep.hpp @@ -27,13 +27,22 @@ RingOnlyBitsFromSquaresPrep::RingOnlyBitsFromSquaresPrep(SubProcessor*, { } +template +SimplerMalRepRingPrep::SimplerMalRepRingPrep(SubProcessor* proc, + DataPositions& usage) : + BufferPrep(usage), MalRepRingPrep(proc, usage), + RingOnlyBitsFromSquaresPrep(proc, usage) +{ +} + template MalRepRingPrepWithBits::MalRepRingPrepWithBits(SubProcessor* proc, DataPositions& usage) : BufferPrep(usage), BitPrep(proc, usage), RingPrep(proc, usage), MaliciousRingPrep(proc, usage), MalRepRingPrep(proc, usage), - RingOnlyBitsFromSquaresPrep(proc, usage) + RingOnlyBitsFromSquaresPrep(proc, usage), + SimplerMalRepRingPrep(proc, usage) { } @@ -54,6 +63,7 @@ void MalRepRingPrep::buffer_squares() MaliciousRepPrep prep(_); assert(this->proc != 0); prep.init_honest(this->proc->P); + prep.buffer_size = this->buffer_size; prep.buffer_squares(); for (auto& x : prep.squares) this->squares.push_back({{x[0], x[1]}}); @@ -68,6 +78,7 @@ void MalRepRingPrep::simple_buffer_triples() MaliciousRepPrep prep(_); assert(this->proc != 0); prep.init_honest(this->proc->P); + prep.buffer_size = this->buffer_size; prep.buffer_triples(); for (auto& x : prep.triples) this->triples.push_back({{x[0], x[1], x[2]}}); @@ -222,7 +233,7 @@ void RingOnlyBitsFromSquaresPrep::buffer_bits() typename BitShare::SquarePrep prep(0, usage); SubProcessor bit_proc(MC, prep, proc->P); prep.set_proc(&bit_proc); - bits_from_square_in_ring(this->bits, OnlineOptions::singleton.batch_size, &prep); + bits_from_square_in_ring(this->bits, this->buffer_size, &prep); } template diff --git a/Protocols/MaliciousRep3Share.h b/Protocols/MaliciousRep3Share.h index def702d13..1967994d0 100644 --- a/Protocols/MaliciousRep3Share.h +++ b/Protocols/MaliciousRep3Share.h @@ -11,6 +11,7 @@ template class HashMaliciousRepMC; template class Beaver; template class MaliciousRepPrepWithBits; +template class MaliciousRepPO; template class MaliciousRepPrep; namespace GC @@ -22,6 +23,7 @@ template class MaliciousRep3Share : public Rep3Share { typedef Rep3Share super; + typedef MaliciousRep3Share This; public: typedef Beaver> Protocol; @@ -29,11 +31,13 @@ class MaliciousRep3Share : public Rep3Share typedef MAC_Check Direct_MC; typedef ReplicatedInput> Input; typedef ::PrivateOutput> PrivateOutput; + typedef MaliciousRepPO PO; typedef Rep3Share Honest; typedef MaliciousRepPrepWithBits LivePrep; typedef MaliciousRepPrep TriplePrep; typedef MaliciousRep3Share prep_type; typedef T random_type; + typedef This Scalar; typedef GC::MaliciousRepSecret bit_type; diff --git a/Protocols/MaliciousRepMC.hpp b/Protocols/MaliciousRepMC.hpp index 5339a65c2..e64db21ad 100644 --- a/Protocols/MaliciousRepMC.hpp +++ b/Protocols/MaliciousRepMC.hpp @@ -144,7 +144,7 @@ void HashMaliciousRepMC::Check(const Player& P) P.Broadcast_Receive(os); for (int i = 0; i < P.num_players(); i++) if (os[i] != os[P.my_num()]) - throw mac_fail(); + throw mac_fail("check hash mismatch"); } } diff --git a/Protocols/MaliciousRepPO.h b/Protocols/MaliciousRepPO.h new file mode 100644 index 000000000..62d4b1783 --- /dev/null +++ b/Protocols/MaliciousRepPO.h @@ -0,0 +1,27 @@ +/* + * MaliciousRepPO.h + * + */ + +#ifndef PROTOCOLS_MALICIOUSREPPO_H_ +#define PROTOCOLS_MALICIOUSREPPO_H_ + +#include "Networking/Player.h" + +template +class MaliciousRepPO +{ + Player& P; + octetStream to_send; + octetStream to_receive[2]; + +public: + MaliciousRepPO(Player& P); + + void prepare_sending(const T& secret, int player); + void send(int player); + void receive(); + typename T::clear finalize(const T& secret); +}; + +#endif /* PROTOCOLS_MALICIOUSREPPO_H_ */ diff --git a/Protocols/MaliciousRepPO.hpp b/Protocols/MaliciousRepPO.hpp new file mode 100644 index 000000000..d0786b60c --- /dev/null +++ b/Protocols/MaliciousRepPO.hpp @@ -0,0 +1,44 @@ +/* + * MaliciousRepPO.cpp + * + */ + +#include "MaliciousRepPO.h" + +#include + +template +MaliciousRepPO::MaliciousRepPO(Player& P) : P(P) +{ + assert(P.num_players() == 3); +} + +template +void MaliciousRepPO::prepare_sending(const T& secret, int player) +{ + secret[2 - P.get_offset(player)].pack(to_send); +} + +template +void MaliciousRepPO::send(int player) +{ + if (P.get_offset(player) == 2) + P.send_to(player, to_send, true); + else + P.send_to(player, to_send.hash(), true); +} + +template +void MaliciousRepPO::receive() +{ + for (int i = 0; i < 2; i++) + P.receive_relative(i + 1, to_receive[i]); + if (to_receive[0].hash() != to_receive[1]) + throw mac_fail("mismatch in private output"); +} + +template +typename T::clear MaliciousRepPO::finalize(const T& secret) +{ + return secret.sum() + to_receive[0].template get(); +} diff --git a/Protocols/MaliciousRepPrep.h b/Protocols/MaliciousRepPrep.h index 7783671f0..c4f8c0dc5 100644 --- a/Protocols/MaliciousRepPrep.h +++ b/Protocols/MaliciousRepPrep.h @@ -17,7 +17,7 @@ template class MalRepRingShare; template class PostSacriRepRingShare; template -void sacrifice(const vector& check_triples, Player& P); +void sacrifice(const vector>& check_triples, Player& P); template class MaliciousRepPrep : public virtual BufferPrep diff --git a/Protocols/MaliciousRepPrep.hpp b/Protocols/MaliciousRepPrep.hpp index ed32f3b97..f77bba457 100644 --- a/Protocols/MaliciousRepPrep.hpp +++ b/Protocols/MaliciousRepPrep.hpp @@ -63,7 +63,7 @@ void MaliciousRepPrep::buffer_triples() { assert(T::open_type::length() >= 40); auto& triples = this->triples; - auto buffer_size = OnlineOptions::singleton.batch_size; + auto buffer_size = this->buffer_size; clear_tmp(); assert(honest_proc != 0); Player& P = honest_proc->P; @@ -122,11 +122,12 @@ template void MaliciousRepPrep::buffer_squares() { auto& squares = this->squares; - auto buffer_size = OnlineOptions::singleton.batch_size; + auto buffer_size = this->buffer_size; clear_tmp(); assert(honest_proc); Player& P = honest_proc->P; squares.clear(); + honest_prep.buffer_size = buffer_size; for (int i = 0; i < buffer_size; i++) { T a, b; diff --git a/Protocols/MaliciousShamirMC.h b/Protocols/MaliciousShamirMC.h index 9c7e607ef..a6b59fae2 100644 --- a/Protocols/MaliciousShamirMC.h +++ b/Protocols/MaliciousShamirMC.h @@ -11,6 +11,8 @@ template class MaliciousShamirMC : public ShamirMC { + typedef typename T::open_type open_type; + vector> reconstructions; vector shares; @@ -34,6 +36,8 @@ class MaliciousShamirMC : public ShamirMC void init_open(const Player& P, int n = 0); typename T::open_type finalize_open(); + + typename T::open_type reconstruct(const vector& shares); }; #endif /* PROTOCOLS_MALICIOUSSHAMIRMC_H_ */ diff --git a/Protocols/MaliciousShamirMC.hpp b/Protocols/MaliciousShamirMC.hpp index 92963a2d3..0309852e8 100644 --- a/Protocols/MaliciousShamirMC.hpp +++ b/Protocols/MaliciousShamirMC.hpp @@ -38,16 +38,24 @@ typename T::open_type MaliciousShamirMC::finalize_open() shares.resize(2 * threshold + 1); for (size_t j = 0; j < shares.size(); j++) shares[j].unpack((*this->os)[j]); + return reconstruct(shares); +} + +template +typename T::open_type MaliciousShamirMC::reconstruct( + const vector& shares) +{ + int threshold = ShamirMachine::s().threshold; typename T::open_type value = 0; for (int j = 0; j < threshold + 1; j++) value += shares[j] * reconstructions[threshold + 1][j]; - for (int j = threshold + 2; j <= 2 * threshold + 1; j++) + for (size_t j = threshold + 2; j <= shares.size(); j++) { typename T::open_type check = 0; - for (int k = 0; k < j; k++) + for (size_t k = 0; k < j; k++) check += shares[k] * reconstructions[j][k]; if (check != value) - throw mac_fail(); + throw mac_fail("inconsistent Shamir secret sharing"); } return value; } diff --git a/Protocols/MaliciousShamirPO.h b/Protocols/MaliciousShamirPO.h new file mode 100644 index 000000000..65003d108 --- /dev/null +++ b/Protocols/MaliciousShamirPO.h @@ -0,0 +1,29 @@ +/* + * MaliciousShamirPO.h + * + */ + +#ifndef PROTOCOLS_MALICIOUSSHAMIRPO_H_ +#define PROTOCOLS_MALICIOUSSHAMIRPO_H_ + +template +class MaliciousShamirPO +{ + Player& P; + + octetStream to_send; + vector to_receive; + + vector shares; + MaliciousShamirMC MC; + +public: + MaliciousShamirPO(Player& P); + + void prepare_sending(const T& secret, int player); + void send(int player); + void receive(); + typename T::clear finalize(const T& secret); +}; + +#endif /* PROTOCOLS_MALICIOUSSHAMIRPO_H_ */ diff --git a/Protocols/MaliciousShamirPO.hpp b/Protocols/MaliciousShamirPO.hpp new file mode 100644 index 000000000..7a21f3847 --- /dev/null +++ b/Protocols/MaliciousShamirPO.hpp @@ -0,0 +1,47 @@ +/* + * MaliciousShamirPO.cpp + * + */ + +#include "MaliciousShamirPO.h" + +template +MaliciousShamirPO::MaliciousShamirPO(Player& P) : + P(P), shares(P.num_players()) +{ +} + +template +void MaliciousShamirPO::prepare_sending(const T& secret, int) +{ + secret.pack(to_send); +} + +template +void MaliciousShamirPO::send(int player) +{ + P.send_to(player, to_send, true); +} + +template +void MaliciousShamirPO::receive() +{ + to_receive.resize(P.num_players()); + for (int i = 0; i < P.num_players(); i++) + if (i != P.my_num()) + P.receive_player(i, to_receive[i], true); +} + +template +typename T::clear MaliciousShamirPO::finalize(const T& secret) +{ + for (int i = 0; i < P.num_players(); i++) + { + if (i == P.my_num()) + shares[i] = secret; + else + shares[i].unpack(to_receive[i]); + } + + return MC.reconstruct(shares); +} diff --git a/Protocols/MaliciousShamirShare.h b/Protocols/MaliciousShamirShare.h index 6f5404db9..56086fc78 100644 --- a/Protocols/MaliciousShamirShare.h +++ b/Protocols/MaliciousShamirShare.h @@ -12,6 +12,7 @@ template class MaliciousRepPrepWithBits; template class MaliciousRepPrep; +template class MaliciousShamirPO; namespace GC { @@ -29,6 +30,7 @@ class MaliciousShamirShare : public ShamirShare typedef MAC_Check Direct_MC; typedef ShamirInput Input; typedef ::PrivateOutput PrivateOutput; + typedef MaliciousShamirPO PO; typedef ShamirShare Honest; typedef MaliciousRepPrepWithBits LivePrep; typedef MaliciousRepPrep TriplePrep; diff --git a/Protocols/MamaShare.h b/Protocols/MamaShare.h index 5eadcbf99..71be6097c 100644 --- a/Protocols/MamaShare.h +++ b/Protocols/MamaShare.h @@ -48,7 +48,12 @@ class MamaShare : public Share_, FixedVec, N>> return "Mama" + to_string(N); } - static void read_or_generate_mac_key(string, Names&, mac_key_type& key) + static string type_short() + { + return string(1, T::type_char()); + } + + static void read_or_generate_mac_key(string, Player&, mac_key_type& key) { SeededPRNG G; key.randomize(G); diff --git a/Protocols/Rep3Share.h b/Protocols/Rep3Share.h index 440ec19e3..017901a70 100644 --- a/Protocols/Rep3Share.h +++ b/Protocols/Rep3Share.h @@ -16,15 +16,95 @@ template class ReplicatedPrep; template class ReplicatedRingPrep; template class PrivateOutput; -template -class Rep3Share : public FixedVec, public ShareInterface +template +class RepShare : public FixedVec, public ShareInterface { + typedef RepShare This; + typedef FixedVec super; + public: typedef T clear; typedef T open_type; typedef T mac_type; typedef T mac_key_type; + const static bool needs_ot = false; + const static bool dishonest_majority = false; + const static bool expensive = false; + + static int threshold(int) + { + return 1; + } + + RepShare() + { + } + template + RepShare(const U& other) : + super(other) + { + } + + void add(const This& x, const This& y) + { + *this = x + y; + } + void sub(const This& x, const This& y) + { + *this = x - y; + } + + template + void add(const U& S, const clear aa, int my_num, + const T& alphai) + { + (void)alphai; + *this = S + S.constant(aa, my_num); + } + template + void sub(const U& S, const clear& aa, int my_num, + const T& alphai) + { + (void)alphai; + *this = S - S.constant(aa, my_num); + } + template + void sub(const clear& aa, const U& S, int my_num, + const T& alphai) + { + (void)alphai; + *this = S.constant(aa, my_num) - S; + } + + void mul_by_bit(const This& x, const T& y) + { + (void) x, (void) y; + throw runtime_error("multiplication by bit not implemented"); + } + + void pack(octetStream& os, bool full = true) const + { + if (full) + FixedVec::pack(os); + else + (*this)[0].pack(os); + } + void unpack(octetStream& os, bool full = true) + { + assert(full); + FixedVec::unpack(os); + } +}; + +template +class Rep3Share : public RepShare +{ + typedef RepShare super; + +public: + typedef T clear; + typedef Replicated Protocol; typedef ReplicatedMC MAC_Check; typedef MAC_Check Direct_MC; @@ -34,11 +114,14 @@ class Rep3Share : public FixedVec, public ShareInterface typedef ReplicatedRingPrep TriplePrep; typedef Rep3Share Honest; + typedef Rep3Share Scalar; + typedef GC::SemiHonestRepSecret bit_type; const static bool needs_ot = false; const static bool dishonest_majority = false; const static bool expensive = false; + const static bool variable_players = false; static string type_short() { @@ -48,10 +131,9 @@ class Rep3Share : public FixedVec, public ShareInterface { return "replicated " + T::type_string(); } - - static int threshold(int) + static char type_char() { - return 1; + return T::type_char(); } static Rep3Share constant(T value, int my_num, const T& alphai = {}) @@ -63,9 +145,9 @@ class Rep3Share : public FixedVec, public ShareInterface { } template - Rep3Share(const FixedVec& other) + Rep3Share(const U& other) : + super(other) { - FixedVec::operator=(other); } Rep3Share(T value, int my_num, const T& alphai = {}) @@ -85,59 +167,11 @@ class Rep3Share : public FixedVec, public ShareInterface FixedVec::assign(buffer); } - void add(const Rep3Share& x, const Rep3Share& y) - { - *this = x + y; - } - void sub(const Rep3Share& x, const Rep3Share& y) - { - *this = x - y; - } - - void add(const Rep3Share& S, const clear aa, int my_num, - const T& alphai) - { - (void)alphai; - *this = S + Rep3Share(aa, my_num); - } - void sub(const Rep3Share& S, const clear& aa, int my_num, - const T& alphai) - { - (void)alphai; - *this = S - Rep3Share(aa, my_num); - } - void sub(const clear& aa, const Rep3Share& S, int my_num, - const T& alphai) - { - (void)alphai; - *this = Rep3Share(aa, my_num) - S; - } - clear local_mul(const Rep3Share& other) const { - T a, b; - a.mul((*this)[0], other.sum()); - b.mul((*this)[1], other[0]); - return a + b; - } - - void mul_by_bit(const Rep3Share& x, const T& y) - { - (void) x, (void) y; - throw runtime_error("multiplication by bit not implemented"); - } - - void pack(octetStream& os, bool full = true) const - { - if (full) - FixedVec::pack(os); - else - (*this)[0].pack(os); - } - void unpack(octetStream& os, bool full = true) - { - assert(full); - FixedVec::unpack(os); + auto a = (*this)[0].lazy_mul(other.lazy_sum()); + auto b = (*this)[1].lazy_mul(other[0]); + return a.lazy_add(b); } }; diff --git a/Protocols/Rep3Share2k.h b/Protocols/Rep3Share2k.h index d3545eb73..7c2bc8666 100644 --- a/Protocols/Rep3Share2k.h +++ b/Protocols/Rep3Share2k.h @@ -17,6 +17,7 @@ template class Rep3Share2 : public Rep3Share> { typedef Z2 T; + typedef Rep3Share2 This; public: typedef Replicated Protocol; @@ -119,6 +120,17 @@ class Rep3Share2 : public Rep3Share> } } } + + template + static void shrsi(SubProcessor& proc, const Instruction& inst) + { + for (int i = 0; i < inst.get_size(); i++) + { + auto& dest = proc.get_S_ref(inst.get_r(0) + i); + auto& source = proc.get_S_ref(inst.get_r(1) + i); + dest = source >> inst.get_n(); + } + } }; #endif /* PROTOCOLS_REP3SHARE2K_H_ */ diff --git a/Protocols/Rep4.h b/Protocols/Rep4.h new file mode 100644 index 000000000..1be903057 --- /dev/null +++ b/Protocols/Rep4.h @@ -0,0 +1,71 @@ +/* + * Rep4.h + * + */ + +#ifndef PROTOCOLS_REP4_H_ +#define PROTOCOLS_REP4_H_ + +#include "Replicated.h" + +template +class Rep4 : public ProtocolBase +{ + friend class Rep4RingPrep; + + typedef typename T::open_type open_type; + + array os; + array, 4> send_hashes, receive_hashes; + + array, 5> add_shares; + vector bit_lengths; + + class ResTuple + { + public: + T res; + open_type r; + }; + + PointerVector results; + + int my_num; + + array get_addshares(const T& x, const T& y); + + void reset_joint_input(int n_inputs); + void prepare_joint_input(int sender, int backup, int receiver, + int outsider, vector& inputs); + void finalize_joint_input(int sender, int backup, int receiver, + int outsider); + + int get_player(int offset); + +public: + array, 3> rep_prngs; + Player& P; + + Rep4(Player& P); + + void init_mul(SubProcessor* proc = 0); + void init_mul(Preprocessing& prep, typename T::MAC_Check& MC); + typename T::clear prepare_mul(const T& x, const T& y, int n = -1); + void exchange(); + T finalize_mul(int n = -1); + void check(); + + void init_dotprod(SubProcessor* proc); + void prepare_dotprod(const T& x, const T& y); + void next_dotprod(); + T finalize_dotprod(int length); + + T get_random(); + void randoms(T& res, int n_bits); + + void trunc_pr(const vector& regs, int size, SubProcessor& proc); + + int get_n_relevant_players() { return 2; } +}; + +#endif /* PROTOCOLS_REP4_H_ */ diff --git a/Protocols/Rep4.hpp b/Protocols/Rep4.hpp new file mode 100644 index 000000000..54673ef2f --- /dev/null +++ b/Protocols/Rep4.hpp @@ -0,0 +1,398 @@ +/* + * Rep4.cpp + * + */ + +#include "Rep4.h" +#include "Processor/TruncPrTuple.h" + +template +Rep4::Rep4(Player& P) : + my_num(P.my_num()), P(P) +{ + assert(P.num_players() == 4); + + rep_prngs[0].ReSeed(); + for (int i = 1; i < 3; i++) + { + octetStream os; + os.append(rep_prngs[0].get_seed(), SEED_SIZE); + P.pass_around(os, -i); + rep_prngs[i].SetSeed(os.get_data()); + } +} + +template +void Rep4::init_mul(SubProcessor*) +{ + for (auto& x : add_shares) + x.clear(); + bit_lengths.clear(); +} + +template +void Rep4::init_mul(Preprocessing&, typename T::MAC_Check&) +{ + init_mul(); +} + +template +void Rep4::reset_joint_input(int n_inputs) +{ + results.clear(); + results.resize(n_inputs); + bit_lengths.clear(); + bit_lengths.resize(n_inputs, -1); +} + +template +void Rep4::prepare_joint_input(int sender, int backup, int receiver, + int outsider, vector& inputs) +{ + if (P.my_num() != receiver) + { + int index = P.get_offset(receiver) - 1; + for (auto& x : results) + { + x.r = rep_prngs[index].get(); + x.res[index] += x.r; + } + + if (P.my_num() == sender or P.my_num() == backup) + { + int offset = P.get_offset(outsider) - 1; + for (size_t i = 0; i < results.size(); i++) + { + auto& input = inputs[i]; + input -= results[i].r; + results[i].res[offset] += input; + } + } + } + + if (P.my_num() == backup) + { + send_hashes[sender][receiver].update(inputs); + } + + if (sender == P.my_num()) + { + assert(inputs.size() == bit_lengths.size()); + switch (P.get_offset(backup)) + { + case 2: + for (size_t i = 0; i < inputs.size(); i++) + inputs[i].pack(os[1], bit_lengths[i]); + break; + case 1: + for (size_t i = 0; i < inputs.size(); i++) + inputs[i].pack(os[0], bit_lengths[i]); + break; + default: + throw not_implemented(); + } + } +} + +template +void Rep4::finalize_joint_input(int sender, int backup, int receiver, + int) +{ + if (P.my_num() == receiver) + { + assert(results.size() == bit_lengths.size()); + T res; + switch (P.get_offset(backup)) + { + case 2: + receive_hashes[sender][backup].update(os[0].get_data_ptr(), + results.size() * open_type::size()); + for (size_t i = 0; i < results.size(); i++) + { + auto& x = results[i]; + res[2].unpack(os[0], bit_lengths[i]); + x.res[2] += res[2]; + } + break; + default: + receive_hashes[sender][backup].update(os[1].get_data_ptr(), + results.size() * open_type::size()); + for (size_t i = 0; i < results.size(); i++) + { + auto& x = results[i]; + res[1].unpack(os[1], bit_lengths[i]); + x.res[1] += res[1]; + } + break; + } + } +} + +template +int Rep4::get_player(int offset) +{ + return (my_num + offset) & 3; +} + +template +typename T::clear Rep4::prepare_mul(const T& x, const T& y, int n_bits) +{ + auto a = get_addshares(x, y); + for (int i = 0; i < 5; i++) + add_shares[i].push_back(a[i]); + bit_lengths.push_back(n_bits); + return {}; +} + +template +array Rep4::get_addshares(const T& x, const T& y) +{ + array res; + for (int i = 0; i < 2; i++) + res[get_player(i - 1)] = + (x[i] + x[i + 1]) * y[i] + x[i] * y[i + 1]; + res[4] = x[0] * y[2] + x[2] * y[0]; + return res; +} + +template +void Rep4::init_dotprod(SubProcessor*) +{ + init_mul(); + next_dotprod(); +} + +template +void Rep4::prepare_dotprod(const T& x, const T& y) +{ + auto a = get_addshares(x, y); + for (int i = 0; i < 5; i++) + add_shares[i].back() += a[i]; +} + +template +void Rep4::next_dotprod() +{ + for (auto& a : add_shares) + a.push_back({}); + bit_lengths.push_back(-1); +} + +template +void Rep4::exchange() +{ + for (auto& o : os) + o.reset_write_head(); + auto& a = add_shares; + results.clear(); + results.resize(a[4].size()); + prepare_joint_input(0, 1, 3, 2, a[0]); + prepare_joint_input(1, 2, 0, 3, a[1]); + prepare_joint_input(2, 3, 1, 0, a[2]); + prepare_joint_input(3, 0, 2, 1, a[3]); + prepare_joint_input(0, 2, 3, 1, a[4]); + prepare_joint_input(1, 3, 2, 0, a[4]); + P.pass_around(os[0], -1); + if (P.my_num() < 2) + P.send_to(3 - P.my_num(), os[1], true); + else + P.receive_player(3 - P.my_num(), os[1], true); + finalize_joint_input(0, 1, 3, 2); + finalize_joint_input(1, 2, 0, 3); + finalize_joint_input(2, 3, 1, 0); + finalize_joint_input(3, 0, 2, 1); + finalize_joint_input(0, 2, 3, 1); + finalize_joint_input(1, 3, 2, 0); +} + +template +T Rep4::finalize_mul(int) +{ + this->counter++; + return results.next().res; +} + +template +T Rep4::finalize_dotprod(int) +{ + this->counter++; + return finalize_mul(); +} + +template +void Rep4::check() +{ + for (int i = 1; i < 4; i++) + { + for (int j = 0; j < 4; j++) + { + octetStream os; + send_hashes[j][P.get_player(i)].final(os); + P.pass_around(os, i); + if (receive_hashes[j][P.get_player(-i)].final() != os) + throw runtime_error( + "hash mismatch for sender " + to_string(j) + + " and backup " + to_string(P.get_player(-i))); + } + } +} + +template +T Rep4::get_random() +{ + T res; + for (int i = 0; i < 3; i++) + res[i].randomize(rep_prngs[i]); + return res; +} + +template +void Rep4::randoms(T& res, int n_bits) +{ + for (int i = 0; i < 3; i++) + res[i].randomize_part(rep_prngs[i], n_bits); +} + +template +void Rep4::trunc_pr(const vector& regs, int size, + SubProcessor& proc) +{ + assert(regs.size() % 4 == 0); + typedef typename T::open_type open_type; + + vector> infos; + for (size_t i = 0; i < regs.size(); i += 4) + infos.push_back({regs, i}); + + PointerVector rs(size * infos.size()); + for (int i = 2; i < 4; i++) + { + int index = P.get_offset(i) - 1; + if (index >= 0) + for (auto& r : rs) + r[index].randomize(rep_prngs[index]); + } + + vector cs; + cs.reserve(rs.size()); + for (auto& info : infos) + { + for (int j = 0; j < size; j++) + cs.push_back(proc.get_S_ref(info.source_base + j) + rs.next()); + } + + octetStream c_os; + vector eval_inputs; + if (P.my_num() < 2) + { + if (P.my_num() == 0) + for (auto& c : cs) + (c[1] + c[2]).pack(c_os); + else + for (auto& c : cs) + (c[1] + c[0]).pack(c_os); + P.send_to(2 + P.my_num(), c_os, true); + P.send_to(3 - P.my_num(), c_os.hash(), true); + } + else + { + P.receive_player(P.my_num() - 2, c_os, true); + octetStream hash; + P.receive_player(3 - P.my_num(), hash, true); + if (hash != c_os.hash()) + throw runtime_error("hash mismatch in joint message passing"); + PointerVector open_cs; + if (P.my_num() == 2) + for (auto& c : cs) + open_cs.push_back(c_os.get() + c[1] + c[2]); + else + for (auto& c : cs) + open_cs.push_back(c_os.get() + c[1] + c[0]); + for (auto& info : infos) + for (int j = 0; j < size; j++) + { + auto c = open_cs.next(); + auto c_prime = info.upper(c); + if (not info.big_gap()) + { + auto c_msb = info.msb(c); + eval_inputs.push_back(c_msb); + } + eval_inputs.push_back(c_prime); + } + } + + PointerVector inputs; + bool generate = proc.P.my_num() < 2; + if (generate) + { + inputs.reserve(2 * rs.size()); + rs.reset(); + for (auto& info : infos) + for (int j = 0; j < size; j++) + { + auto& r = rs.next(); + if (not info.big_gap()) + inputs.push_back(info.msb(r.sum())); + inputs.push_back(info.upper(r.sum())); + } + } + + for (auto& o : os) + o.clear(); + size_t n_inputs = max(inputs.size(), eval_inputs.size()); + reset_joint_input(n_inputs); + prepare_joint_input(0, 1, 3, 2, inputs); + if (P.my_num() == 0) + P.send_to(3, os[0], true); + else if (P.my_num() == 3) + P.receive_player(0, os[0], true); + finalize_joint_input(0, 1, 3, 2); + PointerVector gen_results; + for (auto& x : results) + gen_results.push_back(x.res); + + for (auto& o : os) + o.clear(); + reset_joint_input(n_inputs); + prepare_joint_input(2, 3, 1, 0, eval_inputs); + if (P.my_num() == 2) + P.send_to(1, os[0], true); + else if (P.my_num() == 1) + P.receive_player(2, os[0], true); + finalize_joint_input(2, 3, 1, 0); + PointerVector eval_results; + for (auto& x : results) + eval_results.push_back(x.res); + + init_mul(); + for (auto& info : infos) + for (int j = 0; j < size; j++) + { + if (not info.big_gap()) + prepare_mul(gen_results.next(), eval_results.next()); + gen_results.next(); + eval_results.next(); + } + + if (not add_shares[0].empty()) + exchange(); + + eval_results.reset(); + gen_results.reset(); + + for (auto& info : infos) + for (int j = 0; j < size; j++) + { + if (info.big_gap()) + proc.get_S_ref(info.dest_base + j) = eval_results.next() + - gen_results.next(); + else + { + auto b = gen_results.next() + eval_results.next() + - 2 * finalize_mul(); + proc.get_S_ref(info.dest_base + j) = eval_results.next() + - gen_results.next() + (b << (info.k - info.m)); + } + } +} diff --git a/Protocols/Rep4Input.h b/Protocols/Rep4Input.h new file mode 100644 index 000000000..11da428fc --- /dev/null +++ b/Protocols/Rep4Input.h @@ -0,0 +1,38 @@ +/* + * Rep4Input.h + * + */ + +#ifndef PROTOCOLS_REP4INPUT_H_ +#define PROTOCOLS_REP4INPUT_H_ + +#include "ReplicatedInput.h" + +template +class Rep4Input : public InputBase +{ + Rep4 protocol; + Player& P; + + octetStream to_send; + array to_receive; + + array, 4> results; + +public: + Rep4Input(SubProcessor& proc, MAC_Check_Base&); + Rep4Input(MAC_Check_Base&, Preprocessing&, Player& P); + + void reset(int player); + + void add_mine(const typename T::open_type& input, int n_bits = -1); + void add_other(int player); + + void send_mine(); + void exchange(); + + T finalize_mine(); + void finalize_other(int player, T& target, octetStream& o, int n_bits = -1); +}; + +#endif /* PROTOCOLS_REP4INPUT_H_ */ diff --git a/Protocols/Rep4Input.hpp b/Protocols/Rep4Input.hpp new file mode 100644 index 000000000..bb72888be --- /dev/null +++ b/Protocols/Rep4Input.hpp @@ -0,0 +1,102 @@ +/* + * Rep4Input.cpp + * + */ + +#include "Rep4Input.h" + +template +Rep4Input::Rep4Input(SubProcessor& proc, MAC_Check_Base&) : + InputBase(&proc), protocol(proc.P), P(proc.P) +{ + assert(P.num_players() == 4); +} + +template +Rep4Input::Rep4Input(MAC_Check_Base&, Preprocessing&, Player& P) : + protocol(P), P(P) +{ +} + +template +void Rep4Input::reset(int player) +{ + if (player == P.my_num()) + to_send.reset_write_head(); +} + +template +void Rep4Input::add_mine(const typename T::open_type& input, int) +{ + auto& prot = protocol; + T res; + res[0] = prot.rep_prngs[0].get(); + res[2] = prot.rep_prngs[2].get(); + res[1] = input - res[0] - res[2]; + res[1].pack(to_send); + results[P.my_num()].push_back(res); +} + +template +void Rep4Input::add_other(int player) +{ + auto& prot = protocol; + T res; + switch (P.get_offset(player)) + { + case 1: + case 3: + res[1] = prot.rep_prngs[1].get(); + break; + case 2: + for (int i = 0; i < 3; i += 2) + res[i] = prot.rep_prngs[i].get(); + break; + default: + throw out_of_range("wrong player number"); + } + results[player].push_back(res); +} + +template +void Rep4Input::send_mine() +{ + throw not_implemented(); +} + +template +void Rep4Input::exchange() +{ + P.pass_around(to_send, to_receive[0], -1); + P.pass_around(to_send, to_receive[1], 1); + octetStream os[2][2]; + for (int i = 0; i < 2; i++) + { + os[0][i] = to_receive[i].hash(); + P.pass_around(os[0][i], os[1][i], 2); + } + for (int i = 0; i < 2; i++) + if (os[0][i] != os[1][1 - i]) + throw mac_fail(); +} + +template +T Rep4Input::finalize_mine() +{ + return results[P.my_num()].next(); +} + +template +void Rep4Input::finalize_other(int player, T& target, octetStream&, int) +{ + target = results[player].next(); + switch (P.get_offset(player)) + { + case 1: + target[2].unpack(to_receive[0]); + break; + case 3: + target[0].unpack(to_receive[1]); + break; + } +} diff --git a/Protocols/Rep4MC.h b/Protocols/Rep4MC.h new file mode 100644 index 000000000..6f35ef6b6 --- /dev/null +++ b/Protocols/Rep4MC.h @@ -0,0 +1,30 @@ +/* + * Rep4MC.h + * + */ + +#ifndef PROTOCOLS_REP4MC_H_ +#define PROTOCOLS_REP4MC_H_ + +#include "MAC_Check_Base.h" + +template +class Rep4MC : public MAC_Check_Base +{ + Hash check_hash, receive_hash; + +public: + Rep4MC(typename T::mac_key_type = {}, int = 0, int = 0) + { + } + + void exchange(const Player& P); + void Check(const Player& P); + + Rep4MC& get_part_MC() + { + return *this; + } +}; + +#endif /* PROTOCOLS_REP4MC_H_ */ diff --git a/Protocols/Rep4MC.hpp b/Protocols/Rep4MC.hpp new file mode 100644 index 000000000..22dceff11 --- /dev/null +++ b/Protocols/Rep4MC.hpp @@ -0,0 +1,44 @@ +/* + * Rep4MC.hpp + * + */ + +#ifndef PROTOCOLS_REP4MC_HPP_ +#define PROTOCOLS_REP4MC_HPP_ + +#include "Rep4MC.h" + +template +void Rep4MC::exchange(const Player& P) +{ + octetStream right, tmp; + for (auto& secret : this->secrets) + { + secret[0].pack(right); + secret[2].pack(tmp); + } + check_hash.update(tmp); + P.pass_around(right, 1); + this->values.resize(this->secrets.size()); + for (size_t i = 0; i < this->secrets.size(); i++) + { + typename T::open_type a, b; + a.unpack(right); + this->values[i] = this->secrets[i].sum() + a; + } + receive_hash.update(right); +} + +template +void Rep4MC::Check(const Player& P) +{ + octetStream left; + check_hash.final(left); + P.pass_around(left, -1); + octetStream os; + receive_hash.final(os); + if (os != left) + throw mac_fail(); +} + +#endif /* PROTOCOLS_REP4MC_HPP_ */ diff --git a/Protocols/Rep4Prep.h b/Protocols/Rep4Prep.h new file mode 100644 index 000000000..d33df2440 --- /dev/null +++ b/Protocols/Rep4Prep.h @@ -0,0 +1,63 @@ +/* + * Rep4Prep.h + * + */ + +#ifndef PROTOCOLS_REP4PREP_H_ +#define PROTOCOLS_REP4PREP_H_ + +#include "MaliciousRingPrep.hpp" +#include "MalRepRingPrep.h" +#include "RepRingOnlyEdabitPrep.h" + +template +class Rep4RingPrep : public MaliciousRingPrep +{ + void buffer_triples(); + void buffer_squares(); + void buffer_bits(); + void buffer_inputs(int player); + +public: + Rep4RingPrep(SubProcessor* proc, DataPositions& usage); +}; + +template +class Rep4Prep : public Rep4RingPrep +{ + void buffer_inverses(); + +public: + Rep4Prep(SubProcessor* proc, DataPositions& usage); +}; + +template +class Rep4RingOnlyPrep : public virtual Rep4RingPrep, + public virtual RepRingOnlyEdabitPrep +{ + void buffer_edabits(int n_bits, ThreadQueues* queues) + { + RepRingOnlyEdabitPrep::buffer_edabits(n_bits, queues); + } + + void buffer_edabits(bool strict, int n_bits, ThreadQueues* queues) + { + BufferPrep::buffer_edabits(strict, n_bits, queues); + } + + void buffer_sedabits(int n_bits, ThreadQueues*) + { + this->buffer_sedabits_from_edabits(n_bits); + } + +public: + Rep4RingOnlyPrep(SubProcessor* proc, DataPositions& usage); + + void get_dabit_no_count(T& a, typename T::bit_type& b) + { + this->get_one_no_count(DATA_BIT, a); + b = a & 1; + } +}; + +#endif /* PROTOCOLS_REP4PREP_H_ */ diff --git a/Protocols/Rep4Prep.hpp b/Protocols/Rep4Prep.hpp new file mode 100644 index 000000000..abd28a08c --- /dev/null +++ b/Protocols/Rep4Prep.hpp @@ -0,0 +1,133 @@ +/* + * Rep4Prep.hpp + * + */ + +#ifndef PROTOCOLS_REP4PREP_HPP_ +#define PROTOCOLS_REP4PREP_HPP_ + +#include "Rep4Prep.h" + +template +Rep4RingPrep::Rep4RingPrep(SubProcessor* proc, DataPositions& usage) : + BufferPrep(usage), BitPrep(proc, usage), + RingPrep(proc, usage), MaliciousRingPrep(proc, usage) +{ +} + +template +Rep4Prep::Rep4Prep(SubProcessor* proc, DataPositions& usage) : + BufferPrep(usage), BitPrep(proc, usage), + RingPrep(proc, usage), Rep4RingPrep(proc, usage) +{ +} + +template +Rep4RingOnlyPrep::Rep4RingOnlyPrep(SubProcessor* proc, + DataPositions& usage) : + BufferPrep(usage), BitPrep(proc, usage), + RingPrep(proc, usage), Rep4RingPrep(proc, usage), + RepRingOnlyEdabitPrep(proc, usage) +{ +} + +template +void Rep4RingPrep::buffer_inputs(int player) +{ + auto prot = this->protocol; + assert(prot != 0); + for (int i = 0; i < OnlineOptions::singleton.batch_size; i++) + { + T res; + for (int j = 0; j < 3; j++) + if (prot->P.get_offset(player - j) != 1) + res[j].randomize(prot->rep_prngs[j]); + this->inputs[player].push_back({res, res.sum()}); + } +} + +template +void Rep4RingPrep::buffer_triples() +{ + generate_triples(this->triples, OnlineOptions::singleton.batch_size, + this->protocol); +} + +template +void Rep4RingPrep::buffer_squares() +{ + generate_squares(this->squares, OnlineOptions::singleton.batch_size, + this->protocol, this->proc); +} + +template +void Rep4RingPrep::buffer_bits() +{ + assert(this->proc != 0); + SeededPRNG G; + octetStream os; + Player& P = this->proc->P; + if (P.my_num() % 2 == 0) + { + os.append(G.get_seed(), SEED_SIZE); + P.send_relative(1, os); + } + else + { + P.receive_relative(-1, os); + G.SetSeed(os.consume(SEED_SIZE)); + } + + auto& protocol = this->proc->protocol; + + protocol.init_mul(); + vector bits; + int batch_size = OnlineOptions::singleton.batch_size; + bits.reserve(batch_size); + for (int i = 0; i < batch_size; i++) + bits.push_back(G.get_bit()); + + protocol.init_mul(); + for (auto& o : protocol.os) + o.reset_write_head(); + protocol.reset_joint_input(batch_size); + protocol.prepare_joint_input(0, 1, 3, 2, bits); + if (P.my_num() == 0) + P.send_relative(-1, protocol.os[0]); + if (P.my_num() == 3) + P.receive_relative(1, protocol.os[0]); + protocol.finalize_joint_input(0, 1, 3, 2); + auto a = protocol.results; + + protocol.init_mul(); + for (auto& o : protocol.os) + o.reset_write_head(); + protocol.reset_joint_input(batch_size); + protocol.prepare_joint_input(2, 3, 1, 0, bits); + if (P.my_num() == 2) + P.send_relative(-1, protocol.os[0]); + if (P.my_num() == 1) + P.receive_relative(1, protocol.os[0]); + protocol.finalize_joint_input(2, 3, 1, 0); + auto b = protocol.results; + + auto results = protocol.results; + protocol.init_mul(); + for (int i = 0; i < batch_size; i++) + protocol.prepare_mul(a[i].res, b[i].res); + protocol.exchange(); + typedef typename T::clear clear; + clear two = clear(1) + clear(1); + for (int i = 0; i < batch_size; i++) + this->bits.push_back( + a[i].res + b[i].res - two * protocol.finalize_mul()); +} + +template +void Rep4Prep::buffer_inverses() +{ + assert(this->proc != 0); + ::buffer_inverses(this->inverses, *this, this->proc->MC, this->proc->P); +} + +#endif /* PROTOCOLS_REP4PREP_HPP_ */ diff --git a/Protocols/Rep4Share.h b/Protocols/Rep4Share.h new file mode 100644 index 000000000..d9e6eccee --- /dev/null +++ b/Protocols/Rep4Share.h @@ -0,0 +1,72 @@ +/* + * Rep4Share.h + * + */ + +#ifndef PROTOCOLS_REP4SHARE_H_ +#define PROTOCOLS_REP4SHARE_H_ + +#include "Rep3Share.h" +#include "Processor/NoLivePrep.h" + +template class Rep4MC; +template class Rep4; +template class Rep4Prep; +template class Rep4Input; + +namespace GC +{ +class Rep4Secret; +} + +template +class Rep4Share : public RepShare +{ + typedef Rep4Share This; + typedef RepShare super; + +public: + typedef T clear; + + typedef Rep4 Protocol; + typedef Rep4MC MAC_Check; + typedef MAC_Check Direct_MC; + typedef Rep4Input Input; + typedef ::PrivateOutput PrivateOutput; + typedef Rep4Prep LivePrep; + typedef LivePrep SquarePrep; + + typedef GC::Rep4Secret bit_type; + + static string type_short() + { + return "R4" + string(1, T::type_char()); + } + + static This constant(clear value, int my_num, typename super::mac_key_type = {}) + { + This res; + if (my_num != 0) + res[3 - my_num] = value; + return res; + } + + Rep4Share() + { + } + Rep4Share(const FixedVec& other) : super(other) + { + } + + void assign(clear value, int my_num, clear = {}) + { + *this = constant(value, my_num); + } + void assign(const char* buffer) + { + super::assign(buffer); + } + +}; + +#endif /* PROTOCOLS_REP4SHARE_H_ */ diff --git a/Protocols/Rep4Share2k.h b/Protocols/Rep4Share2k.h new file mode 100644 index 000000000..9e27c3ea6 --- /dev/null +++ b/Protocols/Rep4Share2k.h @@ -0,0 +1,80 @@ +/* + * Rep4Share.h + * + */ + +#ifndef PROTOCOLS_REP4SHARE2K_H_ +#define PROTOCOLS_REP4SHARE2K_H_ + +#include "Rep4Share.h" +#include "Processor/NoLivePrep.h" +#include "Processor/DummyProtocol.h" +#include "GC/square64.h" + +template class Rep4MC; +template class Rep4Input; +template class Rep4RingOnlyPrep; + +template +class Rep4Share2 : public Rep4Share> +{ + typedef Rep4Share2 This; + typedef Rep4Share> super; + +public: + typedef SignedZ2 clear; + typedef Rep4Share> SquareToBitShare; + + typedef Rep4 Protocol; + typedef Rep4MC MAC_Check; + typedef MAC_Check Direct_MC; + typedef Rep4Input Input; + typedef ::PrivateOutput PrivateOutput; + typedef Rep4RingOnlyPrep LivePrep; + + Rep4Share2() + { + } + template + Rep4Share2(const FixedVec& other) : super(other) + { + } + + template + static void split(vector& dest, const vector& regs, + int n_bits, const Rep4Share2* source, int n_inputs, Player& P) + { + int my_num = P.my_num(); + assert(n_bits <= 64); + assert(regs.size() / n_bits == 4); + int unit = GC::Clear::N_BITS; + for (int k = 0; k < DIV_CEIL(n_inputs, unit); k++) + { + int start = k * unit; + int m = min(unit, n_inputs - start); + + for (int i = 0; i < n_bits; i++) + dest.at(regs.at(4 * i + my_num) + k) = {}; + + for (int i = 0; i < 3; i++) + { + square64 square; + + for (int j = 0; j < m; j++) + square.rows[j] = Integer(source[j + start][i]).get(); + + square.transpose(m, n_bits); + + for (int j = 0; j < n_bits; j++) + { + auto &dest_reg = dest.at( + regs.at(4 * j + ((my_num + i + 1) % 4)) + k); + dest_reg = {}; + dest_reg[i] = square.rows[j]; + } + } + } + } +}; + +#endif /* PROTOCOLS_REP4SHARE2K_H_ */ diff --git a/Protocols/RepRingOnlyEdabitPrep.h b/Protocols/RepRingOnlyEdabitPrep.h new file mode 100644 index 000000000..b4a4b894b --- /dev/null +++ b/Protocols/RepRingOnlyEdabitPrep.h @@ -0,0 +1,24 @@ +/* + * RepRingOnlyEdabitPrep.h + * + */ + +#ifndef PROTOCOLS_REPRINGONLYEDABITPREP_H_ +#define PROTOCOLS_REPRINGONLYEDABITPREP_H_ + +#include "ReplicatedPrep.h" + +template +class RepRingOnlyEdabitPrep : public virtual BufferPrep +{ +protected: + void buffer_edabits(int n_bits, ThreadQueues*); + +public: + RepRingOnlyEdabitPrep(SubProcessor*, DataPositions& usage) : + BufferPrep(usage) + { + } +}; + +#endif /* PROTOCOLS_REPRINGONLYEDABITPREP_H_ */ diff --git a/Protocols/RepRingOnlyEdabitPrep.hpp b/Protocols/RepRingOnlyEdabitPrep.hpp new file mode 100644 index 000000000..78721e0d9 --- /dev/null +++ b/Protocols/RepRingOnlyEdabitPrep.hpp @@ -0,0 +1,52 @@ +/* + * RepRingOnlyEdabitPrep.cpp + * + */ + +#include "RepRingOnlyEdabitPrep.h" +#include "GC/BitAdder.h" +#include "Processor/Instruction.h" + +template +void RepRingOnlyEdabitPrep::buffer_edabits(int n_bits, ThreadQueues*) +{ + assert(this->proc); + int dl = T::bit_type::default_length; + int buffer_size = DIV_CEIL(this->buffer_size, dl) * dl; + vector wholes; + wholes.resize(buffer_size); + Instruction inst; + inst.r[0] = 0; + inst.n = n_bits; + inst.size = buffer_size; + this->proc->protocol.randoms_inst(wholes, inst); + + auto& P = this->proc->P; + vector regs(P.num_players() * n_bits); + for (size_t i = 0; i < regs.size(); i++) + regs[i] = i * buffer_size / dl; + typedef typename T::bit_type bt; + vector bits(n_bits * P.num_players() * buffer_size); + T::split(bits, regs, n_bits, wholes.data(), wholes.size(), this->proc->P); + + BitAdder bit_adder; + vector>> summands; + for (int i = 0; i < n_bits; i++) + { + summands.push_back({}); + auto& x = summands.back(); + for (int j = 0; j < P.num_players(); j++) + { + x.push_back({}); + auto& y = x.back(); + for (int k = 0; k < buffer_size / dl; k++) + y.push_back(bits.at(k + buffer_size / dl * (j + P.num_players() * i))); + } + } + vector> sums(buffer_size / dl); + auto &party = GC::ShareThread::s(); + SubProcessor bit_proc(party.MC->get_part_MC(), this->proc->bit_prep, P); + bit_adder.multi_add(sums, summands, 0, buffer_size / dl, bit_proc, dl, 0); + + this->push_edabits(this->edabits[{false, n_bits}], wholes, sums, buffer_size); +} diff --git a/Protocols/Replicated.h b/Protocols/Replicated.h index 5da6c9e9d..4b4e76845 100644 --- a/Protocols/Replicated.h +++ b/Protocols/Replicated.h @@ -44,6 +44,11 @@ class ReplicatedBase template class ProtocolBase { + virtual void buffer_random() { not_implemented(); } + +protected: + vector random; + public: typedef T share_type; @@ -59,6 +64,8 @@ class ProtocolBase void multiply(vector& products, vector>& multiplicands, int begin, int end, SubProcessor& proc); + T mul(const T& x, const T& y); + virtual void init_mul(SubProcessor* proc) = 0; virtual typename T::clear prepare_mul(const T& x, const T& y, int n = -1) = 0; virtual void exchange() = 0; @@ -69,11 +76,13 @@ class ProtocolBase void next_dotprod() {} T finalize_dotprod(int length); + virtual T get_random(); + virtual void trunc_pr(const vector& regs, int size, SubProcessor& proc) { (void) regs, (void) size; (void) proc; throw runtime_error("trunc_pr not implemented"); } virtual void randoms(T&, int) { throw runtime_error("randoms not implemented"); } - virtual void randoms_inst(SubProcessor&, const Instruction&); + virtual void randoms_inst(vector&, const Instruction&); virtual void start_exchange() { exchange(); } virtual void stop_exchange() {} @@ -115,7 +124,8 @@ class Replicated : public ReplicatedBase, public ProtocolBase void prepare_reshare(const typename T::clear& share, int n = -1); - void init_dotprod(SubProcessor* proc); + void init_dotprod(SubProcessor*) { init_mul(); } + void init_dotprod(); void prepare_dotprod(const T& x, const T& y); void next_dotprod(); T finalize_dotprod(int length); diff --git a/Protocols/Replicated.hpp b/Protocols/Replicated.hpp index 7262367b5..d209a1c5e 100644 --- a/Protocols/Replicated.hpp +++ b/Protocols/Replicated.hpp @@ -8,6 +8,7 @@ #include "Replicated.h" #include "Processor/Processor.h" +#include "Processor/TruncPrTuple.h" #include "Tools/benchmarking.h" #include "SemiShare.h" @@ -64,7 +65,7 @@ inline ReplicatedBase ReplicatedBase::branch() template ProtocolBase::~ProtocolBase() { -#ifdef VERBOSE +#ifdef VERBOSE_COUNT if (counter) cerr << "Number of " << T::type_string() << " multiplications: " << counter << endl; #endif @@ -90,7 +91,7 @@ void ProtocolBase::multiply(vector& products, vector >& multiplicands, int begin, int end, SubProcessor& proc) { -#ifdef VERBOSE +#ifdef VERBOSE_CENTRAL fprintf(stderr, "multiply from %d to %d in %d\n", begin, end, BaseMachine::thread_num); #endif @@ -103,6 +104,15 @@ void ProtocolBase::multiply(vector& products, products[i] = finalize_mul(); } +template +T ProtocolBase::mul(const T& x, const T& y) +{ + init_mul(0); + prepare_mul(x, y); + exchange(); + return finalize_mul(); +} + template T ProtocolBase::finalize_dotprod(int length) { @@ -113,6 +123,17 @@ T ProtocolBase::finalize_dotprod(int length) return res; } +template +T ProtocolBase::get_random() +{ + if (random.empty()) + buffer_random(); + + auto res = random.back(); + random.pop_back(); + return res; +} + template void Replicated::init_mul(SubProcessor* proc) { @@ -161,7 +182,8 @@ inline void Replicated::prepare_reshare(const typename T::clear& share, template void Replicated::exchange() { - P.pass_around(os[0], os[1], 1); + if (os[0].get_length() > 0) + P.pass_around(os[0], os[1], 1); } template @@ -179,6 +201,7 @@ void Replicated::stop_exchange() template inline T Replicated::finalize_mul(int n) { + this->counter++; T result; result[0] = add_shares.next(); result[1].unpack(os[1], n); @@ -186,21 +209,22 @@ inline T Replicated::finalize_mul(int n) } template -inline void Replicated::init_dotprod(SubProcessor* proc) +inline void Replicated::init_dotprod() { - init_mul(proc); + init_mul(); dotprod_share.assign_zero(); } template inline void Replicated::prepare_dotprod(const T& x, const T& y) { - dotprod_share += x.local_mul(y); + dotprod_share = dotprod_share.lazy_add(x.local_mul(y)); } template inline void Replicated::next_dotprod() { + dotprod_share.normalize(); prepare_reshare(dotprod_share); dotprod_share.assign_zero(); } @@ -209,7 +233,6 @@ template inline T Replicated::finalize_dotprod(int length) { (void) length; - this->counter++; return finalize_mul(); } @@ -223,12 +246,12 @@ T Replicated::get_random() } template -void ProtocolBase::randoms_inst(SubProcessor& proc, +void ProtocolBase::randoms_inst(vector& S, const Instruction& instruction) { for (int j = 0; j < instruction.get_size(); j++) { - auto& res = proc.get_S_ref(instruction.get_r(0) + j); + auto& res = S[instruction.get_r(0) + j]; randoms(res, instruction.get_n()); } } @@ -255,23 +278,17 @@ void trunc_pr(const vector& regs, int size, octetStream os[2]; for (size_t i = 0; i < regs.size(); i += 4) { - int k = regs[i + 2]; - int m = regs[i + 3]; - int n_shift = value_type::N_BITS - 1 - k; - assert(m < k); - assert(0 < k); - assert(m < value_type::N_BITS); + TruncPrTuple info(regs, i); for (int l = 0; l < size; l++) { auto& res = proc.get_S_ref(regs[i] + l); auto& G = proc.Proc->secure_prng; auto mask = G.template get(); - auto unmask = (mask << (n_shift + 1)) >> (n_shift + m + 1); + auto unmask = info.upper(mask); T shares[4]; shares[0].randomize_to_sum(mask, G); shares[1].randomize_to_sum(unmask, G); - shares[2].randomize_to_sum( - (mask << (n_shift)) >> (value_type::N_BITS - 1), G); + shares[2].randomize_to_sum(info.msb(mask), G); res.randomize(G); shares[3] = res; for (int i = 0; i < 2; i++) diff --git a/Protocols/ReplicatedInput.h b/Protocols/ReplicatedInput.h index f38a9588b..d1e588334 100644 --- a/Protocols/ReplicatedInput.h +++ b/Protocols/ReplicatedInput.h @@ -66,6 +66,8 @@ class ReplicatedInput : public PrepLessInput PrepLessInput(proc), proc(proc), P(P), protocol(P) { assert(T::length == 2); + InputBase::P = &P; + InputBase::os.resize(P.num_players()); expect.resize(P.num_players()); } diff --git a/Protocols/ReplicatedInput.hpp b/Protocols/ReplicatedInput.hpp index 3daedf0f6..0e826a19d 100644 --- a/Protocols/ReplicatedInput.hpp +++ b/Protocols/ReplicatedInput.hpp @@ -14,6 +14,7 @@ template void ReplicatedInput::reset(int player) { + InputBase::reset(player); assert(P.num_players() == 3); if (player == P.my_num()) { diff --git a/Protocols/ReplicatedMachine.h b/Protocols/ReplicatedMachine.h index 2da565fe1..823cc032c 100644 --- a/Protocols/ReplicatedMachine.h +++ b/Protocols/ReplicatedMachine.h @@ -9,6 +9,9 @@ #include using namespace std; +#include "Tools/ezOptionParser.h" +#include "Processor/OnlineOptions.h" + template class ReplicatedMachine { diff --git a/Protocols/ReplicatedPrep.h b/Protocols/ReplicatedPrep.h index e5c66e886..7fb454728 100644 --- a/Protocols/ReplicatedPrep.h +++ b/Protocols/ReplicatedPrep.h @@ -20,6 +20,9 @@ template void buffer_inverses(vector>& inverses, Preprocessing& prep, MAC_Check_Base& MC, Player& P); +template +void bits_from_random(vector& bits, typename T::Protocol& protocol); + template class BufferPrep : public Preprocessing { @@ -33,8 +36,6 @@ class BufferPrep : public Preprocessing vector>> inputs; vector> dabits; - map, vector>> edabits; - map, edabitvec> my_edabits; int n_bit_rounds; @@ -62,6 +63,9 @@ class BufferPrep : public Preprocessing virtual void buffer_personal_dabits(int) { throw runtime_error("no personal daBits"); } + void push_edabits(vector>& edabits, + const vector& sums, const vector>& bits, + int buffer_size); public: typedef T share_type; @@ -85,11 +89,7 @@ class BufferPrep : public Preprocessing T get_random_from_inputs(int nplayers); - virtual void get_dabit(T& a, typename T::bit_type& b); virtual void get_dabit_no_count(T& a, typename T::bit_type& b); - virtual void get_edabits(bool strict, size_t size, T* a, - vector& Sb, const vector& regs); - virtual void get_edabit_no_count(bool strict, int n_bits, edabit& a); void push_triples(const vector>& triples) { this->triples.insert(this->triples.end(), triples.begin(), triples.end()); } @@ -146,6 +146,8 @@ class RingPrep : public virtual BitPrep virtual void buffer_sedabits_from_edabits(int); + void sanitize(vector>& edabits, int n_bits); + public: RingPrep(SubProcessor* proc, DataPositions& usage); virtual ~RingPrep() {} diff --git a/Protocols/ReplicatedPrep.hpp b/Protocols/ReplicatedPrep.hpp index b35d8cec4..533c133e5 100644 --- a/Protocols/ReplicatedPrep.hpp +++ b/Protocols/ReplicatedPrep.hpp @@ -18,7 +18,6 @@ #include "ShuffleSacrifice.hpp" #include "GC/ShareThread.hpp" #include "GC/BitAdder.hpp" -#include "Processor/Processor.hpp" template BufferPrep::BufferPrep(DataPositions& usage) : @@ -33,20 +32,22 @@ BufferPrep::~BufferPrep() { #ifdef VERBOSE if (n_bit_rounds > 0) - cerr << n_bit_rounds << " rounds of random bit generation" << endl; + cerr << n_bit_rounds << " rounds of random " << T::type_string() + << " bit generation" << endl; #define X(KIND) \ if (KIND.size()) \ - cerr << "\t" << KIND.size() << " " #KIND " left" << endl; + cerr << "\t" << KIND.size() << " " #KIND " of " << T::type_string() \ + << " left" << endl; X(triples) X(squares) X(inverses) X(bits) X(dabits) #undef X - for (auto& x : edabits) + for (auto& x : this->edabits) { if (not x.second.empty()) { cerr << "\t~" << x.second.size() * x.second[0].size(); - if (x.first.first) + if (not x.first.first) cerr << " loose"; cerr << " edaBits of size " << x.first.second << " left" << endl; } @@ -100,8 +101,15 @@ template void generate_triples(vector>& triples, int n_triples, U* protocol, int n_bits = -1) { - triples.resize(n_triples); protocol->init_mul(); + generate_triples_initialized(triples, n_triples, protocol, n_bits); +} + +template +void generate_triples_initialized(vector>& triples, int n_triples, + U* protocol, int n_bits = -1) +{ + triples.resize(n_triples); for (size_t i = 0; i < triples.size(); i++) { auto& triple = triples[i]; @@ -136,7 +144,7 @@ template void BitPrep::buffer_squares() { auto proc = this->proc; - auto buffer_size = OnlineOptions::singleton.batch_size; + auto buffer_size = this->buffer_size; assert(proc != 0); vector a_plus_b(buffer_size), as(buffer_size), cs(buffer_size); T b; @@ -154,11 +162,16 @@ void BitPrep::buffer_squares() template void ReplicatedRingPrep::buffer_squares() { - auto protocol = this->protocol; - auto proc = this->proc; + generate_squares(this->squares, this->buffer_size, + this->protocol, this->proc); +} + +template +void generate_squares(vector>& squares, int n_squares, + U* protocol, SubProcessor* proc) +{ assert(protocol != 0); - auto& squares = this->squares; - squares.resize(OnlineOptions::singleton.batch_size); + squares.resize(n_squares); protocol->init_mul(proc); for (size_t i = 0; i < squares.size(); i++) { @@ -461,7 +474,7 @@ void buffer_bits_from_players(vector>& player_bits, auto& P = protocol.P; int n_relevant_players = protocol.get_n_relevant_players(); player_bits.resize(n_relevant_players, vector(buffer_size)); - typename T::Input input(proc, proc.MC); + auto& input = proc.input; input.reset_all(P); for (int i = 0; i < n_relevant_players; i++) { @@ -659,6 +672,19 @@ void RingPrep::buffer_edabits_without_check(int n_bits, vector>& vector> bits; vector sums; buffer_edabits_without_check(n_bits, sums, bits, buffer_size); + this->push_edabits(edabits, sums, bits, buffer_size); + (void) stat; +#ifdef VERBOSE_PREP + cerr << "edaBit generation" << endl; + (proc->P.comm_stats - stat).print(true); +#endif +} + +template +void BufferPrep::push_edabits(vector>& edabits, + const vector& sums, const vector>& bits, + int buffer_size) +{ int unit = T::bit_type::part_type::default_length; edabits.reserve(edabits.size() + DIV_CEIL(buffer_size, unit)); for (int i = 0; i < buffer_size; i++) @@ -667,32 +693,22 @@ void RingPrep::buffer_edabits_without_check(int n_bits, vector>& edabits.push_back(bits.at(i / unit)); edabits.back().push_a(sums[i]); } - (void) stat; -#ifdef VERBOSE_PREP - cerr << "edaBit generation" << endl; - (proc->P.comm_stats - stat).print(true); -#endif } template void RingPrep::buffer_sedabits_from_edabits(int n_bits) { assert(this->proc != 0); - int buffer_size = OnlineOptions::singleton.batch_size; - vector> edabits; - for (int i = 0; i < buffer_size; i++) - { - edabits.push_back({}); - auto& x = edabits.back(); - this->get_edabit_no_count(false, n_bits, x); - } - sanitize(edabits, n_bits); - for (auto& x : edabits) + size_t buffer_size = OnlineOptions::singleton.batch_size; + auto& loose = this->edabits[{false, n_bits}]; + while (loose.size() < size_t(DIV_CEIL(buffer_size, edabitvec::MAX_SIZE))) + this->buffer_edabits(false, n_bits); + sanitize(loose, n_bits); + for (auto& x : loose) { - assert(x.second.size() >= (size_t)n_bits); - x.second.resize(n_bits); this->edabits[{true, n_bits}].push_back(x); } + loose.clear(); } template @@ -762,14 +778,69 @@ void RingPrep::sanitize(vector>& edabits, int n_bits, int player, delete &MCB; } +template +void RingPrep::sanitize(vector>& edabits, int n_bits) +{ + vector dabits; + typedef typename T::bit_type::part_type BT; + vector to_open; + for (auto& x : edabits) + { + for (size_t j = n_bits; j < x.b.size(); j++) + { + BT bits; + for (size_t i = 0; i < x.size(); i++) + { + T a; + typename T::bit_type b; + this->get_dabit_no_count(a, b); + dabits.push_back(a); + bits ^= BT(b) << i; + } + to_open.push_back(x.b[j] + bits); + } + } + vector opened; + auto& MCB = *BT::new_mc( + GC::ShareThread::s().MC->get_alphai()); + MCB.POpen(opened, to_open, this->proc->P); + auto dit = dabits.begin(); + auto oit = opened.begin(); + for (auto& x : edabits) + { + for (size_t j = n_bits; j < x.b.size(); j++) + { + auto masked = (*oit++); + for (size_t i = 0; i < x.size(); i++) + { + int masked_bit = masked.get_bit(i); + auto& mask = *dit++; + auto overflow = mask + + T::constant(masked_bit, this->proc->P.my_num(), + this->proc->MC.get_alphai()) + - mask * typename T::clear(masked_bit * 2); + x.a[i] -= overflow << j; + } + } + } + MCB.Check(this->proc->P); + delete &MCB; +} + template<> inline void SemiHonestRingPrep>::buffer_bits() { assert(protocol != 0); - for (int i = 0; i < DIV_CEIL(buffer_size, gf2n::degree()); i++) + bits_from_random(bits, *protocol); +} + +template +void bits_from_random(vector& bits, typename T::Protocol& protocol) +{ + while (bits.size() < (size_t)OnlineOptions::singleton.batch_size) { - Rep3Share share = protocol->get_random(); + Rep3Share share = protocol.get_random(); for (int j = 0; j < gf2n::degree(); j++) { bits.push_back(share & 1); @@ -829,7 +900,9 @@ template void BufferPrep::get_input_no_count(T& a, typename T::open_type& x, int i) { (void) a, (void) x, (void) i; - if (inputs.size() <= (size_t)i or inputs.at(i).empty()) + if (inputs.size() <= (size_t)i) + inputs.resize(i + 1); + if (inputs.at(i).empty()) buffer_inputs(i); a = inputs[i].back().share; x = inputs[i].back().value; @@ -861,16 +934,16 @@ void BufferPrep::get_personal_dabit(int player, T& a, typename T::bit_type& b } template -void BufferPrep::get_dabit(T& a, typename T::bit_type& b) +void Preprocessing::get_dabit(T& a, typename T::bit_type& b) { get_dabit_no_count(a, b); this->count(DATA_DABIT); } template -void BufferPrep::get_edabit_no_count(bool strict, int n_bits, edabit& a) +void Preprocessing::get_edabit_no_count(bool strict, int n_bits, edabit& a) { - auto& buffer = edabits[{strict, n_bits}]; + auto& buffer = this->edabits[{strict, n_bits}]; auto& my_edabit = my_edabits[{strict, n_bits}]; if (my_edabit.empty()) { @@ -892,7 +965,7 @@ void BufferPrep::buffer_edabits_with_queues(bool strict, int n_bits) } template -void BufferPrep::get_edabits(bool strict, size_t size, T* a, +void Preprocessing::get_edabits(bool strict, size_t size, T* a, vector& Sb, const vector& regs) { int n_bits = regs.size(); @@ -913,7 +986,7 @@ void BufferPrep::get_edabits(bool strict, size_t size, T* a, { for (size_t i = k * unit; i < min(size, (k + 1) * unit); i++) { - get_edabit_no_count(strict, n_bits, eb); + this->get_edabit_no_count(strict, n_bits, eb); a[i] = eb.first; for (int j = 0; j < n_bits; j++) { diff --git a/Protocols/Semi2k.h b/Protocols/Semi2k.h new file mode 100644 index 000000000..927bb6c17 --- /dev/null +++ b/Protocols/Semi2k.h @@ -0,0 +1,28 @@ +/* + * Semi2k.h + * + */ + +#ifndef PROTOCOLS_SEMI2K_H_ +#define PROTOCOLS_SEMI2K_H_ + +#include "SPDZ.h" + +template +class Semi2k : public SPDZ +{ + SeededPRNG G; + +public: + Semi2k(Player& P) : + SPDZ(P) + { + } + + void randoms(T& res, int n_bits) + { + res.randomize_part(G, n_bits); + } +}; + +#endif /* PROTOCOLS_SEMI2K_H_ */ diff --git a/Protocols/Semi2kShare.h b/Protocols/Semi2kShare.h index b28f6c826..bbb48534f 100644 --- a/Protocols/Semi2kShare.h +++ b/Protocols/Semi2kShare.h @@ -7,8 +7,11 @@ #define PROTOCOLS_SEMI2KSHARE_H_ #include "SemiShare.h" +#include "Semi2k.h" #include "OT/Rectangle.h" #include "GC/SemiSecret.h" +#include "GC/square64.h" +#include "Processor/Instruction.h" template class SemiPrep2k; @@ -24,7 +27,7 @@ class Semi2kShare : public SemiShare> typedef DirectSemiMC Direct_MC; typedef SemiInput Input; typedef ::PrivateOutput PrivateOutput; - typedef SPDZ Protocol; + typedef Semi2k Protocol; typedef SemiPrep2k LivePrep; typedef Semi2kShare prep_type; @@ -46,6 +49,53 @@ class Semi2kShare : public SemiShare> (void) alphai; assign(other, my_num); } + + template + static void split(vector& dest, const vector& regs, + int n_bits, const Semi2kShare* source, int n_inputs, Player& P) + { + int my_num = P.my_num(); + assert(n_bits <= 64); + int unit = GC::Clear::N_BITS; + for (int k = 0; k < DIV_CEIL(n_inputs, unit); k++) + { + int start = k * unit; + int m = min(unit, n_inputs - start); + int n = regs.size() / n_bits; + if (P.num_players() != n) + throw runtime_error( + to_string(n) + "-way split not working with " + + to_string(P.num_players()) + " parties"); + + for (int i = 0; i < n_bits; i++) + for (int j = 0; j < n; j++) + dest.at(regs.at(n * i + j) + k) = {}; + + square64 square; + + for (int j = 0; j < m; j++) + square.rows[j] = Integer(source[j + start]).get(); + + square.transpose(m, n_bits); + + for (int j = 0; j < n_bits; j++) + { + auto& dest_reg = dest.at(regs.at(n * j + my_num) + k); + dest_reg = square.rows[j]; + } + } + } + + template + static void shrsi(SubProcessor& proc, const Instruction& inst) + { + for (int i = 0; i < inst.get_size(); i++) + { + auto& dest = proc.get_S_ref(inst.get_r(0) + i); + auto& source = proc.get_S_ref(inst.get_r(1) + i); + dest = source >> inst.get_n(); + } + } }; #endif /* PROTOCOLS_SEMI2KSHARE_H_ */ diff --git a/Protocols/SemiPrep2k.h b/Protocols/SemiPrep2k.h index 23769b133..727b5290d 100644 --- a/Protocols/SemiPrep2k.h +++ b/Protocols/SemiPrep2k.h @@ -7,17 +7,34 @@ #define PROTOCOLS_SEMIPREP2K_H_ #include "SemiPrep.h" +#include "RepRingOnlyEdabitPrep.h" template -class SemiPrep2k : public SemiPrep +class SemiPrep2k : public SemiPrep, public RepRingOnlyEdabitPrep { + void buffer_edabits(int n_bits, ThreadQueues* queues) + { + RepRingOnlyEdabitPrep::buffer_edabits(n_bits, queues); + } + + void buffer_edabits(bool strict, int n_bits, ThreadQueues* queues) + { + BufferPrep::buffer_edabits(strict, n_bits, queues); + } + + void buffer_sedabits(int n_bits, ThreadQueues*) + { + this->buffer_sedabits_from_edabits(n_bits); + } + public: SemiPrep2k(SubProcessor* proc, DataPositions& usage) : BufferPrep(usage), BitPrep(proc, usage), OTPrep(proc, usage), RingPrep(proc, usage), SemiHonestRingPrep(proc, usage), - SemiPrep(proc, usage) + SemiPrep(proc, usage), + RepRingOnlyEdabitPrep(proc, usage) { } diff --git a/Protocols/Shamir.h b/Protocols/Shamir.h index 43e2b5e99..d167cd835 100644 --- a/Protocols/Shamir.h +++ b/Protocols/Shamir.h @@ -32,6 +32,8 @@ class Shamir : public ProtocolBase vector random; + vector> hyper; + typename T::open_type dotprod_share; void buffer_random(); @@ -79,7 +81,7 @@ class Shamir : public ProtocolBase T finalize(int n_input_players); - void init_dotprod(SubProcessor* proc); + void init_dotprod(SubProcessor* proc = 0); void prepare_dotprod(const T& x, const T& y); void next_dotprod(); T finalize_dotprod(int length); diff --git a/Protocols/Shamir.hpp b/Protocols/Shamir.hpp index ccdd83568..8ee0d2153 100644 --- a/Protocols/Shamir.hpp +++ b/Protocols/Shamir.hpp @@ -191,15 +191,40 @@ T Shamir::get_random() template void Shamir::buffer_random() { - Shamir shamir(P); - shamir.reset(); + if (hyper.empty()) + { + int n = P.num_players(); + for (int i = 0; i < n - threshold; i++) + { + hyper.push_back({}); + for (int j = 0; j < n; j++) + { + hyper.back().push_back({1}); + for (int k = 0; k < n; k++) + if (k != j) + hyper.back().back() *= U(n + i - k) / U(j - k); + } + } + } + + ShamirInput input(0, P); int buffer_size = OnlineOptions::singleton.batch_size; - if (P.my_num() <= threshold) - for (int i = 0; i < buffer_size; i++) - shamir.resharing->add_mine(secure_prng.get()); - shamir.exchange(); - for (int i = 0; i < buffer_size; i++) - random.push_back(shamir.finalize(threshold + 1)); + for (int i = 0; i < buffer_size; i += hyper.size()) + input.add_mine(secure_prng.get()); + input.exchange(); + vector inputs; + for (int i = 0; i < buffer_size; i += hyper.size()) + { + inputs.clear(); + for (int j = 0; j < P.num_players(); j++) + inputs.push_back(input.finalize(j)); + for (size_t j = 0; j < hyper.size(); j++) + { + random.push_back({}); + for (int k = 0; k < P.num_players(); k++) + random.back() += hyper[j][k] * inputs[k]; + } + } } #endif diff --git a/Protocols/ShamirInput.h b/Protocols/ShamirInput.h index 70ae633b5..9acc0499c 100644 --- a/Protocols/ShamirInput.h +++ b/Protocols/ShamirInput.h @@ -21,9 +21,10 @@ class IndividualInput : public PrepLessInput IndividualInput(SubProcessor* proc, Player& P) : PrepLessInput(proc), P(P) { + this->reset_all(P); } IndividualInput(SubProcessor& proc) : - PrepLessInput(&proc), P(proc.P) + IndividualInput(&proc , proc.P) { } diff --git a/Protocols/ShamirMC.hpp b/Protocols/ShamirMC.hpp index 5cff9bfe4..fb7f93377 100644 --- a/Protocols/ShamirMC.hpp +++ b/Protocols/ShamirMC.hpp @@ -3,6 +3,9 @@ * */ +#ifndef PROTOCOLS_SHAMIRMC_HPP_ +#define PROTOCOLS_SHAMIRMC_HPP_ + #include "ShamirMC.h" template @@ -114,3 +117,5 @@ typename T::open_type ShamirMC::finalize_open() return res; } + +#endif diff --git a/Protocols/ShamirShare.h b/Protocols/ShamirShare.h index 6491ce335..949234ec9 100644 --- a/Protocols/ShamirShare.h +++ b/Protocols/ShamirShare.h @@ -28,6 +28,7 @@ class ShamirShare : public T, public ShareInterface typedef T mac_key_type; typedef void sacri_type; typedef GC::NoShare mac_type; + typedef GC::NoShare mac_share_type; typedef Shamir Protocol; typedef ShamirMC MAC_Check; diff --git a/Protocols/Share.h b/Protocols/Share.h index 3349685e6..69b57dedd 100644 --- a/Protocols/Share.h +++ b/Protocols/Share.h @@ -41,9 +41,11 @@ class Share_ : public ShareInterface public: + typedef T part_type; typedef V mac_key_type; typedef V mac_type; typedef T share_type; + typedef V mac_share_type; typedef typename T::open_type open_type; typedef typename T::clear clear; @@ -56,9 +58,6 @@ class Share_ : public ShareInterface static int size() { return T::size() + V::size(); } - static string type_short() - { return string(1, T::type_char()); } - static char type_char() { return T::type_char(); } @@ -69,7 +68,7 @@ class Share_ : public ShareInterface { return T::threshold(nplayers); } template - static void read_or_generate_mac_key(string directory, const Names& N, + static void read_or_generate_mac_key(string directory, const Player& P, U& key); static Share_ constant(const clear& aa, int my_num, const typename V::Scalar& alphai) @@ -129,7 +128,7 @@ class Share_ : public ShareInterface template Share_& operator*=(const U& x) { mul(*this, x); return *this; } - Share_ operator<<(int i) { return this->operator*(T(1) << i); } + Share_ operator<<(int i) { return this->operator*(clear(1) << i); } Share_& operator<<=(int i) { return *this = *this << i; } Share_ operator>>(int i) const { return {a >> i, mac >> i}; } @@ -182,6 +181,9 @@ class Share : public Share_, SemiShare> static const bool expensive = true; + static string type_short() + { return string(1, T::type_char()); } + static string type_string() { return "SPDZ " + T::type_string(); } diff --git a/Protocols/Share.hpp b/Protocols/Share.hpp index 834a1984d..58bb7085b 100644 --- a/Protocols/Share.hpp +++ b/Protocols/Share.hpp @@ -4,12 +4,12 @@ template template -void Share_::read_or_generate_mac_key(string directory, const Names& N, +void Share_::read_or_generate_mac_key(string directory, const Player& P, U& key) { try { - read_mac_key(directory, N, key); + read_mac_key(directory, P.N, key); } catch (mac_key_error&) { diff --git a/Protocols/ShareInterface.h b/Protocols/ShareInterface.h index a0b954759..9bb25c5d5 100644 --- a/Protocols/ShareInterface.h +++ b/Protocols/ShareInterface.h @@ -32,10 +32,14 @@ class ShareInterface static void split(vector, vector, int, T*, int, Player&) { throw runtime_error("split not implemented"); } + template + static void shrsi(T&, const Instruction&) + { throw runtime_error("shrsi not implemented"); } + static bool get_rec_factor(int, int) { return false; } template - static void read_or_generate_mac_key(const string&, const Names&, T&) {} + static void read_or_generate_mac_key(const string&, const Player&, T&) {} template static void generate_mac_key(T&, U&) {} diff --git a/Protocols/ShuffleSacrifice.h b/Protocols/ShuffleSacrifice.h index f54afef34..ed54e2ec8 100644 --- a/Protocols/ShuffleSacrifice.h +++ b/Protocols/ShuffleSacrifice.h @@ -12,14 +12,12 @@ using namespace std; #include "Tools/FixedVector.h" #include "edabit.h" +#include "dabit.h" class Player; template class LimitedPrep; -template -using dabit = pair; - template class ShuffleSacrifice { diff --git a/Protocols/ShuffleSacrifice.hpp b/Protocols/ShuffleSacrifice.hpp index 86bac99af..00ee693eb 100644 --- a/Protocols/ShuffleSacrifice.hpp +++ b/Protocols/ShuffleSacrifice.hpp @@ -291,7 +291,7 @@ void ShuffleSacrifice::edabit_sacrifice_buckets(vector>& to_check, if (supply) { auto& triples = *(vector>*)supply; -#ifdef VERBOSE +#ifdef VERBOSE_EDA fprintf(stderr, "got %zu supplies\n", triples.size()); #endif if (player < 0) diff --git a/Protocols/Spdz2kPrep.hpp b/Protocols/Spdz2kPrep.hpp index d46a4f52e..9aa5df481 100644 --- a/Protocols/Spdz2kPrep.hpp +++ b/Protocols/Spdz2kPrep.hpp @@ -7,6 +7,7 @@ #define PROTOCOLS_SPDZ2KPREP_HPP_ #include "Spdz2kPrep.h" +#include "GC/BitAdder.h" #include "DabitSacrifice.hpp" #include "RingOnlyPrep.hpp" @@ -67,17 +68,16 @@ void MaliciousRingPrep::buffer_bits() RingPrep::buffer_bits_without_check(); assert(this->protocol != 0); auto& protocol = *this->protocol; - protocol.init_mul(this->proc); + protocol.init_dotprod(this->proc); auto one = T::constant(1, protocol.P.my_num(), this->proc->MC.get_alphai()); + GlobalPRNG G(protocol.P); for (auto& bit : this->bits) // one of the two is not a zero divisor, so if the product is zero, one of them is too - protocol.prepare_mul(one - bit, bit); + protocol.prepare_dotprod(one - bit, bit * G.get()); + protocol.next_dotprod(); protocol.exchange(); - vector checks; - checks.reserve(this->bits.size()); - for (size_t i = 0; i < this->bits.size(); i++) - checks.push_back(protocol.finalize_mul()); - this->proc->MC.CheckFor(0, checks, protocol.P); + this->proc->MC.CheckFor(0, {protocol.finalize_dotprod(this->bits.size())}, + protocol.P); } template @@ -101,6 +101,7 @@ void bits_from_square_in_ring(vector& bits, int buffer_size, U* bit_prep) auto bit_MC = &bit_proc->MC; vector squares, random_shares; auto one = BitShare::constant(1, bit_proc->P.my_num(), bit_MC->get_alphai()); + bit_prep->buffer_size = buffer_size; for (int i = 0; i < buffer_size; i++) { BitShare a, a2; diff --git a/Protocols/SpdzWise.h b/Protocols/SpdzWise.h new file mode 100644 index 000000000..ebbd0352a --- /dev/null +++ b/Protocols/SpdzWise.h @@ -0,0 +1,58 @@ +/* + * SpdzWise.h + * + */ + +#ifndef PROTOCOLS_SPDZWISE_H_ +#define PROTOCOLS_SPDZWISE_H_ + +#include "Replicated.h" + +template class SpdzWiseInput; + +template +class SpdzWise : public ProtocolBase +{ + typedef typename T::part_type check_type; + + friend class SpdzWiseInput; + + typename T::part_type::Honest::Protocol internal, internal2; + + typename T::mac_key_type mac_key; + + vector results; + + vector coefficients; + + void buffer_random(); + + virtual void zero_check(check_type t); + +public: + Player& P; + + SpdzWise(Player& P); + virtual ~SpdzWise(); + + void init(SubProcessor* proc); + + void init_mul(SubProcessor* proc); + typename T::clear prepare_mul(const T& x, const T& y, int n = -1); + void exchange(); + T finalize_mul(int n = -1); + + void init_dotprod(SubProcessor*); + void prepare_dotprod(const T& x, const T& y); + void next_dotprod(); + T finalize_dotprod(int length); + + void add_to_check(const T& x); + void check(); + + int get_n_relevant_players() { return internal.get_n_relevant_players(); } + + void randoms_inst(vector& S, const Instruction& instruction); +}; + +#endif /* PROTOCOLS_SPDZWISE_H_ */ diff --git a/Protocols/SpdzWise.hpp b/Protocols/SpdzWise.hpp new file mode 100644 index 000000000..bdf2c901e --- /dev/null +++ b/Protocols/SpdzWise.hpp @@ -0,0 +1,182 @@ +/* + * SpdzWise.cpp + * + */ + +#include "SpdzWise.h" + +template +SpdzWise::SpdzWise(Player& P) : + internal(P), internal2(P), P(P) +{ + results.reserve(OnlineOptions::singleton.batch_size); +} + +template +SpdzWise::~SpdzWise() +{ + check(); +} + +template +void SpdzWise::init(SubProcessor* proc) +{ + assert(proc != 0); + mac_key = proc->MC.get_alphai(); + if ((int) results.size() >= OnlineOptions::singleton.batch_size) + check(); +} + +template +void SpdzWise::init_mul(SubProcessor* proc) +{ + init(proc); + internal.init_mul(); + internal2.init_mul(); +} + +template +typename T::clear SpdzWise::prepare_mul(const T& x, const T& y, int) +{ + internal.prepare_mul(x.get_share(), y.get_share()); + internal.prepare_mul(x.get_mac(), y.get_share()); + return {}; +} + +template +T SpdzWise::finalize_mul(int) +{ + T res; + res.set_share(internal.finalize_mul()); + res.set_mac(internal.finalize_mul()); + results.push_back(res); + return res; +} + +template +void SpdzWise::exchange() +{ + internal.exchange(); + internal2.exchange(); +} + +template +void SpdzWise::init_dotprod(SubProcessor* proc) +{ + init(proc); + internal.init_dotprod(); + internal2.init_dotprod(); +} + +template +void SpdzWise::prepare_dotprod(const T& x, const T& y) +{ + internal.prepare_dotprod(x.get_share(), y.get_share()); + internal2.prepare_dotprod(x.get_mac(), y.get_share()); +} + +template +void SpdzWise::next_dotprod() +{ + internal.next_dotprod(); + internal2.next_dotprod(); +} + +template +T SpdzWise::finalize_dotprod(int length) +{ + T res; + res.set_share(internal.finalize_dotprod(length)); + res.set_mac(internal2.finalize_dotprod(length)); + results.push_back(res); + return res; +} + +template +void SpdzWise::add_to_check(const T& x) +{ + results.push_back(x); +} + +template +void SpdzWise::check() +{ + if (results.empty()) + return; + + internal.init_dotprod(); + coefficients.clear(); + + for (auto& res : results) + { + coefficients.push_back(internal.get_random()); + internal.prepare_dotprod(res.get_share(), coefficients.back()); + } + internal.next_dotprod(); + + for (size_t i = 0; i < results.size(); i++) + internal.prepare_dotprod(results[i].get_mac(), coefficients[i]); + internal.next_dotprod(); + + internal.exchange(); + auto w = internal.finalize_dotprod(results.size()); + auto u = internal.finalize_dotprod(results.size()); + auto t = u - internal.mul(mac_key, w); + zero_check(t); + results.clear(); +} + +template +void SpdzWise::zero_check(check_type t) +{ + auto r = internal.get_random(); + internal.init_mul(); + internal.prepare_mul(t, r); + internal.exchange(); + typename T::part_type::MAC_Check MC; + MC.CheckFor(0, {internal.finalize_mul()}, P); +} + +template +void SpdzWise::buffer_random() +{ + // proxy for initialization + assert(mac_key != 0); + int batch_size = OnlineOptions::singleton.batch_size; + vector rs; + rs.reserve(batch_size); + // cannot use member instance + typename T::part_type::Honest::Protocol internal(P); + internal.init_mul(); + for (int i = 0; i < batch_size; i++) + { + rs.push_back(internal.get_random()); + internal.prepare_mul(rs.back(), mac_key); + } + internal.exchange(); + for (int i = 0; i < batch_size; i++) + { + this->random.push_back({rs[i], internal.finalize_mul()}); + results.push_back(this->random.back()); + } +} + +template +void SpdzWise::randoms_inst(vector& S, + const Instruction& instruction) +{ + internal.init_mul(); + for (int i = 0; i < instruction.get_size(); i++) + { + typename T::share_type res; + internal.randoms(res, instruction.get_n()); + internal.prepare_mul(res, mac_key); + S[instruction.get_r(0) + i].set_share(res); + } + internal.exchange(); + for (int i = 0; i < instruction.get_size(); i++) + { + auto& res = S[instruction.get_r(0) + i]; + res.set_mac(internal.finalize_mul()); + } +} diff --git a/Protocols/SpdzWiseInput.h b/Protocols/SpdzWiseInput.h new file mode 100644 index 000000000..b97bfb07b --- /dev/null +++ b/Protocols/SpdzWiseInput.h @@ -0,0 +1,45 @@ +/* + * SpdzWiseInput.h + * + */ + +#ifndef PROTOCOLS_SPDZWISEINPUT_H_ +#define PROTOCOLS_SPDZWISEINPUT_H_ + +#include "ReplicatedInput.h" + +template +class SpdzWiseInput : public InputBase +{ + Player& P; + + typename T::part_type::Input part_input; + typename T::part_type::Honest::Protocol honest_mult; + + typename T::Protocol checker; + SubProcessor* proc; + + typename T::mac_key_type mac_key; + + vector counters; + vector> shares; + +public: + SpdzWiseInput(SubProcessor& proc, Player& P); + SpdzWiseInput(SubProcessor* proc, Player& P); + SpdzWiseInput(SubProcessor& proc, typename T::MAC_Check& MC); + + void reset(int player); + void add_mine(const typename T::open_type& input, int n_bits = -1); + void add_other(int player); + void send_mine(); + void exchange(); + T finalize(int player, int n_bits = -1); + T finalize_mine(); + void finalize_other(int player, T& target, octetStream& o, int n_bits = -1); + + void start(int, int) { throw not_implemented(); } + void stop(int, const vector&) { throw not_implemented(); } +}; + +#endif /* PROTOCOLS_SPDZWISEINPUT_H_ */ diff --git a/Protocols/SpdzWiseInput.hpp b/Protocols/SpdzWiseInput.hpp new file mode 100644 index 000000000..1ad2d5ef3 --- /dev/null +++ b/Protocols/SpdzWiseInput.hpp @@ -0,0 +1,98 @@ +/* + * SpdzWiseInput.cpp + * + */ + +#include "SpdzWiseInput.h" + +template +SpdzWiseInput::SpdzWiseInput(SubProcessor* proc, Player& P) : + InputBase(proc), P(P), part_input(0, P), honest_mult(P), checker(P), proc( + proc), counters(P.num_players()), shares(P.num_players()) +{ + assert(proc != 0); + mac_key = proc->MC.get_alphai(); +} + +template +SpdzWiseInput::SpdzWiseInput(SubProcessor& proc, Player& P) : + SpdzWiseInput(&proc, P) +{ +} + +template +SpdzWiseInput::SpdzWiseInput(SubProcessor& proc, typename T::MAC_Check&) : + SpdzWiseInput(&proc, proc.P) +{ +} + +template +void SpdzWiseInput::reset(int player) +{ + part_input.reset(player); + counters[player] = 0; +} + +template +void SpdzWiseInput::add_mine(const typename T::open_type& input, int n_bits) +{ + part_input.add_mine(input, n_bits); + counters[P.my_num()]++; +} + +template +void SpdzWiseInput::add_other(int player) +{ + part_input.add_other(player); + counters[player]++; +} + +template +void SpdzWiseInput::exchange() +{ + part_input.exchange(); + honest_mult.init_mul(); + for (int i = 0; i < P.num_players(); i++) + { + shares[i].clear(); + for (int j = 0; j < counters[i]; j++) + { + auto s = part_input.finalize(i); + shares[i].push_back({}); + shares[i].back().set_share(s); + honest_mult.prepare_mul(s, mac_key); + } + } + honest_mult.exchange(); + for (int i = 0; i < P.num_players(); i++) + for (int j = 0; j < counters[i]; j++) + { + shares[i][j].set_mac(honest_mult.finalize_mul()); + checker.results.push_back(shares[i][j]); + } + checker.init(proc); +} + +template +T SpdzWiseInput::finalize(int player, int) +{ + return shares[player].next(); +} + +template +void SpdzWiseInput::send_mine() +{ + throw runtime_error("use exchange()"); +} + +template +T SpdzWiseInput::finalize_mine() +{ + throw runtime_error("use finalize()"); +} + +template +void SpdzWiseInput::finalize_other(int, T&, octetStream&, int) +{ + throw runtime_error("use finalize()"); +} diff --git a/Protocols/SpdzWiseMC.h b/Protocols/SpdzWiseMC.h new file mode 100644 index 000000000..9e953e730 --- /dev/null +++ b/Protocols/SpdzWiseMC.h @@ -0,0 +1,63 @@ +/* + * SpdzWiseMC.h + * + */ + +#ifndef PROTOCOLS_SPDZWISEMC_H_ +#define PROTOCOLS_SPDZWISEMC_H_ + +#include "MaliciousRepMC.h" + +template +class SpdzWiseMC : public MAC_Check_Base +{ + vector shares; + + void get_shares(const vector& S) + { + shares.clear(); + for (auto& share : S) + shares.push_back(share.get_share()); + } + +public: + typename T::open_part_type::MAC_Check inner_MC; + + SpdzWiseMC(typename T::mac_key_type mac_key, int = 0, int = 0, int = 0) : + MAC_Check_Base(mac_key) + { + } + SpdzWiseMC(typename T::mac_key_type mac_key, Names&, int) : + MAC_Check_Base(mac_key) + { + } + + void init_open(const Player& P, int n) + { + inner_MC.init_open(P, n); + } + void prepare_open(const T& secret) + { + inner_MC.prepare_open(secret.get_share()); + } + void exchange(const Player& P) + { + inner_MC.exchange(P); + } + typename T::open_type finalize_open() + { + return inner_MC.finalize_open(); + } + void Check(const Player& P) + { + inner_MC.Check(P); + } + void CheckFor(const typename T::open_type& value, const vector& S, + const Player& P) + { + get_shares(S); + inner_MC.CheckFor(value, shares, P); + } +}; + +#endif /* PROTOCOLS_SPDZWISEMC_H_ */ diff --git a/Protocols/SpdzWisePrep.h b/Protocols/SpdzWisePrep.h new file mode 100644 index 000000000..138aa0ca3 --- /dev/null +++ b/Protocols/SpdzWisePrep.h @@ -0,0 +1,31 @@ +/* + * SpdzWisePrep.h + * + */ + +#ifndef PROTOCOLS_SPDZWISEPREP_H_ +#define PROTOCOLS_SPDZWISEPREP_H_ + +#include "ReplicatedPrep.h" + +template +class SpdzWisePrep : public MaliciousRingPrep +{ + typedef MaliciousRingPrep super; + + void buffer_triples(); + void buffer_bits(); + void buffer_inverses(); + + void buffer_inputs(int player); + +public: + SpdzWisePrep(SubProcessor* proc, DataPositions& usage) : + BufferPrep(usage), + BitPrep(proc, usage), RingPrep(proc, usage), + MaliciousRingPrep(proc, usage) + { + } +}; + +#endif /* PROTOCOLS_SPDZWISEPREP_H_ */ diff --git a/Protocols/SpdzWisePrep.hpp b/Protocols/SpdzWisePrep.hpp new file mode 100644 index 000000000..8d49a5c54 --- /dev/null +++ b/Protocols/SpdzWisePrep.hpp @@ -0,0 +1,134 @@ +/* + * SpdzWisePrep.cpp + * + */ + +#include "SpdzWisePrep.h" +#include "SpdzWiseRingPrep.h" +#include "SpdzWiseRingShare.h" +#include "MaliciousShamirShare.h" +#include "Math/gfp.h" + +#include "ReplicatedPrep.hpp" +#include "Spdz2kPrep.hpp" +#include "ShamirMC.hpp" +#include "MaliciousRepPO.hpp" +#include "MaliciousShamirPO.hpp" + +template +void SpdzWisePrep::buffer_triples() +{ + assert(this->protocol != 0); + assert(this->proc != 0); + this->protocol->init_mul(this->proc); + generate_triples_initialized(this->triples, + OnlineOptions::singleton.batch_size, this->protocol); +} + +template<> +inline +void SpdzWisePrep>>::buffer_bits() +{ + MaliciousRingPrep>>::buffer_bits(); +} + +template<> +void SpdzWisePrep>>::buffer_bits() +{ + typedef MaliciousRep3Share part_type; + vector bits; + typename part_type::Honest::Protocol protocol(this->protocol->P); + bits_from_random(bits, protocol); + protocol.init_mul(); + for (auto& bit : bits) + protocol.prepare_mul(bit, this->proc->MC.get_alphai()); + protocol.exchange(); + for (auto& bit : bits) + this->bits.push_back({bit, protocol.finalize_mul()}); +} + +template +void buffer_bits_from_squares_in_ring(vector>& bits, + SubProcessor>* proc) +{ + assert(proc != 0); + typedef SpdzWiseRingShare BitShare; + typename BitShare::MAC_Check MC(proc->MC.get_alphai()); + DataPositions usage; + SpdzWisePrep prep(0, usage); + SubProcessor bit_proc(MC, prep, proc->P, proc->Proc); + prep.set_proc(&bit_proc); + bits_from_square_in_ring(bits, OnlineOptions::singleton.batch_size, &prep); +} + +template +void SpdzWiseRingPrep::buffer_bits() +{ + if (OnlineOptions::singleton.bits_from_squares) + buffer_bits_from_squares_in_ring(this->bits, this->proc); + else + MaliciousRingPrep::buffer_bits(); +} + +template +void SpdzWisePrep::buffer_bits() +{ + throw not_implemented(); +} + +template<> +inline +void SpdzWisePrep>>::buffer_bits() +{ + buffer_bits_from_squares(*this); +} + +template<> +inline +void SpdzWisePrep>>::buffer_bits() +{ + super::buffer_bits(); +} + +template +void SpdzWisePrep::buffer_inverses() +{ + auto protocol = this->protocol; + assert(protocol != 0); + assert(this->proc != 0); + ::buffer_inverses(this->inverses, *this, this->proc->MC, protocol->P); +} + +template +void SpdzWisePrep::buffer_inputs(int player) +{ + assert(this->proc != 0); + assert(this->protocol != 0); + vector rs(OnlineOptions::singleton.batch_size); + auto& P = this->proc->P; + this->inputs.resize(P.num_players()); + this->protocol->init_mul(this->proc); + for (auto& r : rs) + { + r = this->protocol->get_random(); + } + + typename T::part_type::PO output(P); + if (player != P.my_num()) + { + for (auto& r : rs) + { + this->inputs[player].push_back({r, 0}); + output.prepare_sending(r.get_share(), player); + } + output.send(player); + } + else + { + output.receive(); + for (auto& r : rs) + { + this->inputs[player].push_back({r, output.finalize(r.get_share())}); + } + } +} diff --git a/Protocols/SpdzWiseRing.h b/Protocols/SpdzWiseRing.h new file mode 100644 index 000000000..7f83a9a24 --- /dev/null +++ b/Protocols/SpdzWiseRing.h @@ -0,0 +1,30 @@ +/* + * SpdzWiseRing.h + * + */ + +#ifndef PROTOCOLS_SPDZWISERING_H_ +#define PROTOCOLS_SPDZWISERING_H_ + +#include "SpdzWise.h" +#include "PostSacrifice.h" +#include "PostSacriRepRingShare.h" + +template +class SpdzWiseRing : public SpdzWise +{ + typedef typename T::part_type check_type; + typedef PostSacriRepRingShare zero_check_type; + + DataPositions zero_usage; + SimplerMalRepRingPrep zero_prep; + typename zero_check_type::MAC_Check zero_output; + SubProcessor zero_proc; + +public: + SpdzWiseRing(Player &P); + + void zero_check(check_type t); +}; + +#endif /* PROTOCOLS_SPDZWISERING_H_ */ diff --git a/Protocols/SpdzWiseRing.hpp b/Protocols/SpdzWiseRing.hpp new file mode 100644 index 000000000..30904c386 --- /dev/null +++ b/Protocols/SpdzWiseRing.hpp @@ -0,0 +1,51 @@ +/* + * SpdzWiseRing.cpp + * + */ + +#include "SpdzWiseRing.h" + +template +SpdzWiseRing::SpdzWiseRing(Player& P) : + SpdzWise(P), zero_prep(0, zero_usage), zero_proc(zero_output, + zero_prep, P) +{ +} + +template +void SpdzWiseRing::zero_check(check_type t) +{ + int l = T::LENGTH + T::SECURITY; + vector bit_masks(l); + zero_check_type masked = t; + zero_prep.buffer_size = l; + for (int i = 0; i < l; i++) + { + bit_masks[i] = zero_prep.get_bit(); + masked += bit_masks[i] << i; + } + auto& P = this->P; + auto opened = zero_output.open(masked, P); + vector bits(l); + for (int i = 0; i < l; i++) + { + auto b = opened.get_bit(i); + bits[i] = zero_check_type::constant(b, P.my_num()) + bits[i] + - 2 * b * bits[i]; + } + while(bits.size() > 1) + { + auto& protocol = zero_proc.protocol; + protocol.init_mul(&zero_proc); + for (int i = bits.size() - 2; i >= 0; i -= 2) + protocol.prepare_mul(bits[i], bits[i + 1]); + protocol.exchange(); + int n_mults = bits.size() / 2; + bits.resize(bits.size() % 2); + for (int i = 0; i < n_mults; i++) + bits.push_back(protocol.finalize_mul()); + } + zero_output.CheckFor(0, {bits[0]}, P); + zero_output.Check(P); + zero_proc.protocol.check(); +} diff --git a/Protocols/SpdzWiseRingPrep.h b/Protocols/SpdzWiseRingPrep.h new file mode 100644 index 000000000..136000911 --- /dev/null +++ b/Protocols/SpdzWiseRingPrep.h @@ -0,0 +1,48 @@ +/* + * SpdzWiseRingPrep.h + * + */ + +#ifndef PROTOCOLS_SPDZWISERINGPREP_H_ +#define PROTOCOLS_SPDZWISERINGPREP_H_ + +#include "SpdzWisePrep.h" +#include "RepRingOnlyEdabitPrep.h" + +template +class SpdzWiseRingPrep : public virtual SpdzWisePrep, + public virtual RepRingOnlyEdabitPrep +{ + void buffer_bits(); + + void buffer_edabits(int n_bits, ThreadQueues* queues) + { + RepRingOnlyEdabitPrep::buffer_edabits(n_bits, queues); + } + + void buffer_edabits(bool strict, int n_bits, ThreadQueues* queues) + { + BufferPrep::buffer_edabits(strict, n_bits, queues); + } + + void buffer_sedabits(int n_bits, ThreadQueues*) + { + this->buffer_sedabits_from_edabits(n_bits); + } + +public: + SpdzWiseRingPrep(SubProcessor* proc, DataPositions& usage) : + BufferPrep(usage), + BitPrep(proc, usage), RingPrep(proc, usage), + SpdzWisePrep(proc, usage), RepRingOnlyEdabitPrep(proc, usage) + { + } + + void get_dabit_no_count(T& a, typename T::bit_type& b) + { + this->get_one_no_count(DATA_BIT, a); + b = a.get_share() & 1; + } +}; + +#endif /* PROTOCOLS_SPDZWISERINGPREP_H_ */ diff --git a/Protocols/SpdzWiseRingShare.h b/Protocols/SpdzWiseRingShare.h new file mode 100644 index 000000000..e49d5733b --- /dev/null +++ b/Protocols/SpdzWiseRingShare.h @@ -0,0 +1,88 @@ +/* + * SpdzWiseRingShare.h + * + */ + +#ifndef PROTOCOLS_SPDZWISERINGSHARE_H_ +#define PROTOCOLS_SPDZWISERINGSHARE_H_ + +#include "SpdzWiseShare.h" +#include "MaliciousRep3Share.h" +#include "Rep3Share2k.h" +#include "Math/Z2k.h" + +template class SpdzWiseRingPrep; +template class SpdzWiseRing; + +template +class SpdzWiseRingShare : public SpdzWiseShare>> +{ + typedef SpdzWiseRingShare This; + typedef SpdzWiseShare>> super; + +public: + typedef SignedZ2 clear; + typedef clear open_type; + typedef MaliciousRep3Share open_part_type; + + typedef SpdzWiseMC MAC_Check; + typedef MAC_Check Direct_MC; + + typedef SpdzWiseRing Protocol; + typedef SpdzWiseRingPrep LivePrep; + typedef SpdzWiseInput Input; + typedef ::PrivateOutput PrivateOutput; + + typedef GC::MaliciousRepSecret bit_type; + + static const int LENGTH = K; + static const int SECURITY = S; + + SpdzWiseRingShare() + { + } + + template + SpdzWiseRingShare(const T& other) : + super(other) + { + } + + template + SpdzWiseRingShare(const T &share, const U &mac) : + super(share, mac) + { + } + + template + static void split(vector& dest, const vector& regs, + int n_bits, const SpdzWiseRingShare* source, int n_inputs, Player& P) + { + vector> shares(n_inputs); + for (int i = 0; i < n_inputs; i++) + shares[i] = source[i].get_share(); + Rep3Share2::split(dest, regs, n_bits, shares.data(), n_inputs, P); + } + + static void shrsi(SubProcessor& proc, const Instruction& inst) + { + typename This::part_type::Honest::Protocol protocol(proc.P); + protocol.init_mul(); + for (int i = 0; i < inst.get_size(); i++) + { + auto& dest = proc.get_S_ref(inst.get_r(0) + i); + auto& source = proc.get_S_ref(inst.get_r(1) + i); + dest.set_share(Rep3Share2(source.get_share()) >> inst.get_n()); + protocol.prepare_mul(dest.get_share(), proc.MC.get_alphai()); + } + protocol.exchange(); + for (int i = 0; i < inst.get_size(); i++) + { + auto& dest = proc.get_S_ref(inst.get_r(0) + i); + dest.set_mac(protocol.finalize_mul()); + proc.protocol.add_to_check(dest); + } + } +}; + +#endif /* PROTOCOLS_SPDZWISERINGSHARE_H_ */ diff --git a/Protocols/SpdzWiseShare.h b/Protocols/SpdzWiseShare.h new file mode 100644 index 000000000..84c56e9b3 --- /dev/null +++ b/Protocols/SpdzWiseShare.h @@ -0,0 +1,74 @@ +/* + * SpdzWiseShare.h + * + */ + +#ifndef PROTOCOLS_SPDZWISESHARE_H_ +#define PROTOCOLS_SPDZWISESHARE_H_ + +#include "Share.h" +#include "SpdzWise.h" +#include "Processor/DummyProtocol.h" +#include "Processor/NoProtocol.h" + +template class NoLivePrep; +template class NotImplementedInput; +template class SpdzWiseMC; +template class SpdzWisePrep; +template class SpdzWiseInput; + +namespace GC +{ +class MaliciousRepSecret; +} + +template +class SpdzWiseShare : public Share_ +{ + typedef Share_ super; + +public: + typedef T open_part_type; + typedef typename T::clear clear; + typedef typename T::open_type open_type; + + typedef SpdzWiseMC MAC_Check; + typedef MAC_Check Direct_MC; + + typedef SpdzWise Protocol; + typedef SpdzWisePrep LivePrep; + typedef SpdzWiseInput Input; + typedef ::PrivateOutput PrivateOutput; + + typedef typename T::bit_type bit_type; + + static const bool expensive = true; + + static string type_short() + { + return "SY" + T::type_short(); + } + + static string type_string() + { + return "SPDZ-wise " + T::type_string(); + } + + static void read_or_generate_mac_key(string directory, Player& P, T& mac_key); + + SpdzWiseShare() + { + } + + SpdzWiseShare(const super& other) : + super(other) + { + } + + SpdzWiseShare(const T& share, const T& mac) : + super(share, mac) + { + } +}; + +#endif /* PROTOCOLS_SPDZWISESHARE_H_ */ diff --git a/Protocols/SpdzWiseShare.hpp b/Protocols/SpdzWiseShare.hpp new file mode 100644 index 000000000..60dccba50 --- /dev/null +++ b/Protocols/SpdzWiseShare.hpp @@ -0,0 +1,43 @@ +/* + * SpdzWiseShare.hpp + * + */ + +#ifndef PROTOCOLS_SPDZWISESHARE_HPP_ +#define PROTOCOLS_SPDZWISESHARE_HPP_ + +#include "SpdzWiseShare.h" + +#include "fake-stuff.hpp" + +template +void SpdzWiseShare::read_or_generate_mac_key(string directory, Player& P, T& mac_key) +{ + try + { + read_mac_key(directory, P.N, mac_key); + } + catch (mac_key_error&) + { + SeededPRNG G; + mac_key.randomize(G); + } + + try + { + // validate MAC key + typename open_part_type::MAC_Check MC; + auto masked = typename T::Honest::Protocol(P).get_random() + mac_key; + MC.open(masked, P); + MC.Check(P); + } + catch (mac_fail&) + { +#ifdef VERBOSE + cerr << "Generating fresh MAC key for " << type_string() << endl; +#endif + mac_key = typename T::Honest::Protocol(P).get_random(); + } +} + +#endif /* PROTOCOLS_SPDZWISESHARE_HPP_ */ diff --git a/Protocols/dabit.h b/Protocols/dabit.h new file mode 100644 index 000000000..1228a033d --- /dev/null +++ b/Protocols/dabit.h @@ -0,0 +1,46 @@ +/* + * dabit.h + * + */ + +#ifndef PROTOCOLS_DABIT_H_ +#define PROTOCOLS_DABIT_H_ + +#include +using namespace std; + +template +class dabit : public pair +{ + typedef pair super; + +public: + typedef typename T::bit_type::part_type bit_type; + + static int size() + { + return T::size() + bit_type::size(); + } + + static string type_string() + { + return T::type_string(); + } + + dabit() + { + } + + dabit(const T& a, const bit_type& b) : + super(a, b) + { + } + + void assign(const char* buffer) + { + this->first.assign(buffer); + this->second.assign(buffer + T::size()); + } +}; + +#endif /* PROTOCOLS_DABIT_H_ */ diff --git a/Protocols/edabit.h b/Protocols/edabit.h index df78b7b4a..91a62dc1d 100644 --- a/Protocols/edabit.h +++ b/Protocols/edabit.h @@ -15,11 +15,14 @@ template class edabitvec { typedef FixedVector b_type; + typedef FixedVector a_type; - FixedVector a; +public: + static const int MAX_SIZE = a_type::MAX_SIZE; + + a_type a; b_type b; -public: edabitvec() { } @@ -92,6 +95,27 @@ class edabitvec } a.push_back(x.first); } + + void input(int length, ifstream& s) + { + char buffer[MAX_SIZE * T::size()]; + s.read(buffer, MAX_SIZE * T::size()); + for (int i = 0; i < MAX_SIZE; i++) + { + T x; + x.assign(buffer + i * T::size()); + a.push_back(x); + } + size_t bsize = T::bit_type::part_type::size(); + char bbuffer[length * bsize]; + s.read(bbuffer, length * bsize); + for (int i = 0; i < length; i++) + { + typename T::bit_type::part_type x; + x.assign(bbuffer + i * bsize); + b.push_back(x); + } + } }; #endif /* PROTOCOLS_EDABIT_H_ */ diff --git a/Protocols/fake-stuff.h b/Protocols/fake-stuff.h index 5fb375274..b07c95ced 100644 --- a/Protocols/fake-stuff.h +++ b/Protocols/fake-stuff.h @@ -28,6 +28,8 @@ void write_mac_key(const string& directory, int player_num, int nplayers, T key) template void read_mac_key(const string& directory, int player_num, int nplayers, U& key); +template +void read_mac_key(const string& directory, const Names& N, U& key); template class Files @@ -55,9 +57,16 @@ class Files { delete[] outf; } - void output_shares(const typename T::clear& a) + template + void output_shares(const typename U::clear& a) + { + output_shares(a, key); + } + template + void output_shares(const typename U::clear& a, + const typename U::mac_type& key) { - vector Sa(N); + vector Sa(N); make_share(Sa,a,N,key,G); for (int j=0; j class Share; template class SemiShare; template class ShamirShare; +template class MaliciousShamirShare; template class FixedVec; template class Share_; +template class SpdzWiseShare; +template class MaliciousRep3Share; namespace GC { template class TinySecret; template class TinierSecret; +template class MaliciousCcdSecret; } template @@ -43,6 +49,42 @@ void make_share(Share_* Sa,const U& a,int N,const V& key,PRNG& G) Sa[N-1]=S; } +template +void make_share(SpdzWiseShare>* Sa,const U& a,int N,const V& key,PRNG& G) +{ + insecure("share generation", false); + assert (key[0] == key[1]); + auto mac = a * key[0]; + FixedVec shares, macs; + shares.randomize_to_sum(a, G); + macs.randomize_to_sum(mac, G); + + for (int i = 0; i < N; i++) + { + MaliciousRep3Share share, mac; + share[0] = shares[i]; + share[1] = shares[positive_modulo(i - 1, 3)]; + mac[0] = macs[i]; + mac[1] = macs[positive_modulo(i - 1, 3)]; + Sa[i].set_share(share); + Sa[i].set_mac(mac); + } +} + +template +void make_share(SpdzWiseShare>* Sa, const U& a, int N, + const V& key, PRNG& G) +{ + vector> shares(N), macs(N); + make_share(shares.data(), a, N, {}, G); + make_share(macs.data(), a * key, N, {}, G); + for (int i = 0; i < N; i++) + { + Sa[i].set_share(shares[i]); + Sa[i].set_mac(macs[i]); + } +} + template void make_vector_share(T* Sa,const U& a,int N,const V& key,PRNG& G) { @@ -70,10 +112,9 @@ void make_share(GC::TinierSecret* Sa, const U& a, int N, const V& key, PRNG& make_vector_share(Sa, a, N, key, G); } -template -void make_share(SemiShare* Sa,const T& a,int N,const T& key,PRNG& G) +template +void make_share(SemiShare* Sa,const T& a,int N,const U&,PRNG& G) { - (void) key; insecure("share generation", false); T x, S = a; for (int i=0; i* Sa,const T& a,int N,const T& key,PRNG& G) template void make_share(FixedVec* Sa, const V& a, int N, const U& key, PRNG& G); -template +template inline void make_share(vector& Sa, - const typename T::clear& a, int N, const typename T::mac_type& key, + const typename T::clear& a, int N, const U& key, PRNG& G) { Sa.resize(N); @@ -115,8 +156,26 @@ void make_share(FixedVec* Sa, const V& a, int N, const U& key, PRNG& G) } } -template -void make_share(ShamirShare* Sa, const T& a, int N, +template +void make_share(FixedVec* Sa, const V& a, int N, const U& key, PRNG& G) +{ + (void) key; + assert(N == 4); + insecure("share generation", false); + FixedVec add_shares; + add_shares.randomize_to_sum(a, G); + for (int i=0; i share; + share[0] = add_shares[(i + 0) % 4]; + share[1] = add_shares[(i + 1) % 4]; + share[2] = add_shares[(i + 2) % 4]; + Sa[i] = share; + } +} + +template +void make_share(ShamirShare* Sa, const V& a, int N, const typename ShamirShare::mac_type&, PRNG& G) { insecure("share generation", false); @@ -192,7 +251,7 @@ inline string mac_filename(string directory, int playerno) { if (directory.empty()) directory = "."; - return directory + "/Player-MAC-Keys-" + string(1, T::type_char()) + "-P" + return directory + "/Player-MAC-Keys-" + T::type_short() + "-P" + to_string(playerno); } @@ -265,45 +324,115 @@ void read_global_mac_key(const string& directory, int nparties, U& key) cout << "Final Keys : " << key << endl; } +template +T reconstruct(vector& shares) +{ + return sum(shares); +} + +template +T reconstruct(vector>& shares) +{ + T res; + for (auto& x : shares) + res += x[0]; + return res; +} + +template +T reconstruct(vector>& shares) +{ + T res; + for (size_t i = 0; i < shares.size(); i++) + res += Shamir>::get_rec_factor(i, shares.size()) * shares[i]; + return res; +} + +template +void make_mac_key_share(typename T::mac_share_type::open_type& key, + vector& key_shares, int nplayers, T) +{ + SeededPRNG G; + key.randomize(G); + make_share(key_shares.data(), key, nplayers, GC::NoShare(), G); + assert(not key_shares[0].is_zero()); +} + +template +void make_mac_key_share(Z2& key, + vector>>& key_shares, int nplayers, Spdz2kShare) +{ + SeededPRNG G; + key = {}; + key_shares.resize(nplayers); + for (int i = 0; i < nplayers; i++) + { + key_shares[i] = G.get>(); + key += key_shares[i]; + } + assert(not key.is_zero()); +} + template -void generate_mac_keys(typename T::mac_type::Scalar& key, +void generate_mac_keys(typename T::mac_share_type::open_type& key, int nplayers, string prep_data_prefix) { key.assign_zero(); int tmpN = 0; ifstream inpf; - SeededPRNG G; prep_data_prefix = get_prep_sub_dir(prep_data_prefix, nplayers); + bool generate = false; + vector key_shares(nplayers); for (int i = 0; i < nplayers; i++) { + auto& pp = key_shares[i]; stringstream filename; - filename - << mac_filename(prep_data_prefix, - i); + filename << mac_filename(prep_data_prefix, i); inpf.open(filename.str().c_str()); - typename T::mac_key_type::Scalar pp; if (inpf.fail()) { inpf.close(); - cout << "No MAC key share for player " << i << ", generating a fresh one\n"; - pp.randomize(G); - ofstream outf(filename.str().c_str()); - if (outf.fail()) - throw file_error(filename.str().c_str()); - outf << nplayers << " " << pp << endl; - outf.close(); - cout << "Written new MAC key share to " << filename.str() << endl; + cout << "No MAC key share for player " << i << ", generating a fresh ones\n"; + generate = true; + break; } else { inpf >> tmpN; // not needed here pp.input(inpf,true); inpf.close(); + if (pp.is_zero()) + { + generate = true; + break; + } } cout << " Key " << i << ": " << pp << endl; - key.add(pp); } + + key = reconstruct(key_shares); + + if (generate) + { + make_mac_key_share(key, key_shares, nplayers, T()); + + for (int i = 0; i < nplayers; i++) + { + auto& pp = key_shares[i]; + stringstream filename; + filename + << mac_filename(prep_data_prefix, i); + ofstream outf(filename.str().c_str()); + if (outf.fail()) + throw file_error(filename.str().c_str()); + outf << nplayers << " " << pp << endl; + outf.close(); + cout << "Written new MAC key share to " << filename.str() << endl; + cout << " Key " << i << ": " << pp << endl; + } + } + cout << "--------------\n"; cout << "Final Key: " << key << endl; } diff --git a/README.md b/README.md index b5c2087af..f11f29425 100644 --- a/README.md +++ b/README.md @@ -73,7 +73,7 @@ The following table lists all protocols that are fully supported. | Malicious, dishonest majority | [MASCOT](#secret-sharing) | [SPDZ2k](#secret-sharing) | [Tiny / Tinier](#secret-sharing) | [BMR](#bmr) | | Covert, dishonest majority | [CowGear / ChaiGear](#secret-sharing) | N/A | N/A | N/A | | Semi-honest, dishonest majority | [Semi / Hemi / Soho](#secret-sharing) | [Semi2k](#secret-sharing) | [SemiBin](#secret-sharing) | [Yao's GC](#yaos-garbled-circuits) / [BMR](#bmr) | -| Malicious, honest majority | [Shamir / Rep3 / PS](#honest-majority) | [Brain / Rep3 / PS](#honest-majority) | [Rep3 / CCD](#honest-majority) | [BMR](#bmr) | +| Malicious, honest majority | [Shamir / Rep3 / PS / SY](#honest-majority) | [Brain / Rep[34] / PS / SY](#honest-majority) | [Rep3 / CCD](#honest-majority) | [BMR](#bmr) | | Semi-honest, honest majority | [Shamir / Rep3](#honest-majority) | [Rep3](#honest-majority) | [Rep3 / CCD](#honest-majority) | [BMR](#bmr) | #### Paper and Citation @@ -261,6 +261,19 @@ al.](https://eprint.iacr.org/2020/338). You can activate them by using `-Y` instead of `-X`. Note that this also activates classic daBits when useful. +##### Local share conversion + +This technique has been used by [Mohassel and +Rindal](https://eprint.iacr.org/2018/403) as well as [Araki et +al.](https://eprint.iacr.org/2018/762) It involves locally +converting an arithmetic share to a set of binary shares, from which the +binary equivalent to the arithmetic share is reconstructed using a +binary adder. This requires additive secret sharing over a ring +without any MACs. You can activate it by using `-Z ` with the +compiler where `n` is the number of parties for the standard variant +(3 or 4) and 2 for the special +variant by Mohassel and Rindal (available in Rep3 only). + #### Bristol Fashion circuits Bristol Fashion is the name of a description format of binary circuits @@ -335,6 +348,15 @@ https://github.com/mkskeller/EzPC/commit/2021be90d21dc26894be98f33cd10dd26769f47 [The reference](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#module-Compiler.ml) contains further documentation on available layers. +### Emulation + +For arithmetic circuits modulo a power of two and binary circuits, you +can emulate the computation as follows: + +``` ./emulate.x ``` + +This runs the compiled bytecode in cleartext computation. + ## Dishonest majority Some full implementations require oblivious transfer, which is @@ -470,20 +492,25 @@ The following table shows all programs for honest-majority computation: | `brain-party.x` | Replicated | Mod 2^k | Y | 3 | `brain.sh` | | `ps-rep-ring-party.x` | Replicated | Mod 2^k | Y | 3 | `ps-rep-ring.sh` | | `malicious-rep-ring-party.x` | Replicated | Mod 2^k | Y | 3 | `mal-rep-ring.sh` | +| `sy-rep-ring-party.x` | SPDZ-wise replicated | Mod 2^k | Y | 3 | `sy-rep-ring.sh` | +| `rep4-ring-party.x` | Replicated | Mod 2^k | Y | 4 | `rep4-ring.sh` | | `replicated-bin-party.x` | Replicated | Binary | N | 3 | `replicated.sh` | | `malicious-rep-bin-party.x` | Replicated | Binary | Y | 3 | `mal-rep-bin.sh` | | `replicated-field-party.x` | Replicated | Mod prime | N | 3 | `rep-field.sh` | | `ps-rep-field-party.x` | Replicated | Mod prime | Y | 3 | `ps-rep-field.sh` | +| `sy-rep-field-party.x` | SPDZ-wise replicated | Mod prime | Y | 3 | `sy-rep-field.sh` | | `malicious-rep-field-party.x` | Replicated | Mod prime | Y | 3 | `mal-rep-field.sh` | | `shamir-party.x` | Shamir | Mod prime | N | 3 or more | `shamir.sh` | | `malicious-shamir-party.x` | Shamir | Mod prime | Y | 3 or more | `mal-shamir.sh` | +| `sy-shamir-party.x` | SPDZ-wise Shamir | Mod prime | Y | 3 or more | `mal-shamir.sh` | | `ccd-party.x` | CCD/Shamir | Binary | N | 3 or more | `ccd.sh` | | `malicious-cdd-party.x` | CCD/Shamir | Binary | Y | 3 or more | `mal-ccd.sh` | We use the "generate random triple optimistically/sacrifice/Beaver" methodology described by [Lindell and Nof](https://eprint.iacr.org/2017/816) to achieve malicious -security, except for the "PS" (post-sacrifice) protocols where the +security with plain replicated secret sharing, +except for the "PS" (post-sacrifice) protocols where the actual multiplication is executed optimistally and checked later as also described by Lindell and Nof. The implementations used by `brain-party.x`, @@ -492,7 +519,7 @@ and `ps-rep-ring-party.x` correspond to the protocols called DOS18 preprocessing (single), ABF+17 preprocessing, CDE+18 preprocessing, and postprocessing, respectively, by [Eerikson et al.](https://eprint.iacr.org/2019/164) -Otherwise, we use resharing by [Cramer et +We use resharing by [Cramer et al.](https://eprint.iacr.org/2000/037) for Shamir's secret sharing and the optimized approach by [Araki et al.](https://eprint.iacr.org/2016/768) for replicated secret sharing. @@ -500,6 +527,14 @@ The CCD protocols are named after the [historic paper](https://doi.org/10.1145/62212.62214) by Chaum, Crépeau, and Damgård, which introduced binary computation using Shamir secret sharing over extension fields of characteristic two. +SY/SPDZ-wise refers to the line of work started by [Chida et +al.](https://eprint.iacr.org/2018/570) for computation modulo a prime +and furthered by [Abspoel et al.](https://eprint.iacr.org/2019/1298) +for computation modulo a power of two. It involves sharing both a +secret value and information-theoretic tag similar to SPDZ but not +with additive secret sharing, hence the name. +Rep4 refers to the four-party protocol by [Dalskov et +al.](https://eprint.iacr.org/2020/1330). All protocols in this section require encrypted channels because the information received by the honest majority suffices the reconstruct @@ -678,23 +713,22 @@ Creating fake offline data for SPDZ2k requires to call `./Fake-Offline.x -Z -S ` -### Honest-majority three-party computation of binary circuits with malicious security +You will need to run `spdz2k-party.x -F` in order to use the data from storage. -Compile the virtual machines: +### Other protocols -`make -j 8 rep-bin` +Preprocessing data for the default parameters of most other protocols +can be produced as follows: -Generate preprocessing data: +`./Fake-Offline.x -e ` -`Scripts/setup-online.sh 3` +The `-e` command-line parameters accepts a list of integers seperated +by commas. -After compilating the mpc file, run as follows: - -`malicious-rep-bin-party.x [-I] -h -p <0/1/2> tutorial` - -When running locally, you can omit the host argument. As above, `-I` -activates interactive input, otherwise inputs are read from -`Player-Data/Input-P-0`. +You can then run the protocol with argument `-F`. Note that when +running on several hosts, you will need to distribute the data in +`Player-Data`. The preprocessing files contain `-P` +indicating which party will access it. ### BMR diff --git a/Scripts/emulate.sh b/Scripts/emulate.sh new file mode 100755 index 000000000..d38fc37f0 --- /dev/null +++ b/Scripts/emulate.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +test -e logs || mkdir logs +prog=${1%.sch} +prog=${prog##*/} +shift +./emulate.x $prog $* 2>&1 | tee -a logs/emulate-$prog diff --git a/Scripts/rep4-ring.sh b/Scripts/rep4-ring.sh new file mode 100755 index 000000000..73f71d104 --- /dev/null +++ b/Scripts/rep4-ring.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +HERE=$(cd `dirname $0`; pwd) +SPDZROOT=$HERE/.. + +export PLAYERS=4 + +. $HERE/run-common.sh + +run_player rep4-ring-party.x $* || exit 1 diff --git a/Scripts/run-common.sh b/Scripts/run-common.sh index c7a1e10b9..ee1346359 100644 --- a/Scripts/run-common.sh +++ b/Scripts/run-common.sh @@ -24,7 +24,7 @@ lldb_screen() } run_player() { - port=$((RANDOM%10000+10000)) + port=$((RANDOM%60000+10000)) bin=$1 shift if ! test -e $SPDZROOT/logs; then diff --git a/Scripts/sy-rep-field.sh b/Scripts/sy-rep-field.sh new file mode 100755 index 000000000..0b2d40c75 --- /dev/null +++ b/Scripts/sy-rep-field.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +HERE=$(cd `dirname $0`; pwd) +SPDZROOT=$HERE/.. + +export PLAYERS=3 + +. $HERE/run-common.sh + +run_player sy-rep-field-party.x $* || exit 1 diff --git a/Scripts/sy-rep-ring.sh b/Scripts/sy-rep-ring.sh new file mode 100755 index 000000000..a09fe5fe6 --- /dev/null +++ b/Scripts/sy-rep-ring.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +HERE=$(cd `dirname $0`; pwd) +SPDZROOT=$HERE/.. + +export PLAYERS=3 + +. $HERE/run-common.sh + +run_player sy-rep-ring-party.x $* || exit 1 diff --git a/Scripts/sy-shamir.sh b/Scripts/sy-shamir.sh new file mode 100755 index 000000000..98ae80ce0 --- /dev/null +++ b/Scripts/sy-shamir.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +HERE=$(cd `dirname $0`; pwd) +SPDZROOT=$HERE/.. + +export PLAYERS=3 + +. $HERE/run-common.sh + +run_player sy-shamir-party.x $* || exit 1 diff --git a/Scripts/test_ecdsa.sh b/Scripts/test_ecdsa.sh index 52376b68e..829e9aa44 100755 --- a/Scripts/test_ecdsa.sh +++ b/Scripts/test_ecdsa.sh @@ -1,6 +1,6 @@ #!/usr/bin/env bash -echo MY_CFLAGS += -DINSECURE >> CONFIG.mine +echo SECURE = -DINSECURE >> CONFIG.mine touch ECDSA/Fake-ECDSA.cpp make -j4 ecdsa Fake-ECDSA.x secure.x diff --git a/Scripts/test_tutorial.sh b/Scripts/test_tutorial.sh index 0c799d7bc..9923c30d5 100755 --- a/Scripts/test_tutorial.sh +++ b/Scripts/test_tutorial.sh @@ -46,13 +46,15 @@ for dabit in ${dabit:-0 1 2}; do ./compile.py -R 64 $compile_opts tutorial - for i in ring semi2k brain mal-rep-ring ps-rep-ring spdz2k; do + for i in ring rep4-ring semi2k brain mal-rep-ring ps-rep-ring sy-rep-ring \ + spdz2k; do test_vm $i $run_opts done ./compile.py $compile_opts tutorial - for i in rep-field shamir mal-rep-field ps-rep-field mal-shamir hemi semi \ + for i in rep-field shamir mal-rep-field ps-rep-field sy-rep-field \ + mal-shamir sy-shamir hemi semi \ soho cowgear mascot; do test_vm $i $run_opts done @@ -60,6 +62,15 @@ for dabit in ${dabit:-0 1 2}; do test_vm chaigear $run_opts -l 3 -c 2 done +./compile.py -R 64 -Z 3 tutorial +test_vm ring $run_opts + +./compile.py -R 64 -Z 4 tutorial +test_vm rep4-ring $run_opts + +./compile.py -R 64 -Z 2 tutorial +test_vm semi2k $run_opts + ./compile.py tutorial test_vm cowgear $run_opts -T diff --git a/Tools/Hash.cpp b/Tools/Hash.cpp index a1bd14a4d..680bec969 100644 --- a/Tools/Hash.cpp +++ b/Tools/Hash.cpp @@ -40,6 +40,7 @@ void Hash::final(octetStream& os) os.resize_precise(hash_length); os.reset_write_head(); final(os.append(hash_length)); + reset(); } octetStream Hash::final() diff --git a/Tools/Hash.h b/Tools/Hash.h index 74ee6d2e8..07722d642 100644 --- a/Tools/Hash.h +++ b/Tools/Hash.h @@ -2,8 +2,10 @@ #define _SHA1 #include +#include +using namespace std; -class octetStream; +#include "octetStream.h" class Hash { @@ -25,6 +27,20 @@ class Hash size += len; } void update(const octetStream& os); + template + void update(const T& x) + { + update(x.get_ptr(), x.size()); + } + template + void update(const vector& v) + { + octetStream tmp(v.size() * sizeof(T)); + for (auto& x : v) + x.pack(tmp); + update(tmp); + } + void final(unsigned char hashout[hash_length]) { crypto_generichash_final(state, hashout, crypto_generichash_BYTES); diff --git a/Tools/Phase.cpp b/Tools/Phase.cpp deleted file mode 100644 index 4880d3cb1..000000000 --- a/Tools/Phase.cpp +++ /dev/null @@ -1,8 +0,0 @@ -/* - * Phase.cpp - * - */ - -#include "BMR/Register.h" - -BlackHole Phase::out; diff --git a/Tools/PointerVector.h b/Tools/PointerVector.h index 12bb6453c..68f830f91 100644 --- a/Tools/PointerVector.h +++ b/Tools/PointerVector.h @@ -20,6 +20,10 @@ class PointerVector : public vector void clear() { vector::clear(); + reset(); + } + void reset() + { i = 0; } T& next() diff --git a/Tools/octetStream.h b/Tools/octetStream.h index 6e95ff890..ac4e79f95 100644 --- a/Tools/octetStream.h +++ b/Tools/octetStream.h @@ -68,6 +68,7 @@ class octetStream size_t get_length() const { return len; } size_t get_max_length() const { return mxlen; } octet* get_data() const { return data; } + octet* get_data_ptr() const { return data + ptr; } bool done() const { return ptr == len; } bool empty() const { return len == 0; } diff --git a/Tools/random.h b/Tools/random.h index 9d6fc7052..33008dd94 100644 --- a/Tools/random.h +++ b/Tools/random.h @@ -135,6 +135,16 @@ class GlobalPRNG : public PRNG } }; +template +class ElementPRNG : public PRNG +{ +public: + T get() + { + return PRNG::get(); + } +}; + inline bool PRNG::get_bit() { if (n_cached_bits == 0) diff --git a/Utils/Check-Offline.cpp b/Utils/Check-Offline.cpp index a8124b3eb..f465c94d4 100644 --- a/Utils/Check-Offline.cpp +++ b/Utils/Check-Offline.cpp @@ -12,14 +12,19 @@ #include "Tools/ezOptionParser.h" #include "Exceptions/Exceptions.h" #include "GC/MaliciousRepSecret.h" +#include "GC/TinierSecret.h" +#include "GC/TinyMC.h" +#include "GC/SemiSecret.h" #include "Math/Setup.h" #include "Processor/Data_Files.h" #include "Protocols/fake-stuff.hpp" +#include "Protocols/ReplicatedPrep.hpp" #include "Processor/Data_Files.hpp" #include "Math/Z2k.hpp" #include "Math/gfp.hpp" +#include "GC/Secret.hpp" #include #include diff --git a/Utils/Fake-Offline.cpp b/Utils/Fake-Offline.cpp index f96a294d2..c02adbf1d 100644 --- a/Utils/Fake-Offline.cpp +++ b/Utils/Fake-Offline.cpp @@ -10,12 +10,17 @@ #include "Protocols/PostSacriRepFieldShare.h" #include "Protocols/SemiShare.h" #include "Protocols/MaliciousShamirShare.h" +#include "Protocols/SpdzWiseRingShare.h" +#include "Protocols/SpdzWiseShare.h" +#include "Protocols/Rep4Share2k.h" #include "Protocols/fake-stuff.h" #include "Exceptions/Exceptions.h" #include "GC/MaliciousRepSecret.h" #include "GC/SemiSecret.h" #include "GC/TinySecret.h" #include "GC/TinierSecret.h" +#include "GC/MaliciousCcdSecret.h" +#include "GC/Rep4Secret.h" #include "Math/Setup.h" #include "Processor/Data_Files.h" @@ -24,6 +29,7 @@ #include "Tools/benchmarking.h" #include "Protocols/fake-stuff.hpp" +#include "Protocols/Shamir.hpp" #include "Processor/Data_Files.hpp" #include "Math/Z2k.hpp" #include "Math/gfp.hpp" @@ -35,6 +41,7 @@ using namespace std; string prep_data_prefix; +ez::ezOptionParser opt; void make_bit_triples(const gf2n& key,int N,int ntrip,Dtype dtype,bool zero) { @@ -158,6 +165,60 @@ void make_bits(const typename T::mac_type& key, int N, int ntrip, bool zero, delete[] outf; } +template +void make_dabits(const typename T::mac_type& key, int N, int ntrip, bool zero, + const typename T::bit_type::mac_type& bit_key = { }) +{ + Files files(N, key, + get_prep_sub_dir(prep_data_prefix, N) + + DataPositions::dtype_names[DATA_DABIT] + "-" + T::type_short()); + SeededPRNG G; + for (int i = 0; i < ntrip; i++) + { + bool bit = not zero && G.get_bit(); + files.template output_shares(bit); + files.template output_shares::bit_type>(bit, bit_key); + } +} + +template +void make_edabits(const typename T::mac_type& key, int N, int ntrip, bool zero, false_type, + const typename T::bit_type::mac_type& bit_key = {}) +{ + vector lengths; + opt.get("-e")->getInts(lengths); + for (auto length : lengths) + { + Files files(N, key, + get_prep_sub_dir(prep_data_prefix, N) + + "edaBits-" + to_string(length)); + SeededPRNG G; + bigint value; + int max_size = edabitvec::MAX_SIZE; + for (int i = 0; i < ntrip / max_size; i++) + { + vector as(max_size); + vector bs(length); + for (int j = 0; j < max_size; j++) + { + if (not zero) + G.get_bigint(value, length, true); + as[j] = value; + for (int k = 0; k < length; k++) + bs[k] ^= BitVec(bigint((value >> k) & 1).get_si()) << j; + } + for (auto& a : as) + files.template output_shares(a); + for (auto& b : bs) + files.template output_shares(b, bit_key); + } + } +} + +template +void make_edabits(const typename T::mac_type&, int, int, bool, true_type) +{ +} /* N = Number players * ntrip = Number inputs needed @@ -245,6 +306,8 @@ void make_basic(const typename T::mac_type& key, int nplayers, int nitems, bool { make_minimal(key, nplayers, nitems, zero); make_square_tuples(key, nplayers, nitems, T::type_short(), zero); + make_dabits(key, nplayers, nitems, zero); + make_edabits(key, nplayers, nitems, zero, T::clear::characteristic_two); if (T::clear::invertible) { make_inverse(key, nplayers, nitems, zero, prep_data_prefix); @@ -252,6 +315,14 @@ void make_basic(const typename T::mac_type& key, int nplayers, int nitems, bool } } +template +void make_with_mac_key(int nplayers, int default_num, bool zero) +{ + typename T::mac_share_type::open_type key; + generate_mac_keys(key, nplayers, prep_data_prefix); + make_basic(key, nplayers, default_num, zero); +} + template int generate(ez::ezOptionParser& opt); @@ -260,8 +331,6 @@ int main(int argc, const char** argv) insecure("preprocessing"); bigint::init_thread(); - ez::ezOptionParser opt; - opt.syntax = "./Fake-Offline.x [OPTIONS]\n\nOptions with 2 arguments take the form '-X <#gf2n tuples>,<#modp tuples>'"; opt.example = "./Fake-Offline.x 2 -lgp 128 -lg2 128 --default 10000\n./Fake-Offline.x 3 -trip 50000,10000 -btrip 100000\n"; @@ -382,6 +451,15 @@ int main(int argc, const char** argv) "-S", // Flag token. "--security" // Flag token. ); + opt.add( + "", // Default. + 0, // Required? + -1, // Number of args expected. + ',', // Delimiter if expecting multiple args. + "edaBit lengths (separate by comma)", // Help description. + "-e", // Flag token. + "--edabits" // Flag token. + ); opt.parse(argc, argv); if (opt.isSet("-Z")) @@ -544,13 +622,17 @@ int generate(ez::ezOptionParser& opt) make_basic>({}, nplayers, default_num, zero); make_basic>({}, nplayers, default_num, zero); make_basic>({}, nplayers, default_num, zero); - make_bits>({}, nplayers, default_num, zero); - make_bits>({}, nplayers, default_num, zero); - make_bits>({}, nplayers, default_num, zero); + make_basic>({}, nplayers, default_num, zero); + make_basic>({}, nplayers, default_num, zero); + make_basic>({}, nplayers, default_num, zero); + make_with_mac_key>(nplayers, default_num, zero); + make_with_mac_key>>(nplayers, default_num, zero); make_mult_triples({}, nplayers, ntrip2, zero, prep_data_prefix); make_bits({}, nplayers, nbits2, zero); } + else if (nplayers == 4) + make_basic>({}, nplayers, default_num, zero); make_basic>({}, nplayers, default_num, zero); make_basic>({}, nplayers, default_num, zero); @@ -565,17 +647,29 @@ int generate(ez::ezOptionParser& opt) Z2<41> keyt; generate_mac_keys>(keyt, nplayers, prep_data_prefix); - make_minimal>(keyt, nplayers, default_num, zero); + make_minimal>(keyt, nplayers, default_num / 64, zero); gf2n_short keytt; generate_mac_keys>(keytt, nplayers, prep_data_prefix); - make_minimal>(keytt, nplayers, default_num, zero); + make_minimal>(keytt, nplayers, default_num / 64, zero); + + make_dabits(keyp, nplayers, default_num, zero, keytt); + make_edabits(keyp, nplayers, default_num, zero, false_type(), keytt); - make_basic>({}, nplayers, default_num, zero); - make_basic>({}, nplayers, default_num, zero); + if (nplayers > 2) + { + make_basic>({}, nplayers, default_num, zero); + make_basic>({}, nplayers, default_num, zero); + + make_basic>({}, nplayers, default_num, zero); + make_basic>({}, nplayers, default_num, zero); - make_basic>({}, nplayers, default_num, zero); - make_basic>({}, nplayers, default_num, zero); + make_with_mac_key>>(nplayers, + default_num, zero); + + make_mult_triples>({}, nplayers, + default_num, zero, prep_data_prefix); + } return 0; } diff --git a/Yao/YaoEvalWire.cpp b/Yao/YaoEvalWire.cpp index a867afdb3..2379d42d8 100644 --- a/Yao/YaoEvalWire.cpp +++ b/Yao/YaoEvalWire.cpp @@ -18,8 +18,6 @@ #include "GC/ShareSecret.hpp" #include "YaoCommon.hpp" -ostream& YaoEvalWire::out = cout; - void YaoEvalWire::random() { set(0); diff --git a/Yao/YaoEvalWire.h b/Yao/YaoEvalWire.h index 4de503edf..18b1e4caf 100644 --- a/Yao/YaoEvalWire.h +++ b/Yao/YaoEvalWire.h @@ -26,8 +26,7 @@ class YaoEvalWire : public YaoWire static string name() { return "YaoEvalWire"; } - typedef ostream& out_type; - static ostream& out; + typedef SwitchableOutput out_type; static YaoEvalWire new_reg() { return {}; } diff --git a/Yao/YaoEvaluator.cpp b/Yao/YaoEvaluator.cpp index 750c5ed30..fd652a481 100644 --- a/Yao/YaoEvaluator.cpp +++ b/Yao/YaoEvaluator.cpp @@ -30,6 +30,7 @@ YaoEvaluator::YaoEvaluator(int thread_num, YaoEvalMaster& master) : void YaoEvaluator::pre_run() { + processor.out.activate(true); if (not continuous()) receive_to_store(*P); } diff --git a/azure-pipelines.yml b/azure-pipelines.yml index df154ce74..88acb8b94 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -19,6 +19,6 @@ steps: - script: | make - script: - Scripts/setup-ssl.sh + Scripts/setup-ssl.sh 4 - script: Scripts/test_tutorial.sh -C diff --git a/compile.py b/compile.py index c0fc2d9f8..50bcc895d 100755 --- a/compile.py +++ b/compile.py @@ -67,6 +67,10 @@ def main(): help="mixing arithmetic and binary computation") parser.add_option("-Y", "--edabit", action="store_true", dest="edabit", help="mixing arithmetic and binary computation using edaBits") + parser.add_option("-Z", "--split", default=None, dest="split", + help="mixing arithmetic and binary computation " + "using direct conversion if supported " + "(number of parties as argument)") parser.add_option("-C", "--CISC", action="store_true", dest="cisc", help="faster CISC compilation mode") parser.add_option("-v", "--verbose", action="store_true", dest="verbose", diff --git a/doc/index.rst b/doc/index.rst index 648ea82dd..59e7fd8b0 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -49,7 +49,7 @@ length during execution using ``program.set_bit_length(length)``. For binary computation you can do so with ``sint = sbitint.get_type(length)``. -The following option switches from a single computation domain to +The following options switch from a single computation domain to mixed computation when using in conjunction with arithmetic computation: @@ -67,6 +67,14 @@ The implementation of both daBits and edaBits are explained in this paper_. .. _paper: https://eprint.iacr.org/2020/338 +.. cmdoption:: -Z + --split + + Enables mixed computation using local conversion. This has been + used by `Mohassel and Rindal `_ + and `Araki et al. `_ It only + works with additive secret sharing modulo a power of two. + The following options change less fundamental aspects of the computation: