From 6a424539c93f5489a6d09360f0092224552d94d8 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Thu, 25 Aug 2022 13:20:46 +1000 Subject: [PATCH] SoftSpokenOT. --- .gitignore | 4 + .gitmodules | 12 +- BMR/Party.cpp | 2 + BMR/RealGarbleWire.h | 2 - BMR/RealGarbleWire.hpp | 2 +- BMR/RealProgramParty.hpp | 4 +- BMR/Register.h | 22 +--- BMR/Register.hpp | 12 +- BMR/Register_inline.h | 6 +- CHANGELOG.md | 17 ++- CONFIG | 18 +-- Compiler/GC/instructions.py | 16 ++- Compiler/GC/types.py | 83 ++++++++++--- Compiler/comparison.py | 21 ++-- Compiler/compilerLib.py | 11 +- Compiler/floatingpoint.py | 32 ++--- Compiler/instructions.py | 69 ++++++----- Compiler/instructions_base.py | 10 +- Compiler/library.py | 29 +++-- Compiler/ml.py | 13 ++- Compiler/mpc_math.py | 4 +- Compiler/program.py | 18 ++- Compiler/types.py | 131 ++++++++++++++------- Compiler/util.py | 3 + ECDSA/P256Element.h | 2 +- ECDSA/fake-spdz-ecdsa-party.cpp | 8 +- ECDSA/ot-ecdsa-party.hpp | 9 +- FHE/Ciphertext.cpp | 2 +- FHE/NTL-Subs.cpp | 6 +- FHE/PPData.cpp | 3 +- FHEOffline/PairwiseMachine.cpp | 17 ++- FHEOffline/PairwiseMachine.h | 25 ++-- FHEOffline/SimpleGenerator.cpp | 2 +- FHEOffline/SimpleGenerator.h | 6 +- FHEOffline/SimpleMachine.cpp | 14 +++ FHEOffline/SimpleMachine.h | 21 +++- GC/AtlasShare.h | 5 - GC/CcdShare.h | 5 - GC/FakeSecret.cpp | 3 +- GC/MaliciousCcdShare.h | 5 - GC/Processor.h | 2 +- GC/Secret.hpp | 2 +- GC/Secret_inline.h | 8 +- GC/ShareParty.hpp | 22 ++-- GC/ShareSecret.hpp | 9 +- GC/ShareThread.h | 1 + GC/ShareThread.hpp | 6 + GC/ThreadMaster.hpp | 10 +- GC/TinierShare.h | 22 +++- GC/TinyMC.h | 10 ++ GC/TinyShare.h | 5 - HOSTS.example | 5 - Machines/OTMachine.cpp | 2 +- Machines/Tinier.cpp | 12 +- Machines/TripleMachine.cpp | 20 +++- Machines/mama-party.cpp | 32 +++-- Machines/spdz2k-party.cpp | 8 +- Machines/tinier-party.cpp | 4 +- Makefile | 94 ++++++++++----- Math/Square.h | 2 + Math/Square.hpp | 9 ++ Math/Zp_Data.cpp | 4 + Math/Zp_Data.h | 2 + Math/bigint.h | 2 +- Networking/CryptoPlayer.cpp | 30 ++++- Networking/CryptoPlayer.h | 5 + Networking/Player.cpp | 119 ++++++++----------- Networking/Player.h | 70 +++++------ Networking/PlayerBuffer.h | 23 ++++ Networking/PlayerCtSocket.h | 169 +++++++++++++++++++++++++++ Networking/Receiver.h | 5 + Networking/Sender.h | 5 + OT/BaseOT.cpp | 125 ++++++++------------ OT/BaseOT.h | 11 +- OT/BitMatrix.h | 3 + OT/BitMatrix.hpp | 7 +- OT/MamaRectangle.h | 5 + OT/NPartyTripleGenerator.h | 5 +- OT/NPartyTripleGenerator.hpp | 38 +++++- OT/OTExtension.cpp | 9 ++ OT/OTExtension.h | 2 +- OT/OTExtensionWithMatrix.cpp | 137 +++++++++++++++++++++- OT/OTExtensionWithMatrix.h | 28 ++++- OT/OTMultiplier.h | 16 +-- OT/OTMultiplier.hpp | 75 ++++++++++-- OT/OTTripleSetup.h | 25 +++- OT/Rectangle.h | 2 + OT/Rectangle.hpp | 7 ++ OT/TripleMachine.h | 7 +- Processor/BaseMachine.cpp | 17 ++- Processor/BaseMachine.h | 11 +- Processor/Instruction.hpp | 15 +-- Processor/Machine.hpp | 23 ++-- Processor/NoFilePrep.h | 2 +- Processor/OfflineMachine.hpp | 16 ++- Processor/Online-Thread.hpp | 3 + Processor/OnlineMachine.h | 5 + Processor/OnlineMachine.hpp | 10 -- Processor/OnlineOptions.cpp | 9 +- Processor/Processor.h | 4 +- Processor/Processor.hpp | 27 +++-- Programs/Source/mnist_full_C.mpc | 1 + Programs/Source/test_args.mpc | 2 +- Programs/Source/test_gc.mpc | 8 +- Protocols/ChaiGearPrep.h | 2 +- Protocols/ChaiGearPrep.hpp | 4 +- Protocols/ChaiGearShare.h | 1 + Protocols/CowGearShare.h | 1 + Protocols/FakeInput.h | 2 +- Protocols/LowGearKeyGen.hpp | 4 + Protocols/MAC_Check.h | 11 ++ Protocols/MAC_Check.hpp | 34 +++++- Protocols/MAC_Check_Base.h | 3 + Protocols/MaliciousRepPrep.hpp | 1 + Protocols/MamaPrep.h | 3 +- Protocols/MamaPrep.hpp | 4 +- Protocols/MamaShare.h | 8 +- Protocols/NoShare.h | 17 ++- Protocols/ProtocolSetup.h | 12 ++ Protocols/ReplicatedPrep.hpp | 8 +- Protocols/SecureShuffle.hpp | 10 +- Protocols/SemiPrep.h | 2 + Protocols/SemiPrep.hpp | 15 +++ Protocols/Share.h | 6 +- Protocols/ShareInterface.h | 2 +- Protocols/Spdz2kPrep.h | 2 +- Protocols/Spdz2kShare.h | 4 +- Protocols/fake-stuff.h | 2 + Protocols/fake-stuff.hpp | 19 ++- README.md | 168 ++++++++++++++------------ Scripts/build.sh | 12 +- Scripts/test_tutorial.sh | 2 +- Tools/Coordinator.cpp | 78 +++++++++++++ Tools/Coordinator.h | 43 +++++++ Tools/NetworkOptions.h | 2 +- Tools/PointerVector.h | 2 +- Tools/Subroutines.cpp | 4 +- Tools/Subroutines.h | 9 +- Tools/Waksman.h | 9 +- Tools/benchmarking.cpp | 6 +- Tools/benchmarking.h | 2 +- Tools/int.h | 5 + Tools/intrinsics.h | 1 + Tools/octetStream.h | 29 ++++- Tools/random.cpp | 35 +----- Tools/random.h | 66 +++++++++-- Tools/time-func.cpp | 3 +- Utils/Fake-Offline.cpp | 35 +++--- Utils/binary-example.cpp | 2 +- Utils/pairwise-offline.cpp | 2 +- Yao/YaoEvalWire.cpp | 2 +- Yao/YaoEvalWire.h | 2 - Yao/YaoGarbleWire.cpp | 2 +- Yao/YaoGarbleWire.h | 2 - azure-pipelines.yml | 11 +- SimpleOT => deps/SimpleOT | 0 deps/SimplestOT_C | 1 + deps/libOTe | 1 + mpir => deps/mpir | 0 simde => deps/simde | 0 doc/Doxyfile | 2 +- doc/add-protocol.rst | 17 +++ doc/compilation.rst | 182 +++++++++++++++++++++++++++++ doc/conf.py | 10 +- doc/gen-readme.sh | 4 + doc/index.rst | 194 +------------------------------ doc/instructions.rst | 5 +- doc/low-level.rst | 18 ++- doc/machine-learning.rst | 4 +- doc/requirements.txt | 3 +- doc/troubleshooting.rst | 2 +- 171 files changed, 2179 insertions(+), 1023 deletions(-) delete mode 100644 HOSTS.example create mode 100644 Networking/PlayerBuffer.h create mode 100644 Networking/PlayerCtSocket.h create mode 100644 Tools/Coordinator.cpp create mode 100644 Tools/Coordinator.h rename SimpleOT => deps/SimpleOT (100%) create mode 160000 deps/SimplestOT_C create mode 160000 deps/libOTe rename mpir => deps/mpir (100%) rename simde => deps/simde (100%) create mode 100644 doc/compilation.rst create mode 100755 doc/gen-readme.sh diff --git a/.gitignore b/.gitignore index 9a4dd72e2..0d770b1ec 100644 --- a/.gitignore +++ b/.gitignore @@ -119,3 +119,7 @@ _build/ # environment .env + +# temp doc files +doc/readme.md +doc/xml diff --git a/.gitmodules b/.gitmodules index 32dca28be..7dea81d36 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,12 +1,18 @@ [submodule "SimpleOT"] - path = SimpleOT + path = deps/SimpleOT url = https://github.com/mkskeller/SimpleOT [submodule "mpir"] - path = mpir + path = deps/mpir url = https://github.com/wbhart/mpir [submodule "Programs/Circuits"] path = Programs/Circuits url = https://github.com/mkskeller/bristol-fashion [submodule "simde"] - path = simde + path = deps/simde url = https://github.com/simd-everywhere/simde +[submodule "deps/libOTe"] + path = deps/libOTe + url = https://github.com/mkskeller/softspoken-implementation +[submodule "deps/SimplestOT_C"] + path = deps/SimplestOT_C + url = https://github.com/mkskeller/SimplestOT_C diff --git a/BMR/Party.cpp b/BMR/Party.cpp index beddd64cf..0fe11a0f1 100644 --- a/BMR/Party.cpp +++ b/BMR/Party.cpp @@ -249,6 +249,7 @@ FakeProgramParty::FakeProgramParty(int argc, const char** argv) : } cout << "Compiler: " << prev << endl; P = new PlainPlayer(N, 0); + Share::MAC_Check::setup(*P); if (argc > 4) threshold = atoi(argv[4]); cout << "Threshold for multi-threaded evaluation: " << threshold << endl; @@ -280,6 +281,7 @@ FakeProgramParty::~FakeProgramParty() cerr << "Dynamic storage: " << 1e-9 * dynamic_memory.capacity_in_bytes() << " GB" << endl; #endif + Share::MAC_Check::teardown(); } void FakeProgramParty::_compute_prfs_outputs(Key* keys) diff --git a/BMR/RealGarbleWire.h b/BMR/RealGarbleWire.h index 9fa2dc521..115d0bcaa 100644 --- a/BMR/RealGarbleWire.h +++ b/BMR/RealGarbleWire.h @@ -48,8 +48,6 @@ class RealGarbleWire : public PRFRegister static void inputbvec(GC::Processor>& processor, ProcessorBase& input_processor, const vector& args); - RealGarbleWire(const Register& reg) : PRFRegister(reg) {} - void garble(PRFOutputs& prf_output, const RealGarbleWire& left, const RealGarbleWire& right); diff --git a/BMR/RealGarbleWire.hpp b/BMR/RealGarbleWire.hpp index 760a20b89..c9e31fc65 100644 --- a/BMR/RealGarbleWire.hpp +++ b/BMR/RealGarbleWire.hpp @@ -110,7 +110,7 @@ void RealGarbleWire::inputbvec( { GarbleInputter inputter; processor.inputbvec(inputter, input_processor, args, - inputter.party.P->my_num()); + *inputter.party.P); } template diff --git a/BMR/RealProgramParty.hpp b/BMR/RealProgramParty.hpp index 64efc5506..70208ec50 100644 --- a/BMR/RealProgramParty.hpp +++ b/BMR/RealProgramParty.hpp @@ -97,8 +97,6 @@ RealProgramParty::RealProgramParty(int argc, const char** argv) : if (online_opts.live_prep) { mac_key.randomize(prng); - if (T::needs_ot) - BaseMachine::s().ot_setups.push_back({*P, true}); prep = new typename T::LivePrep(0, usage); } else @@ -107,6 +105,7 @@ RealProgramParty::RealProgramParty(int argc, const char** argv) : prep = new Sub_Data_Files(N, prep_dir, usage); } + T::MAC_Check::setup(*P); MC = new typename T::MAC_Check(mac_key); garble_processor.reset(program); @@ -219,6 +218,7 @@ RealProgramParty::~RealProgramParty() delete garble_inputter; delete garble_protocol; cout << "Data sent = " << data_sent * 1e-6 << " MB" << endl; + T::MAC_Check::teardown(); } template diff --git a/BMR/Register.h b/BMR/Register.h index f348f7b7e..6a15a720c 100644 --- a/BMR/Register.h +++ b/BMR/Register.h @@ -152,7 +152,7 @@ class Register { * for pipelining matters. */ - Register(int n_parties); + Register(); void init(int n_parties); void init(int rfd, int n_parties); @@ -278,10 +278,6 @@ class ProgramRegister : public Phase, public Register static int threshold(int) { throw not_implemented(); } - static Register new_reg(); - static Register tmp_reg() { return new_reg(); } - static Register and_reg() { return new_reg(); } - template static void store(NoMemory& dest, const vector >& accesses) { (void)dest; (void)accesses; } @@ -306,8 +302,6 @@ class ProgramRegister : public Phase, public Register void other_input(Input&, int) {} char get_output() { return 0; } - - ProgramRegister(const Register& reg) : Register(reg) {} }; class PRFRegister : public ProgramRegister @@ -319,8 +313,6 @@ class PRFRegister : public ProgramRegister static void load(vector >& accesses, const NoMemory& source); - PRFRegister(const Register& reg) : ProgramRegister(reg) {} - void op(const PRFRegister& left, const PRFRegister& right, Function func); void XOR(const Register& left, const Register& right); void input(party_id_t from, char input = -1); @@ -396,8 +388,6 @@ class EvalRegister : public ProgramRegister static void convcbit(Integer& dest, const GC::Clear& source, GC::Processor>& proc); - EvalRegister(const Register& reg) : ProgramRegister(reg) {} - void op(const ProgramRegister& left, const ProgramRegister& right, Function func); void XOR(const Register& left, const Register& right); @@ -427,8 +417,6 @@ class GarbleRegister : public ProgramRegister static void load(vector >& accesses, const NoMemory& source); - GarbleRegister(const Register& reg) : ProgramRegister(reg) {} - void op(const Register& left, const Register& right, Function func); void XOR(const Register& left, const Register& right); void input(party_id_t from, char value = -1); @@ -452,8 +440,6 @@ class RandomRegister : public ProgramRegister static void load(vector >& accesses, const NoMemory& source); - RandomRegister(const Register& reg) : ProgramRegister(reg) {} - void randomize(); void op(const Register& left, const Register& right, Function func); @@ -469,12 +455,6 @@ class RandomRegister : public ProgramRegister }; -inline Register::Register(int n_parties) : - garbled_entry(n_parties), external(NO_SIGNAL), - mask(NO_SIGNAL), keys(n_parties) -{ -} - inline void KeyVector::operator=(const KeyVector& other) { resize(other.size()); diff --git a/BMR/Register.hpp b/BMR/Register.hpp index bd214a858..617906945 100644 --- a/BMR/Register.hpp +++ b/BMR/Register.hpp @@ -14,15 +14,7 @@ void ProgramRegister::inputbvec(T& processor, ProcessorBase& input_processor, const vector& args) { NoOpInputter inputter; - int my_num = -1; - try - { - my_num = ProgramParty::s().P->my_num(); - } - catch (exception&) - { - } - processor.inputbvec(inputter, input_processor, args, my_num); + processor.inputbvec(inputter, input_processor, args, *ProgramParty::s().P); } template @@ -31,7 +23,7 @@ void EvalRegister::inputbvec(T& processor, ProcessorBase& input_processor, { EvalInputter inputter; processor.inputbvec(inputter, input_processor, args, - ProgramParty::s().P->my_num()); + *ProgramParty::s().P); } template diff --git a/BMR/Register_inline.h b/BMR/Register_inline.h index 6a275da64..7694c464d 100644 --- a/BMR/Register_inline.h +++ b/BMR/Register_inline.h @@ -9,10 +9,10 @@ #include "CommonParty.h" #include "Party.h" - -inline Register ProgramRegister::new_reg() +inline Register::Register() : + garbled_entry(CommonParty::s().get_n_parties()), external(NO_SIGNAL), + mask(NO_SIGNAL), keys(CommonParty::s().get_n_parties()) { - return Register(CommonParty::s().get_n_parties()); } #endif /* BMR_REGISTER_INLINE_H_ */ diff --git a/CHANGELOG.md b/CHANGELOG.md index ac6435805..e8e015348 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,20 @@ The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here. +## 0.3.3 (Aug 25, 2022) + +- Use SoftSpokenOT to avoid unclear security of KOS OT extension candidate +- Fix security bug in MAC check when using multithreading +- Fix security bug to prevent selective failure attack by checking earlier +- Fix security bug in Mama: insufficient sacrifice. +- Inverse permutation (@Quitlox) +- Easier direct compilation (@eriktaubeneck) +- Generally allow element-vector operations +- Increase maximum register size to 2^54 +- Client example in Python +- Uniform base OTs across platforms +- Multithreaded base OT computation +- Faster random bit generation in two-player Semi(2k) + ## 0.3.2 (May 27, 2022) - Secure shuffling @@ -7,7 +22,7 @@ The changelog explains changes pulled through from the private development repos - Documented BGV encryption interface - Optimized matrix multiplication in dealer protocol - Fixed security bug in homomorphic encryption parameter generation -- Fixed Security bug in Temi matrix multiplication +- Fixed security bug in Temi matrix multiplication ## 0.3.1 (Apr 19, 2022) diff --git a/CONFIG b/CONFIG index cef15e0b4..fb9db2009 100644 --- a/CONFIG +++ b/CONFIG @@ -31,24 +31,21 @@ ARCH = -mtune=native -msse4.1 -msse4.2 -maes -mpclmul -mavx -mavx2 -mbmi2 -madx ARCH = -march=native MACHINE := $(shell uname -m) +ARM := $(shell uname -m | grep x86; echo $$?) OS := $(shell uname -s) ifeq ($(MACHINE), x86_64) -# set this to 0 to avoid using AVX for OT ifeq ($(OS), Linux) -CHECK_AVX := $(shell grep -q avx /proc/cpuinfo; echo $$?) -ifeq ($(CHECK_AVX), 0) AVX_OT = 1 else AVX_OT = 0 endif else -AVX_OT = 1 -endif -else ARCH = AVX_OT = 0 endif +USE_KOS = 0 + # allow to set compiler in CONFIG.mine CXX = g++ @@ -87,7 +84,7 @@ else BOOST = -lboost_thread $(MY_BOOST) endif -CFLAGS += $(ARCH) $(MY_CFLAGS) $(GDEBUG) -Wextra -Wall $(OPTIM) -I$(ROOT) -pthread $(PROF) $(DEBUG) $(MOD) $(GF2N_LONG) $(PREP_DIR) $(SSL_DIR) $(SECURE) -std=c++11 -Werror +CFLAGS += $(ARCH) $(MY_CFLAGS) $(GDEBUG) -Wextra -Wall $(OPTIM) -I$(ROOT) -I$(ROOT)/deps -pthread $(PROF) $(DEBUG) $(MOD) $(GF2N_LONG) $(PREP_DIR) $(SSL_DIR) $(SECURE) -std=c++11 -Werror CPPFLAGS = $(CFLAGS) LD = $(CXX) @@ -98,3 +95,10 @@ ifeq ($(USE_NTL),1) CFLAGS += -Wno-error=unused-parameter -Wno-error=deprecated-copy endif endif + +ifeq ($(USE_KOS),1) +CFLAGS += -DUSE_KOS +else +CFLAGS += -std=c++17 +LDLIBS += -llibOTe -lcryptoTools +endif diff --git a/Compiler/GC/instructions.py b/Compiler/GC/instructions.py index e53b71879..2b5ec46ad 100644 --- a/Compiler/GC/instructions.py +++ b/Compiler/GC/instructions.py @@ -342,7 +342,8 @@ class stmcb(base.DirectMemoryWriteInstruction, base.VectorInstruction): code = opcodes['STMCB'] arg_format = ['cb','long'] -class ldmsbi(base.ReadMemoryInstruction, base.VectorInstruction): +class ldmsbi(base.ReadMemoryInstruction, base.VectorInstruction, + base.IndirectMemoryInstruction): """ Copy secret bit memory cell with run-time address to secret bit register. @@ -351,8 +352,10 @@ class ldmsbi(base.ReadMemoryInstruction, base.VectorInstruction): """ code = opcodes['LDMSBI'] arg_format = ['sbw','ci'] + direct = staticmethod(ldmsb) -class stmsbi(base.WriteMemoryInstruction, base.VectorInstruction): +class stmsbi(base.WriteMemoryInstruction, base.VectorInstruction, + base.IndirectMemoryInstruction): """ Copy secret bit register to secret bit memory cell with run-time address. @@ -361,8 +364,10 @@ class stmsbi(base.WriteMemoryInstruction, base.VectorInstruction): """ code = opcodes['STMSBI'] arg_format = ['sb','ci'] + direct = staticmethod(stmsb) -class ldmcbi(base.ReadMemoryInstruction, base.VectorInstruction): +class ldmcbi(base.ReadMemoryInstruction, base.VectorInstruction, + base.IndirectMemoryInstruction): """ Copy clear bit memory cell with run-time address to clear bit register. @@ -371,8 +376,10 @@ class ldmcbi(base.ReadMemoryInstruction, base.VectorInstruction): """ code = opcodes['LDMCBI'] arg_format = ['cbw','ci'] + direct = staticmethod(ldmcb) -class stmcbi(base.WriteMemoryInstruction, base.VectorInstruction): +class stmcbi(base.WriteMemoryInstruction, base.VectorInstruction, + base.IndirectMemoryInstruction): """ Copy clear bit register to clear bit memory cell with run-time address. @@ -381,6 +388,7 @@ class stmcbi(base.WriteMemoryInstruction, base.VectorInstruction): """ code = opcodes['STMCBI'] arg_format = ['cb','ci'] + direct = staticmethod(stmcb) class ldmsdi(base.ReadMemoryInstruction): code = opcodes['LDMSDI'] diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py index 5530432bb..396769e03 100644 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -198,6 +198,8 @@ def __and__(self, other): return 0 elif self.is_long_one(other): return self + elif isinstance(other, _vec): + return other & other.from_vec([self]) else: return self._and(other) @read_mem_value @@ -241,6 +243,13 @@ def zero_if_not(self, condition): return self * condition else: return self * cbit.conv(condition) + def expand(self, length): + if self.n in (length, None): + return self + elif self.n == 1: + return self.get_type(length).bit_compose([self] * length) + else: + raise CompilerError('cannot expand from %s to %s' % (self.n, length)) class cbits(bits): """ Clear bits register. Helper type with limited functionality. """ @@ -295,8 +304,15 @@ def clear_op(self, other, c_inst, ci_inst, op): return op(self, cbits(other)) __add__ = lambda self, other: \ self.clear_op(other, inst.addcb, inst.addcbi, operator.add) - __sub__ = lambda self, other: \ - self.clear_op(-other, inst.addcb, inst.addcbi, operator.add) + def __sub__(self, other): + try: + return self + -other + except: + return type(self)(regint(self) - regint(other)) + def __rsub__(self, other): + return type(self)(other - regint(self)) + def __neg__(self): + return type(self)(-regint(self)) def _xor(self, other): if isinstance(other, (sbits, sbitvec)): return NotImplemented @@ -589,7 +605,15 @@ def trans(cls, rows): rows = list(rows) if len(rows) == 1 and rows[0].n <= rows[0].unit: return rows[0].bit_decompose() - n_columns = rows[0].n + for row in rows: + try: + n_columns = row.n + break + except: + pass + for i in range(len(rows)): + if util.is_zero(rows[i]): + rows[i] = cls.get_type(n_columns)(0) for row in rows: assert(row.n == n_columns) if n_columns == 1 and len(rows) <= cls.unit: @@ -613,7 +637,7 @@ def bit_adder(*args, **kwargs): def ripple_carry_adder(*args, **kwargs): return sbitint.ripple_carry_adder(*args, **kwargs) -class sbitvec(_vec): +class sbitvec(_vec, _bit): """ Vector of registers of secret bits, effectively a matrix of secret bits. This facilitates parallel arithmetic operations in binary circuits. Container types are not supported, use :py:obj:`sbitvec.get_type` for that. @@ -656,6 +680,7 @@ class sbitvec(_vec): [1, 0, 1] """ bit_extend = staticmethod(lambda v, n: v[:n] + [0] * (n - len(v))) + is_clear = False @classmethod def get_type(cls, n): """ Create type for fixed-length vector of registers of secret bits. @@ -691,10 +716,11 @@ def from_vec(cls, vector): res.v = _complement_two_extend(list(vector), n)[:n] return res def __init__(self, other=None, size=None): - assert size in (None, 1) if other is not None: if util.is_constant(other): - self.v = [sbit((other >> i) & 1) for i in range(n)] + t = sbits.get_type(size or 1) + self.v = [t(((other >> i) & 1) * ((1 << t.n) - 1)) + for i in range(n)] elif isinstance(other, _vec): self.v = self.bit_extend(other.v, n) elif isinstance(other, (list, tuple)): @@ -702,6 +728,7 @@ def __init__(self, other=None, size=None): else: self.v = sbits.get_type(n)(other).bit_decompose() assert len(self.v) == n + assert size is None or size == self.v[0].n @classmethod def load_mem(cls, address, size=None): if size not in (None, 1): @@ -733,8 +760,9 @@ def store_in_mem(self, address): def reveal(self): return util.untuplify([x.reveal() for x in self.elements()]) @classmethod - def two_power(cls, nn): - return cls.from_vec([0] * nn + [1] + [0] * (n - nn - 1)) + def two_power(cls, nn, size=1): + return cls.from_vec( + [0] * nn + [sbits.get_type(size)().long_one()] + [0] * (n - nn - 1)) def coerce(self, other): if util.is_constant(other): return self.from_vec(util.bit_decompose(other, n)) @@ -818,16 +846,14 @@ def coerce(self, other): return other def __xor__(self, other): other = self.coerce(other) - return self.from_vec(x ^ y for x, y in zip(self.v, other)) + return self.from_vec(x ^ y for x, y in zip(*self.expand(other))) def __and__(self, other): - return self.from_vec(x & y for x, y in zip(self.v, other.v)) + return self.from_vec(x & y for x, y in zip(*self.expand(other))) + def __invert__(self): + return self.from_vec(~x for x in self.v) def if_else(self, x, y): assert(len(self.v) == 1) - try: - return self.from_vec(util.if_else(self.v[0], a, b) \ - for a, b in zip(x, y)) - except: - return util.if_else(self.v[0], x, y) + return util.if_else(self.v[0], x, y) def __iter__(self): return iter(self.v) def __len__(self): @@ -890,6 +916,24 @@ def tree_reduce(self, function): elements = red.elements() elements += odd return self.from_vec(sbitvec(elements).v) + @classmethod + def comp_result(cls, x): + return cls.get_type(1).from_vec([x]) + def expand(self, other, expand=True): + m = 1 + for x in itertools.chain(self.v, other.v if isinstance(other, sbitvec) else []): + try: + m = max(m, x.n) + except: + pass + res = [] + for y in self, other: + if isinstance(y, int): + res.append([x * sbits.get_type(m)().long_one() + for x in util.bit_decompose(y, len(self.v))]) + else: + res.append([x.expand(m) if (expand and isinstance(x, bits)) else x for x in y.v]) + return res class bit(object): n = 1 @@ -1139,7 +1183,7 @@ def pow2(self, k): :param k: bit length of input """ return _sbitintbase.pow2(self, k) -class sbitintvec(sbitvec, _number, _bitint, _sbitintbase): +class sbitintvec(sbitvec, _bitint, _number, _sbitintbase): """ Vector of signed integers for parallel binary computation:: @@ -1176,7 +1220,8 @@ def __add__(self, other): return self other = self.coerce(other) assert(len(self.v) == len(other.v)) - v = sbitint.bit_adder(self.v, other.v) + a, b = self.expand(other) + v = sbitint.bit_adder(a, b) return self.from_vec(v) __radd__ = __add__ def __mul__(self, other): @@ -1184,7 +1229,7 @@ def __mul__(self, other): return self.from_vec(other * x for x in self.v) elif isinstance(other, sbitfixvec): return NotImplemented - other_bits = util.bit_decompose(other) + _, other_bits = self.expand(other, False) m = float('inf') for x in itertools.chain(self.v, other_bits): try: @@ -1228,6 +1273,8 @@ class cbitfix(object): store_in_mem = lambda self, *args: self.v.store_in_mem(*args) @classmethod def _new(cls, value): + if isinstance(value, list): + return [cls._new(x) for x in value] res = cls() if cls.k < value.unit: bits = value.bit_decompose(cls.k) diff --git a/Compiler/comparison.py b/Compiler/comparison.py index 84bdd22b6..1a139ef6d 100644 --- a/Compiler/comparison.py +++ b/Compiler/comparison.py @@ -87,15 +87,14 @@ def LtzRing(a, k): carry = CarryOutRawLE(*reversed(list(x[:-1] for x in summands))) msb = carry ^ summands[0][-1] ^ summands[1][-1] return sint.conv(msb) - return - elif program.options.ring: + else: from . import floatingpoint require_ring_size(k, 'comparison') m = k - 1 shift = int(program.options.ring) - k r_prime, r_bin = MaskingBitsInRing(k) tmp = a - r_prime - c_prime = (tmp << shift).reveal() >> shift + c_prime = (tmp << shift).reveal(False) >> shift a = r_bin[0].bit_decompose_clear(c_prime, m) b = r_bin[:m] u = CarryOutRaw(a[::-1], b[::-1]) @@ -190,7 +189,7 @@ def TruncLeakyInRing(a, k, m, signed): r = sint.bit_compose(r_bits) if signed: a += (1 << (k - 1)) - shifted = ((a << (n_shift - m)) + (r << n_shift)).reveal() + shifted = ((a << (n_shift - m)) + (r << n_shift)).reveal(False) masked = shifted >> n_shift u = sint() BitLTL(u, masked, r_bits[:n_bits], 0) @@ -231,7 +230,7 @@ def Mod2mRing(a_prime, a, k, m, signed): shift = int(program.options.ring) - m r_prime, r_bin = MaskingBitsInRing(m, True) tmp = a + r_prime - c_prime = (tmp << shift).reveal() >> shift + c_prime = (tmp << shift).reveal(False) >> shift u = sint() BitLTL(u, c_prime, r_bin[:m], 0) res = (u << m) + c_prime - r_prime @@ -261,7 +260,7 @@ def Mod2mField(a_prime, a, k, m, kappa, signed): t[1] = a adds(t[2], t[0], t[1]) adds(t[3], t[2], r_prime) - asm_open(c, t[3]) + asm_open(True, c, t[3]) modc(c_prime, c, c2m) if const_rounds: BitLTC1(u, c_prime, r, kappa) @@ -510,7 +509,7 @@ def PreMulC_with_inverses_and_vectors(p, a): movs(w[0], r[0]) movs(a_vec[0], a[0]) vmuls(k, t[0], w, a_vec) - vasm_open(k, m, t[0]) + vasm_open(k, True, m, t[0]) PreMulC_end(p, a, c, m, z) def PreMulC_with_inverses(p, a): @@ -538,7 +537,7 @@ def PreMulC_with_inverses(p, a): w[1][0] = r[0][0] for i in range(k): muls(t[0][i], w[1][i], a[i]) - asm_open(m[i], t[0][i]) + asm_open(True, m[i], t[0][i]) PreMulC_end(p, a, c, m, z) def PreMulC_without_inverses(p, a): @@ -563,7 +562,7 @@ def PreMulC_without_inverses(p, a): #adds(tt[0][i], t[0][i], a[i]) #subs(tt[1][i], tt[0][i], a[i]) #startopen(tt[1][i]) - asm_open(u[i], t[0][i]) + asm_open(True, u[i], t[0][i]) for i in range(k-1): muls(v[i], r[i+1], s[i]) w[0] = r[0] @@ -579,7 +578,7 @@ def PreMulC_without_inverses(p, a): mulm(z[i], s[i], u_inv[i]) for i in range(k): muls(t[1][i], w[i], a[i]) - asm_open(m[i], t[1][i]) + asm_open(True, m[i], t[1][i]) PreMulC_end(p, a, c, m, z) def PreMulC_end(p, a, c, m, z): @@ -646,7 +645,7 @@ def Mod2(a_0, a, k, kappa, signed): t[1] = a adds(t[2], t[0], t[1]) adds(t[3], t[2], r_prime) - asm_open(c, t[3]) + asm_open(True, c, t[3]) from . import floatingpoint c_0 = floatingpoint.bits(c, 1)[0] mulci(tc, c_0, 2) diff --git a/Compiler/compilerLib.py b/Compiler/compilerLib.py index eb800ba48..4a4706ff6 100644 --- a/Compiler/compilerLib.py +++ b/Compiler/compilerLib.py @@ -181,7 +181,8 @@ def build_option_parser(self): action="store_true", dest="invperm", help="speedup inverse permutation (only use in two-party, " - "semi-honest environment)") + "semi-honest environment)" + ) parser.add_option( "-C", "--CISC", @@ -244,11 +245,9 @@ def build_vars(self): self.VARS[op.__name__] = op # add open and input separately due to name conflict - self.VARS["open"] = instructions.asm_open self.VARS["vopen"] = instructions.vasm_open self.VARS["gopen"] = instructions.gasm_open self.VARS["vgopen"] = instructions.vgasm_open - self.VARS["input"] = instructions.asm_input self.VARS["ginput"] = instructions.gasm_input self.VARS["comparison"] = comparison @@ -268,7 +267,6 @@ def build_vars(self): "sgf2nuint", "sgf2nuint32", "sgf2nfloat", - "sfloat", "cfloat", "squant", ]: @@ -276,6 +274,9 @@ def build_vars(self): def prep_compile(self, name=None): self.parse_args() + if len(self.args) < 1 and name is None: + self.parser.print_help() + exit(1) self.build_program(name=name) self.build_vars() @@ -372,7 +373,7 @@ def compile_func(self): ) self.prep_compile(self.compile_name) print( - f"Compiling: {self.compile_name} from " f"func {self.compile_func.__name__}" + "Compiling: {} from {}".format(self.compile_name, self.compile_func.__name__) ) self.compile_function() self.finalize_compile() diff --git a/Compiler/floatingpoint.py b/Compiler/floatingpoint.py index d3d3f8c50..94a47f1bf 100644 --- a/Compiler/floatingpoint.py +++ b/Compiler/floatingpoint.py @@ -28,7 +28,7 @@ def shift_two(n, pos): def maskRing(a, k): shift = int(program.Program.prog.options.ring) - k - if program.Program.prog.use_edabit: + if program.Program.prog.use_edabit(): r_prime, r = types.sint.get_edabit(k) elif program.Program.prog.use_dabit: rr, r = zip(*(types.sint.get_dabit() for i in range(k))) @@ -36,7 +36,7 @@ def maskRing(a, k): else: r = [types.sint.get_random_bit() for i in range(k)] r_prime = types.sint.bit_compose(r) - c = ((a + r_prime) << shift).reveal() >> shift + c = ((a + r_prime) << shift).reveal(False) >> shift return c, r def maskField(a, k, kappa): @@ -47,7 +47,7 @@ def maskField(a, k, kappa): comparison.PRandM(r_dprime, r_prime, r, k, k, kappa) # always signed due to usage in equality testing a += two_power(k) - asm_open(c, a + two_power(k) * r_dprime + r_prime) + asm_open(True, c, a + two_power(k) * r_dprime + r_prime) return c, r @instructions_base.ret_cisc @@ -233,7 +233,7 @@ def Inv(a): ldi(one, 1) inverse(t[0], t[1]) s = t[0]*a - asm_open(c[0], s) + asm_open(True, c[0], s) # avoid division by zero for benchmarking divc(c[1], one, c[0]) #divc(c[1], c[0], one) @@ -281,7 +281,7 @@ def BitDecRingRaw(a, k, m): else: r_bits = [types.sint.get_random_bit() for i in range(m)] r = types.sint.bit_compose(r_bits) - shifted = ((a - r) << n_shift).reveal() + shifted = ((a - r) << n_shift).reveal(False) masked = shifted >> n_shift bits = r_bits[0].bit_adder(r_bits, masked.bit_decompose(m)) return bits @@ -299,7 +299,7 @@ def BitDecFieldRaw(a, k, m, kappa, bits_to_compute=None): r = [types.sint() for i in range(m)] comparison.PRandM(r_dprime, r_prime, r, k, m, kappa) pow2 = two_power(k + kappa) - asm_open(c, pow2 + two_power(k) + a - two_power(m)*r_dprime - r_prime) + asm_open(True, c, pow2 + two_power(k) + a - two_power(m)*r_dprime - r_prime) res = r[0].bit_adder(r, list(r[0].bit_decompose_clear(c,m))) instructions_base.reset_global_vector_size() return res @@ -341,10 +341,10 @@ def B2U_from_Pow2(pow2a, l, kappa): if program.Program.prog.options.ring: n_shift = int(program.Program.prog.options.ring) - l assert n_shift > 0 - c = ((pow2a + types.sint.bit_compose(r)) << n_shift).reveal() >> n_shift + c = ((pow2a + types.sint.bit_compose(r)) << n_shift).reveal(False) >> n_shift else: comparison.PRandInt(t, kappa) - asm_open(c, pow2a + two_power(l) * t + + asm_open(True, c, pow2a + two_power(l) * t + sum(two_power(i) * r[i] for i in range(l))) comparison.program.curr_tape.require_bit_length(l + kappa) c = list(r_bits[0].bit_decompose_clear(c, l)) @@ -386,11 +386,11 @@ def Trunc(a, l, m, kappa=None, compute_modulo=False, signed=False): r_dprime += t1 - t2 if program.Program.prog.options.ring: n_shift = int(program.Program.prog.options.ring) - l - c = ((a + r_dprime + r_prime) << n_shift).reveal() >> n_shift + c = ((a + r_dprime + r_prime) << n_shift).reveal(False) >> n_shift else: comparison.PRandInt(rk, kappa) r_dprime += two_power(l) * rk - asm_open(c, a + r_dprime + r_prime) + asm_open(True, c, a + r_dprime + r_prime) for i in range(1,l): ci[i] = c % two_power(i) c_dprime = sum(ci[i]*(x[i-1] - x[i]) for i in range(1,l)) @@ -416,7 +416,7 @@ def TruncInRing(to_shift, l, pow2m): rev *= pow2m r_bits = [types.sint.get_random_bit() for i in range(l)] r = types.sint.bit_compose(r_bits) - shifted = (rev - (r << n_shift)).reveal() + shifted = (rev - (r << n_shift)).reveal(False) masked = shifted >> n_shift bits = types.intbitint.bit_adder(r_bits, masked.bit_decompose(l)) return types.sint.bit_compose(reversed(bits)) @@ -457,7 +457,7 @@ def Int2FL(a, gamma, l, kappa=None): v = t.right_shift(gamma - l - 1, gamma - 1, kappa, signed=False) else: v = 2**(l-gamma+1) * t - p = (p + gamma - 1 - l) * (1 -z) + p = (p + gamma - 1 - l) * z.bit_not() return v, p, z, s def FLRound(x, mode): @@ -530,7 +530,7 @@ def TruncPrRing(a, k, m, signed=True): msb = r_bits[-1] n_shift = n_ring - (k + 1) tmp = a + r - masked = (tmp << n_shift).reveal() + masked = (tmp << n_shift).reveal(False) shifted = (masked << 1 >> (n_shift + m + 1)) overflow = msb.bit_xor(masked >> (n_ring - 1)) res = shifted - upper + \ @@ -551,7 +551,7 @@ def TruncPrField(a, k, m, kappa=None): k, m, kappa, use_dabit=False) two_to_m = two_power(m) r = two_to_m * r_dprime + r_prime - c = (b + r).reveal() + c = (b + r).reveal(False) c_prime = c % two_to_m a_prime = c_prime - r_prime d = (a - a_prime) / two_to_m @@ -667,14 +667,14 @@ def get_bits_loop(): def _(): for i in range(bit_length): tbits[j][i].link(sint.get_random_bit()) - c = regint(BITLT(tbits[j], pbits, bit_length).reveal()) + c = regint(BITLT(tbits[j], pbits, bit_length).reveal(False)) done[j].link(c) return (sum(done) != a.size) for j in range(a.size): for i in range(bit_length): movs(bbits[i][j], tbits[j][i]) b = sint.bit_compose(bbits) - c = (a-b).reveal() + c = (a-b).reveal(False) cmodp = c t = bbits[0].bit_decompose_clear(p - c, bit_length) c = longint(c, bit_length) diff --git a/Compiler/instructions.py b/Compiler/instructions.py index 058b6ff4f..91809ba44 100644 --- a/Compiler/instructions.py +++ b/Compiler/instructions.py @@ -387,6 +387,14 @@ class use(base.Instruction): code = base.opcodes['USE'] arg_format = ['int','int','int'] + @classmethod + def get_usage(cls, args): + from .program import field_types, data_types + from .util import find_in_dict + return {(find_in_dict(field_types, args[0].i), + find_in_dict(data_types, args[1].i)): + args[2].i} + class use_inp(base.Instruction): """ Input usage. Necessary to avoid reusage while using preprocessing from files. @@ -398,6 +406,13 @@ class use_inp(base.Instruction): code = base.opcodes['USE_INP'] arg_format = ['int','int','int'] + @classmethod + def get_usage(cls, args): + from .program import field_types, data_types + from .util import find_in_dict + return {(find_in_dict(field_types, args[0].i), 'input', args[1].i): + args[2].i} + class use_edabit(base.Instruction): """ edaBit usage. Necessary to avoid reusage while using preprocessing from files. Also used to multithreading for expensive @@ -410,6 +425,10 @@ class use_edabit(base.Instruction): code = base.opcodes['USE_EDABIT'] arg_format = ['int','int','int'] + @classmethod + def get_usage(cls, args): + return {('sedabit' if args[0].i else 'edabit', args[1].i): args[2].i} + class use_matmul(base.Instruction): """ Matrix multiplication usage. Used for multithreading of preprocessing. @@ -471,6 +490,11 @@ class use_prep(base.Instruction): code = base.opcodes['USE_PREP'] arg_format = ['str','int'] + @classmethod + def get_usage(cls, args): + return {('gf2n' if cls.__name__ == 'guse_prep' else 'modp', + args[0].str): args[1].i} + class nplayers(base.Instruction): """ Store number of players in clear integer register. @@ -783,30 +807,6 @@ def has_var_args(self): return True -### -### Special GF(2) arithmetic instructions -### - -@base.vectorize -class gmulbitc(base.MulBase): - r""" Clear GF(2^n) by clear GF(2) multiplication """ - __slots__ = [] - code = base.opcodes['GMULBITC'] - arg_format = ['cgw','cg','cg'] - - def is_gf2n(self): - return True - -@base.vectorize -class gmulbitm(base.MulBase): - r""" Secret GF(2^n) by clear GF(2) multiplication """ - __slots__ = [] - code = base.opcodes['GMULBITM'] - arg_format = ['sgw','sg','cg'] - - def is_gf2n(self): - return True - ### ### Arithmetic with immediate values ### @@ -1707,6 +1707,7 @@ class writesockets(base.IOInstruction): from registers into a socket for a specified client id. If the protocol uses MACs, the client should be different for every party. + :param: number of arguments to follow :param: client id (regint) :param: message type (must be 0) :param: vector size (int) @@ -2162,14 +2163,19 @@ class gconvgf2n(base.Instruction): class asm_open(base.VarArgsInstruction): """ Reveal secret registers (vectors) to clear registers (vectors). - :param: number of argument to follow (multiple of two) + :param: number of argument to follow (odd number) + :param: check after opening (0/1) :param: destination (cint) :param: source (sint) :param: (repeat the last two)... """ __slots__ = [] code = base.opcodes['OPEN'] - arg_format = tools.cycle(['cw','s']) + arg_format = tools.chain(['int'], tools.cycle(['cw','s'])) + + def merge(self, other): + self.args[0] |= other.args[0] + self.args += other.args[1:] @base.gf2n @base.vectorize @@ -2415,12 +2421,17 @@ class shuffle_base(base.DataInstruction): def logn(n): return int(math.ceil(math.log(n, 2))) + @classmethod + def n_swaps(cls, n): + logn = cls.logn(n) + return logn * 2 ** logn - 2 ** logn + 1 + def add_gen_usage(self, req_node, n): # hack for unknown usage req_node.increment(('bit', 'inverse'), float('inf')) # minimal usage with two relevant parties logn = self.logn(n) - n_switches = logn * 2 ** logn + n_switches = self.n_swaps(n) for i in range(self.n_relevant_parties): req_node.increment((self.field_type, 'input', i), n_switches) # multiplications for bit check @@ -2430,7 +2441,7 @@ def add_gen_usage(self, req_node, n): def add_apply_usage(self, req_node, n, record_size): req_node.increment(('bit', 'inverse'), float('inf')) logn = self.logn(n) - n_switches = logn * 2 ** logn * self.n_relevant_parties + n_switches = self.n_swaps(n) * self.n_relevant_parties if n != 2 ** logn: record_size += 1 req_node.increment((self.field_type, 'triple'), @@ -2548,7 +2559,7 @@ def expand(self): c = [program.curr_block.new_reg('c') for i in range(2)] square(s[0], s[1]) subs(s[2], self.args[1], s[0]) - asm_open(c[0], s[2]) + asm_open(False, c[0], s[2]) mulc(c[1], c[0], c[0]) mulm(s[3], self.args[1], c[0]) adds(s[4], s[3], s[3]) diff --git a/Compiler/instructions_base.py b/Compiler/instructions_base.py index f7aa48f9b..fb60d908b 100644 --- a/Compiler/instructions_base.py +++ b/Compiler/instructions_base.py @@ -542,7 +542,7 @@ def add_usage(self, *args): def get_bytes(self): assert len(self.kwargs) < 2 - res = int_to_bytes(opcodes['CISC']) + res = LongArgFormat.encode(opcodes['CISC']) res += int_to_bytes(sum(len(x[0]) + 2 for x in self.calls) + 1) name = self.function.__name__ String.check(name) @@ -720,7 +720,7 @@ def __str__(self): class LongArgFormat(IntArgFormat): @classmethod def encode(cls, arg): - return struct.pack('>Q', arg) + return list(struct.pack('>Q', arg)) def __init__(self, f): self.i = struct.unpack('>Q', f.read(8))[0] @@ -741,6 +741,8 @@ def check(cls, arg): class PlayerNoAF(IntArgFormat): @classmethod def check(cls, arg): + if not util.is_constant(arg): + raise CompilerError('Player number must be known at compile time') super(PlayerNoAF, cls).check(arg) if arg > 256: raise ArgumentError(arg, 'Player number > 256') @@ -823,7 +825,7 @@ def get_code(self, prefix=0): return (prefix << self.code_length) + self.code def get_encoding(self): - enc = int_to_bytes(self.get_code()) + enc = LongArgFormat.encode(self.get_code()) # add the number of registers if instruction flagged as has var args if self.has_var_args(): enc += int_to_bytes(len(self.args)) @@ -958,7 +960,7 @@ def __init__(self, f): except AttributeError: pass read = lambda: struct.unpack('>I', f.read(4))[0] - full_code = read() + full_code = struct.unpack('>Q', f.read(8))[0] code = full_code % (1 << Instruction.code_length) self.size = full_code >> Instruction.code_length self.type = cls.reverse_opcodes[code] diff --git a/Compiler/library.py b/Compiler/library.py index 799f85d29..1da50e9c9 100644 --- a/Compiler/library.py +++ b/Compiler/library.py @@ -243,6 +243,10 @@ def store_in_mem(value, address): try: value.store_in_mem(address) except AttributeError: + if isinstance(value, (list, tuple)): + for i, x in enumerate(value): + store_in_mem(x, address + i) + return # legacy if value.is_clear: if isinstance(address, cint): @@ -261,11 +265,13 @@ def reveal(secret): try: return secret.reveal() except AttributeError: + if secret.is_clear: + return secret if secret.is_gf2n: res = cgf2n() else: res = cint() - instructions.asm_open(res, secret) + instructions.asm_open(True, res, secret) return res @vectorize @@ -883,10 +889,10 @@ def loop_fn(i): def for_range(start, stop=None, step=None): """ Decorator to execute loop bodies consecutively. Arguments work as - in Python :py:func:`range`, but they can by any public + in Python :py:func:`range`, but they can be any public integer. Information has to be passed out via container types such - as :py:class:`~Compiler.types.Array` or declaring registers as - :py:obj:`global`. Note that changing Python data structures such + as :py:class:`~Compiler.types.Array` or using :py:func:`update`. + Note that changing Python data structures such as lists within the loop is not possible, but the compiler cannot warn about this. @@ -901,13 +907,11 @@ def for_range(start, stop=None, step=None): @for_range(n) def _(i): a[i] = i - global x - x += 1 + x.update(x + 1) Note that you cannot overwrite data structures such as - :py:class:`~Compiler.types.Array` in a loop even when using - :py:obj:`global`. Use :py:func:`~Compiler.types.Array.assign` - instead. + :py:class:`~Compiler.types.Array` in a loop. Use + :py:func:`~Compiler.types.Array.assign` instead. """ def decorator(loop_body): range_loop(loop_body, start, stop, step) @@ -1518,6 +1522,11 @@ class State: pass state = State() if callable(condition): condition = condition() + try: + if not condition.is_clear: + raise CompilerError('cannot branch on secret values') + except AttributeError: + pass state.condition = regint.conv(condition) state.start_block = instructions.program.curr_block state.req_child = get_tape().open_scope(lambda x: x[0].max(x[1]), \ @@ -1889,7 +1898,7 @@ def FPDiv(a, b, k, f, kappa, simplex_flag=False, nearest=False): theta = int(ceil(log(k/3.5) / log(2))) base.set_global_vector_size(b.size) - alpha = b.get_type(2 * k).two_power(2*f) + alpha = b.get_type(2 * k).two_power(2*f, size=b.size) w = AppRcr(b, k, f, kappa, simplex_flag, nearest).extend(2 * k) x = alpha - b.extend(2 * k) * w base.reset_global_vector_size() diff --git a/Compiler/ml.py b/Compiler/ml.py index 02f0f04ed..173c2eac0 100644 --- a/Compiler/ml.py +++ b/Compiler/ml.py @@ -148,7 +148,7 @@ def argmax(x): """ Compute index of maximum element. :param x: iterable - :returns: sint + :returns: sint or 0 if :py:obj:`x` has length 1 """ def op(a, b): comp = (a[1] > b[1]) @@ -164,7 +164,7 @@ def softmax(x): return softmax_from_exp(exp_for_softmax(x)[0]) def exp_for_softmax(x): - m = util.max(x) + m = util.max(x) - get_limit(x[0]) + 1 + math.log(len(x), 2) mv = m.expand_to_vector(len(x)) try: x = x.get_vector() @@ -2384,6 +2384,11 @@ def output_weights(self): for layer in self.layers: layer.output_weights() + def summary(self): + sizes = [var.total_size() for var in self.thetas] + print(sizes) + print('Trainable params:', sum(sizes)) + class Adam(Optimizer): """ Adam/AMSgrad optimizer. @@ -2653,9 +2658,7 @@ def trainable_variables(self): return list(self.opt.thetas) def summary(self): - sizes = [var.total_size() for var in self.trainable_variables] - print(sizes) - print('Trainable params:', sum(sizes)) + self.opt.summary() def build(self, input_shape, batch_size=128): data_input_shape = input_shape diff --git a/Compiler/mpc_math.py b/Compiler/mpc_math.py index 47253dc43..abdaf1233 100644 --- a/Compiler/mpc_math.py +++ b/Compiler/mpc_math.py @@ -295,7 +295,6 @@ class my_fix(type(a)): intbitint = types.intbitint n_shift = int(types.program.options.ring) - a.k if types.program.use_split(): - assert not zero_output from Compiler.GC.types import sbitvec if types.program.use_split() == 3: x = a.v.split_to_two_summands(a.k) @@ -327,6 +326,7 @@ class my_fix(type(a)): s = sint.conv(bits[-1]) lower = sint.bit_compose(sint.conv(b) for b in bits[:a.f]) higher_bits = bits[a.f:n_bits] + bits_to_check = bits[n_bits:-1] else: if types.program.use_edabit(): l = sint.get_edabit(a.f, True) @@ -338,7 +338,7 @@ class my_fix(type(a)): r_bits = [sint.get_random_bit() for i in range(a.k)] r = sint.bit_compose(r_bits) lower_r = sint.bit_compose(r_bits[:a.f]) - shifted = ((a.v - r) << n_shift).reveal() + shifted = ((a.v - r) << n_shift).reveal(False) masked_bits = (shifted >> n_shift).bit_decompose(a.k) lower_overflow = comparison.CarryOutRaw(masked_bits[a.f-1::-1], r_bits[a.f-1::-1]) diff --git a/Compiler/program.py b/Compiler/program.py index d7b57db90..f92ab4971 100644 --- a/Compiler/program.py +++ b/Compiler/program.py @@ -545,7 +545,7 @@ def use_invperm(self, change=None): """ if change is None: if not self._invperm: - self.relevant_opts.add('invperm') + self.relevant_opts.add("invperm") return self._invperm else: self._invperm = change @@ -1276,7 +1276,7 @@ class Register(_no_truth): "can_eliminate", "duplicates", ] - maximum_size = 2 ** (32 - inst_base.Instruction.code_length) - 1 + maximum_size = 2 ** (64 - inst_base.Instruction.code_length) - 1 def __init__(self, reg_type, program, size=None, i=None): """Creates a new register. @@ -1382,6 +1382,20 @@ def link(self, other): for dup in self.duplicates: dup.duplicates = self.duplicates + def update(self, other): + """ + Update register. Useful in loops like + :py:func:`~Compiler.library.for_range`. + + :param other: any convertible type + + """ + other = self.conv(other) + if self.program != other.program: + raise CompilerError( + 'cannot update register with one from another thread') + self.link(other) + @property def is_gf2n(self): return ( diff --git a/Compiler/types.py b/Compiler/types.py index d63295c8f..6a150beee 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -127,8 +127,14 @@ def vectorized_operation(self, *args, **kwargs): if (isinstance(args[0], Tape.Register) or isinstance(args[0], sfloat)) \ and not isinstance(args[0], bits) \ and args[0].size != self.size: - raise VectorMismatch('Different vector sizes of operands: %d/%d' - % (self.size, args[0].size)) + if min(args[0].size, self.size) == 1: + size = max(args[0].size, self.size) + self = self.expand_to_vector(size) + args = list(args) + args[0] = args[0].expand_to_vector(size) + else: + raise VectorMismatch('Different vector sizes of operands: %d/%d' + % (self.size, args[0].size)) set_global_vector_size(self.size) try: res = operation(self, *args, **kwargs) @@ -249,8 +255,11 @@ def __mul__(self, other): try: return self.mul(other) except VectorMismatch: - # try reverse multiplication - return NotImplemented + if type(self) != type(other) and 1 in (self.size, other.size): + # try reverse multiplication + return NotImplemented + else: + raise __radd__ = __add__ __rmul__ = __mul__ @@ -1658,6 +1667,8 @@ def binary_output(self, player=None): """ if player == None: player = -1 + if not util.is_constant(player): + raise CompilerError('Player number must be known at compile time') intoutput(player, self) class localint(Tape._no_truth): @@ -2081,12 +2092,15 @@ def __rsub__(self, other): return self.secret_op(other, subs, submr, subsfi, True) __rsub__.__doc__ = __sub__.__doc__ - @vectorize def __truediv__(self, other): """ Secret field division. :param other: any compatible type """ - return self * (self.clear_type(1) / other) + try: + one = self.clear_type(1, size=other.size) + except AttributeError: + one = self.clear_type(1) + return self * (one / other) @vectorize def __rtruediv__(self, other): @@ -2113,12 +2127,12 @@ def secure_shuffle(self, unit_size=1): @set_instruction_type @vectorize - def reveal(self): + def reveal(self, check=True): """ Reveal secret value publicly. :rtype: relevant clear type """ res = self.clear_type() - asm_open(res, self) + asm_open(check, res, self) return res @set_instruction_type @@ -2166,9 +2180,7 @@ class sint(_secret, _int): signed integer in a restricted range, see below. The same holds for ``abs()``, shift operators (``<<, >>``), modulo (``%``), and exponentation (``**``). Modulo only works if the right-hand - operator is a compile-time power of two, and exponentiation only - works if the base is two or if the exponent is a compile-time - integer. + operator is a compile-time power of two. Most non-linear operations require compile-time parameters for bit length and statistical security. They default to the global @@ -2672,7 +2684,7 @@ def trunc_zeros(self, n_zeros, bit_length=None, signed=True): return comparison.TruncZeros(self, bit_length, n_zeros, signed) @staticmethod - def two_power(n): + def two_power(n, size=None): return floatingpoint.two_power(n) def split_to_n_summands(self, length, n): @@ -2690,7 +2702,6 @@ def split_to_two_summands(self, length, get_carry=False): columns = self.split_to_n_summands(length, n) return _bitint.wallace_tree_without_finish(columns, get_carry) - @vectorize def reveal_to(self, player): """ Reveal secret value to :py:obj:`player`. @@ -2698,13 +2709,14 @@ def reveal_to(self, player): :returns: :py:class:`personal` """ if not util.is_constant(player): - secret_mask = sint() - player_mask = cint() - inputmaskreg(secret_mask, player_mask, regint.conv(player)) + secret_mask = sint(size=self.size) + player_mask = cint(size=self.size) + inputmaskreg(secret_mask, player_mask, + regint.conv(player).expand_to_vector(self.size)) return personal(player, - (self + secret_mask).reveal() - player_mask) + (self + secret_mask).reveal(False) - player_mask) else: - res = personal(player, self.clear_type()) + res = personal(player, self.clear_type(size=self.size)) privateoutput(self.size, player, res._v, self) return res @@ -2856,6 +2868,10 @@ def __rsub__(self, other): else: return super(sintbit, self).__rsub__(other) + __rand__ = __and__ + __rxor__ = __xor__ + __ror__ = __or__ + class sgf2n(_secret, _gf2n): """ Secret :math:`\mathrm{GF}(2^n)` value. n is chosen at runtime. A @@ -2873,6 +2889,7 @@ class sgf2n(_secret, _gf2n): instruction_type = 'gf2n' clear_type = cgf2n reg_type = 'sg' + long_one = staticmethod(lambda: 1) @classmethod def get_type(cls, length): @@ -3022,6 +3039,7 @@ class _bitint(Tape._no_truth): bits = None log_rounds = False linear_rounds = False + comp_result = staticmethod(lambda x: x) @staticmethod def half_adder(a, b): @@ -3241,12 +3259,16 @@ def wallace_reduction(cls, a, b, c, get_carry=True): del carries[-1] return sums, carries + def expand(self, other): + a = self.bit_decompose() + b = util.bit_decompose(other, self.n_bits) + return a, b + def __sub__(self, other): if type(other) == sgf2n: raise CompilerError('Unclear subtraction') - a = self.bit_decompose() - b = util.bit_decompose(other, self.n_bits) from util import bit_not, bit_and, bit_xor + a, b = self.expand(other) n = 1 for x in (a + b): try: @@ -3293,8 +3315,7 @@ def prep_comparison(a, b): a[-1], b[-1] = b[-1], a[-1] def comparison(self, other, const_rounds=False, index=None): - a = self.bit_decompose() - b = util.bit_decompose(other, self.n_bits) + a, b = self.expand(other) self.prep_comparison(a, b) if const_rounds: return self.get_highest_different_bits(a, b, index) @@ -3304,30 +3325,33 @@ def comparison(self, other, const_rounds=False, index=None): def __lt__(self, other): if program.options.comparison == 'log': x, not_equal = self.comparison(other) - return util.if_else(not_equal, x, 0) + res = util.if_else(not_equal, x, 0) else: - return self.comparison(other, True, 1) + res = self.comparison(other, True, 1) + return self.comp_result(res) def __le__(self, other): if program.options.comparison == 'log': x, not_equal = self.comparison(other) - return util.if_else(not_equal, x, 1) + res = util.if_else(not_equal, x, x.long_one()) else: - return 1 - self.comparison(other, True, 0) + res = self.comparison(other, True, 0).bit_not() + return self.comp_result(res) def __ge__(self, other): - return 1 - (self < other) + return (self < other).bit_not() def __gt__(self, other): - return 1 - (self <= other) + return (self <= other).bit_not() def __eq__(self, other, bit_length=None, security=None): diff = self ^ other - diff_bits = [1 - x for x in diff.bit_decompose()[:bit_length]] - return floatingpoint.KMul(diff_bits) + diff_bits = [x.bit_not() for x in diff.bit_decompose()[:bit_length]] + return self.comp_result(util.tree_reduce(lambda x, y: x.bit_and(y), + diff_bits)) def __ne__(self, other): - return 1 - (self == other) + return (self == other).bit_not() equal = __eq__ @@ -3881,7 +3905,6 @@ def print_plain(self): def output_if(self, cond): cond_print_plain(cint.conv(cond), self.v, cint(-self.f, size=self.size)) - @vectorize def binary_output(self, player=None): """ Write double-precision floating-point number to ``Player-Data/Binary-Output-P-``. @@ -3890,7 +3913,11 @@ def binary_output(self, player=None): """ if player == None: player = -1 + if not util.is_constant(player): + raise CompilerError('Player number must be known at compile time') + set_global_vector_size(self.size) floatoutput(player, self.v, cint(-self.f), cint(0), cint(0)) + reset_global_vector_size() class _single(_number, _secret_structure): """ Representation as single integer preserving the order """ @@ -4124,6 +4151,7 @@ def get_vector(self): class _fix(_single): """ Secret fixed point type. """ __slots__ = ['v', 'f', 'k'] + is_clear = False def set_precision(cls, f, k = None): cls.f = f @@ -4349,6 +4377,18 @@ def bit_decompose(self, n_bits=None): """ Bit decomposition. """ return self.v.bit_decompose(n_bits or self.k) + def update(self, other): + """ + Update register. Useful in loops like + :py:func:`~Compiler.library.for_range`. + + :param other: any convertible type + + """ + other = self.conv(other) + assert self.f == other.f + self.v.update(other.v) + class sfix(_fix): """ Secret fixed-point number represented as secret integer, by multiplying with ``2^f`` and then rounding. See :py:class:`sint` @@ -4737,6 +4777,8 @@ class sfloat(_number, _secret_structure): returning :py:class:`sint`. The other operand can be any of sint/cfix/regint/cint/int/float. + This data type only works with arithmetic computation. + :param v: initialization (sfloat/sfix/float/int/sint/cint/regint) """ __slots__ = ['v', 'p', 'z', 's', 'size'] @@ -4835,6 +4877,9 @@ def get_input_from(cls, player): @vectorize_init @read_mem_value def __init__(self, v, p=None, z=None, s=None, size=None): + if program.options.binary: + raise CompilerError( + 'floating-point operations not supported with binary circuits') self.size = get_global_vector_size() if p is None: if isinstance(v, sfloat): @@ -5227,7 +5272,13 @@ class Array(_vectorizable): def create_from(cls, l): """ Convert Python iterator or vector to array. Basic type will be taken from first element, further elements must to be convertible to - that. """ + that. + + :param l: Python iterable or register vector + :returns: :py:class:`Array` of appropriate type containing the contents + of :py:obj:`l` + + """ if isinstance(l, cls): return l if isinstance(l, _number): @@ -6099,12 +6150,12 @@ def _(i): try: res_matrix[i] = self.value_type.row_matrix_mul( self[i], other, res_params) - except AttributeError: + except (AttributeError, CompilerError): # fallback for binary circuits - @library.for_range(other.sizes[1]) + @library.for_range_opt(other.sizes[1]) def _(j): res_matrix[i][j] = 0 - @library.for_range(self.sizes[1]) + @library.for_range_opt(self.sizes[1]) def _(k): res_matrix[i][j] += self[i][k] * other[k][j] return res_matrix @@ -6223,13 +6274,7 @@ def _(i): res[i] = self.direct_mul_trans(other, indices=indices) def direct_mul_to_matrix(self, other): - """ Matrix multiplication in the virtual machine. - - :param self: :py:class:`Matrix` / 2-dimensional :py:class:`MultiArray` - :param other: :py:class:`Matrix` / 2-dimensional :py:class:`MultiArray` - :returns: :py:obj:`Matrix` - - """ + # Obsolete. Use dot(). res = self.value_type.Matrix(self.sizes[0], other.sizes[1]) res.assign_vector(self.direct_mul(other)) return res diff --git a/Compiler/util.py b/Compiler/util.py index 9d84df226..c1bedc27c 100644 --- a/Compiler/util.py +++ b/Compiler/util.py @@ -238,6 +238,9 @@ def mem_size(x): except AttributeError: return 1 +def find_in_dict(d, v): + return list(d.keys())[list(d.values()).index(v)] + class set_by_id(object): def __init__(self, init=[]): self.content = {} diff --git a/ECDSA/P256Element.h b/ECDSA/P256Element.h index 4657b5d88..27ea7f75c 100644 --- a/ECDSA/P256Element.h +++ b/ECDSA/P256Element.h @@ -22,7 +22,7 @@ class P256Element : public ValueInterface EC_POINT* point; public: - typedef void next; + typedef P256Element next; typedef void Square; static const true_type invertible; diff --git a/ECDSA/fake-spdz-ecdsa-party.cpp b/ECDSA/fake-spdz-ecdsa-party.cpp index f0e3257c6..5bef730d5 100644 --- a/ECDSA/fake-spdz-ecdsa-party.cpp +++ b/ECDSA/fake-spdz-ecdsa-party.cpp @@ -45,12 +45,13 @@ int main(int argc, const char** argv) string prefix = get_prep_sub_dir(PREP_DIR "ECDSA/", 2); read_mac_key(prefix, N, keyp); + pShare::MAC_Check::setup(P); + Share::MAC_Check::setup(P); + DataPositions usage; Sub_Data_Files prep(N, prefix, usage); typename pShare::Direct_MC MCp(keyp); ArithmeticProcessor _({}, 0); - BaseMachine machine; - machine.ot_setups.push_back({P, false}); SubProcessor proc(_, MCp, prep, P); pShare sk, __; @@ -60,4 +61,7 @@ int main(int argc, const char** argv) preprocessing(tuples, n_tuples, sk, proc, opts); check(tuples, sk, keyp, P); sign_benchmark(tuples, sk, MCp, P, opts); + + pShare::MAC_Check::teardown(); + Share::MAC_Check::teardown(); } diff --git a/ECDSA/ot-ecdsa-party.hpp b/ECDSA/ot-ecdsa-party.hpp index 569aa791f..ebf0aea96 100644 --- a/ECDSA/ot-ecdsa-party.hpp +++ b/ECDSA/ot-ecdsa-party.hpp @@ -92,9 +92,6 @@ void run(int argc, const char** argv) P256Element::init(); P256Element::Scalar::next::init_field(P256Element::Scalar::pr(), false); - BaseMachine machine; - machine.ot_setups.push_back({P, true}); - P256Element::Scalar keyp; SeededPRNG G; keyp.randomize(G); @@ -102,6 +99,9 @@ void run(int argc, const char** argv) typedef T pShare; DataPositions usage; + pShare::MAC_Check::setup(P); + T::MAC_Check::setup(P); + OnlineOptions::singleton.batch_size = 1; typename pShare::Direct_MC MCp(keyp); ArithmeticProcessor _({}, 0); @@ -137,4 +137,7 @@ void run(int argc, const char** argv) preprocessing(tuples, n_tuples, sk, proc, opts); //check(tuples, sk, keyp, P); sign_benchmark(tuples, sk, MCp, P, opts, prep_mul ? 0 : &proc); + + pShare::MAC_Check::teardown(); + T::MAC_Check::teardown(); } diff --git a/FHE/Ciphertext.cpp b/FHE/Ciphertext.cpp index 00e051318..62cbd5281 100644 --- a/FHE/Ciphertext.cpp +++ b/FHE/Ciphertext.cpp @@ -130,7 +130,7 @@ void Ciphertext::rerandomize(const FHE_PK& pk) assert(p != 0); for (auto& x : r) { - G.get(x, params->p0().numBits() - p.numBits() - 1); + G.get(x, params->p0().numBits() - p.numBits() - 1); x *= p; } tmp.from(r, 0); diff --git a/FHE/NTL-Subs.cpp b/FHE/NTL-Subs.cpp index 794e7431d..f3973026e 100644 --- a/FHE/NTL-Subs.cpp +++ b/FHE/NTL-Subs.cpp @@ -368,7 +368,8 @@ ZZX Cyclotomic(int N) int phi_N(int N) { if (((N - 1) & N) != 0) - throw runtime_error("compile with NTL support"); + throw runtime_error( + "compile with NTL support (USE_NTL=1 in CONFIG.mine)"); else if (N == 1) return 1; else @@ -418,7 +419,8 @@ void init(Ring& Rg, int m, bool generate_poly) for (int i=0; i& elem) const */ } -void PPData::from_eval(vector& elem) const +void PPData::from_eval(vector&) const { // avoid warning - elem.empty(); throw not_implemented(); /* diff --git a/FHEOffline/PairwiseMachine.cpp b/FHEOffline/PairwiseMachine.cpp index b19dd62cf..dd3f8968d 100644 --- a/FHEOffline/PairwiseMachine.cpp +++ b/FHEOffline/PairwiseMachine.cpp @@ -17,15 +17,13 @@ PairwiseMachine::PairwiseMachine(Player& P) : { } -PairwiseMachine::PairwiseMachine(int argc, const char** argv) : - MachineBase(argc, argv), P(*new PlainPlayer(N, "pairwise")), - other_pks(N.num_players(), {setup_p.params, 0}), - pk(other_pks[N.my_num()]), sk(pk) +RealPairwiseMachine::RealPairwiseMachine(int argc, const char** argv) : + MachineBase(argc, argv), PairwiseMachine(*new PlainPlayer(N, "pairwise")) { init(); } -void PairwiseMachine::init() +void RealPairwiseMachine::init() { if (use_gf2n) { @@ -63,7 +61,7 @@ PairwiseSetup& PairwiseMachine::setup() } template -void PairwiseMachine::setup_keys() +void RealPairwiseMachine::setup_keys() { auto& N = P; PairwiseSetup& s = setup(); @@ -84,10 +82,11 @@ void PairwiseMachine::setup_keys() if (i != N.my_num()) other_pks[i].unpack(os[i]); set_mac_key(s.alphai); + Share::MAC_Check::setup(P); } template -void PairwiseMachine::set_mac_key(T alphai) +void RealPairwiseMachine::set_mac_key(T alphai) { typedef typename T::FD FD; auto& N = P; @@ -142,5 +141,5 @@ void PairwiseMachine::check(Player& P) const bundle.compare(P); } -template void PairwiseMachine::setup_keys(); -template void PairwiseMachine::setup_keys(); +template void RealPairwiseMachine::setup_keys(); +template void RealPairwiseMachine::setup_keys(); diff --git a/FHEOffline/PairwiseMachine.h b/FHEOffline/PairwiseMachine.h index c2283443e..a8a0c649e 100644 --- a/FHEOffline/PairwiseMachine.h +++ b/FHEOffline/PairwiseMachine.h @@ -10,7 +10,7 @@ #include "FHEOffline/SimpleMachine.h" #include "FHEOffline/PairwiseSetup.h" -class PairwiseMachine : public MachineBase +class PairwiseMachine : public virtual MachineBase { public: PairwiseSetup setup_p; @@ -23,15 +23,6 @@ class PairwiseMachine : public MachineBase vector enc_alphas; PairwiseMachine(Player& P); - PairwiseMachine(int argc, const char** argv); - - void init(); - - template - void setup_keys(); - - template - void set_mac_key(T alphai); template PairwiseSetup& setup(); @@ -42,4 +33,18 @@ class PairwiseMachine : public MachineBase void check(Player& P) const; }; +class RealPairwiseMachine : public virtual MachineBase, public virtual PairwiseMachine +{ +public: + RealPairwiseMachine(int argc, const char** argv); + + void init(); + + template + void setup_keys(); + + template + void set_mac_key(T alphai); +}; + #endif /* FHEOFFLINE_PAIRWISEMACHINE_H_ */ diff --git a/FHEOffline/SimpleGenerator.cpp b/FHEOffline/SimpleGenerator.cpp index b2701b2c5..be5ee2c19 100644 --- a/FHEOffline/SimpleGenerator.cpp +++ b/FHEOffline/SimpleGenerator.cpp @@ -12,7 +12,7 @@ template