Skip to content

Commit

Permalink
Protocol in dealer model.
Browse files Browse the repository at this point in the history
  • Loading branch information
mkskeller committed Apr 19, 2022
1 parent eee8865 commit 9ef15cc
Show file tree
Hide file tree
Showing 186 changed files with 2,008 additions and 618 deletions.
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.

## 0.3.1 (Apr 19, 2022)

- Protocol in dealer model
- Command-line option for security parameter
- Fixed security bug in SPDZ2k (see Section 3.4 of [the updated paper](https://eprint.iacr.org/2018/482))
- Ability to run high-level (Python) code from C++
- More memory capacity due to 64-bit addressing
- Homomorphic encryption for more fields of characteristic two
- Docker container

## 0.3.0 (Feb 17, 2022)

- Semi-honest computation based on threshold semi-homomorphic encryption
Expand Down
8 changes: 4 additions & 4 deletions Compiler/GC/instructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ class ldmsb(base.DirectMemoryInstruction, base.ReadMemoryInstruction,
:param: memory address (int)
"""
code = opcodes['LDMSB']
arg_format = ['sbw','int']
arg_format = ['sbw','long']

class stmsb(base.DirectMemoryWriteInstruction, base.VectorInstruction):
""" Copy secret bit register to secret bit memory cell with compile-time
Expand All @@ -315,7 +315,7 @@ class stmsb(base.DirectMemoryWriteInstruction, base.VectorInstruction):
:param: memory address (int)
"""
code = opcodes['STMSB']
arg_format = ['sb','int']
arg_format = ['sb','long']
# def __init__(self, *args, **kwargs):
# super(type(self), self).__init__(*args, **kwargs)
# import inspect
Expand All @@ -330,7 +330,7 @@ class ldmcb(base.DirectMemoryInstruction, base.ReadMemoryInstruction,
:param: memory address (int)
"""
code = opcodes['LDMCB']
arg_format = ['cbw','int']
arg_format = ['cbw','long']

class stmcb(base.DirectMemoryWriteInstruction, base.VectorInstruction):
""" Copy clear bit register to clear bit memory cell with compile-time
Expand All @@ -340,7 +340,7 @@ class stmcb(base.DirectMemoryWriteInstruction, base.VectorInstruction):
:param: memory address (int)
"""
code = opcodes['STMCB']
arg_format = ['cb','int']
arg_format = ['cb','long']

class ldmsbi(base.ReadMemoryInstruction, base.VectorInstruction):
""" Copy secret bit memory cell with run-time address to secret bit
Expand Down
15 changes: 14 additions & 1 deletion Compiler/GC/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
fixed-length types obtained by :py:obj:`get_type(n)` are the preferred
way of using them, and in some cases required in connection with
container types.
Computation using these types will always be executed as a binary
circuit. See :ref:`protocol-pairs` for the exact protocols.
"""

from Compiler.types import MemValue, read_mem_value, regint, Array, cint
Expand All @@ -17,7 +20,6 @@
from functools import reduce

class bits(Tape.Register, _structure, _bit):
""" Base class for binary registers. """
n = 40
unit = 64
PreOp = staticmethod(floatingpoint.PreOpN)
Expand Down Expand Up @@ -400,12 +402,18 @@ def get_random_bit():
res = sbit()
inst.bitb(res)
return res
@staticmethod
def _check_input_player(player):
if not util.is_constant(player):
raise CompilerError('player must be known at compile time '
'for binary circuit inputs')
@classmethod
def get_input_from(cls, player, n_bits=None):
""" Secret input from :py:obj:`player`.
:param: player (int)
"""
cls._check_input_player(player)
if n_bits is None:
n_bits = cls.n
res = cls()
Expand Down Expand Up @@ -653,6 +661,7 @@ def get_input_from(cls, player):
:param: player (int)
"""
sbits._check_input_player(player)
res = cls.from_vec(sbit() for i in range(n))
inst.inputbvec(n + 3, 0, player, *res.v)
return res
Expand Down Expand Up @@ -780,6 +789,8 @@ def coerce(self, other):
size = other.size
return (other.get_vector(base, min(64, size - base)) \
for base in range(0, size, 64))
if not isinstance(other, type(self)):
return type(self)(other)
return other
def __xor__(self, other):
other = self.coerce(other)
Expand Down Expand Up @@ -1222,6 +1233,7 @@ def get_input_from(cls, player):
:param: player (int)
"""
sbits._check_input_player(player)
v = cls.int_type()
inst.inputb(player, cls.k, cls.f, v)
return cls._new(v)
Expand Down Expand Up @@ -1287,6 +1299,7 @@ def get_input_from(cls, player):
:param: player (int)
"""
v = [sbit() for i in range(sbitfix.k)]
sbits._check_input_player(player)
inst.inputbvec(len(v) + 3, sbitfix.f, player, *v)
return cls._new(cls.int_type.from_vec(v))
def __init__(self, value=None, *args, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions Compiler/allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
class BlockAllocator:
""" Manages freed memory blocks. """
def __init__(self):
self.by_logsize = [defaultdict(set) for i in range(32)]
self.by_logsize = [defaultdict(set) for i in range(64)]
self.by_address = {}

def by_size(self, size):
if size >= 2 ** 32:
if size >= 2 ** 64:
raise CompilerError('size exceeds addressing capability')
return self.by_logsize[int(math.log(size, 2))][size]

Expand Down
12 changes: 6 additions & 6 deletions Compiler/instructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class ldmc(base.DirectMemoryInstruction, base.ReadMemoryInstruction):
"""
__slots__ = []
code = base.opcodes['LDMC']
arg_format = ['cw','int']
arg_format = ['cw','long']

@base.gf2n
@base.vectorize
Expand All @@ -84,7 +84,7 @@ class ldms(base.DirectMemoryInstruction, base.ReadMemoryInstruction):
"""
__slots__ = []
code = base.opcodes['LDMS']
arg_format = ['sw','int']
arg_format = ['sw','long']

@base.gf2n
@base.vectorize
Expand All @@ -99,7 +99,7 @@ class stmc(base.DirectMemoryWriteInstruction):
"""
__slots__ = []
code = base.opcodes['STMC']
arg_format = ['c','int']
arg_format = ['c','long']

@base.gf2n
@base.vectorize
Expand All @@ -114,7 +114,7 @@ class stms(base.DirectMemoryWriteInstruction):
"""
__slots__ = []
code = base.opcodes['STMS']
arg_format = ['s','int']
arg_format = ['s','long']

@base.vectorize
class ldmint(base.DirectMemoryInstruction, base.ReadMemoryInstruction):
Expand All @@ -128,7 +128,7 @@ class ldmint(base.DirectMemoryInstruction, base.ReadMemoryInstruction):
"""
__slots__ = []
code = base.opcodes['LDMINT']
arg_format = ['ciw','int']
arg_format = ['ciw','long']

@base.vectorize
class stmint(base.DirectMemoryWriteInstruction):
Expand All @@ -142,7 +142,7 @@ class stmint(base.DirectMemoryWriteInstruction):
"""
__slots__ = []
code = base.opcodes['STMINT']
arg_format = ['ci','int']
arg_format = ['ci','long']

@base.vectorize
class ldmci(base.ReadMemoryInstruction, base.IndirectMemoryInstruction):
Expand Down
13 changes: 11 additions & 2 deletions Compiler/instructions_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def reformat(arg_format):
if isinstance(arg_format, list):
__format = []
for __f in arg_format:
if __f in ('int', 'p', 'ci', 'str'):
if __f in ('int', 'long', 'p', 'ci', 'str'):
__format.append(__f)
else:
__format.append(__f[0] + 'g' + __f[1:])
Expand All @@ -360,7 +360,7 @@ class GF2N_Instruction(instruction_cls):
arg_format = instruction_cls.gf2n_arg_format
elif isinstance(instruction_cls.arg_format, itertools.repeat):
__f = next(instruction_cls.arg_format)
if __f != 'int' and __f != 'p':
if __f not in ('int', 'long', 'p'):
arg_format = itertools.repeat(__f[0] + 'g' + __f[1:])
else:
arg_format = copy.deepcopy(instruction_cls.arg_format)
Expand Down Expand Up @@ -711,6 +711,14 @@ def __init__(self, f):
def __str__(self):
return str(self.i)

class LongArgFormat(IntArgFormat):
@classmethod
def encode(cls, arg):
return struct.pack('>Q', arg)

def __init__(self, f):
self.i = struct.unpack('>Q', f.read(8))[0]

class ImmediateModpAF(IntArgFormat):
@classmethod
def check(cls, arg):
Expand Down Expand Up @@ -768,6 +776,7 @@ def __str__(self):
'i': ImmediateModpAF,
'ig': ImmediateGF2NAF,
'int': IntArgFormat,
'long': LongArgFormat,
'p': PlayerNoAF,
'str': String,
}
Expand Down
13 changes: 13 additions & 0 deletions Compiler/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -1149,6 +1149,7 @@ def multithread(n_threads, n_items=None, max_size=None):
:param n_threads: compile-time (int)
:param n_items: regint/cint/int (default: :py:obj:`n_threads`)
:param max_size: maximum size to be processed at once (default: no limit)
The following executes ``f(0, 8)``, ``f(8, 8)``, and
``f(16, 9)`` in three different threads:
Expand Down Expand Up @@ -1366,6 +1367,18 @@ def _(base, size):
left = (left + 1) // 2
return inputs[0]

def tree_reduce(function, sequence):
""" Round-efficient reduction. The following computes the maximum
of the list :py:obj:`l`::
m = tree_reduce(lambda x, y: x.max(y), l)
:param function: reduction function taking two arguments
:param sequence: list, vector, or array
"""
return util.tree_reduce(function, sequence)

def foreach_enumerate(a):
""" Run-time loop over public data. This uses
``Player-Data/Public-Input/<progname>``. Example:
Expand Down
7 changes: 5 additions & 2 deletions Compiler/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -1079,14 +1079,17 @@ def __repr__(self):
(type(self).__name__, self.X.sizes, self.strides,
self.ksize, self.padding)

def _forward(self, batch):
def forward(self, batch=None, training=False):
if batch is None:
batch = Array.create_from(regint(0))
def process(pool, bi, k, i, j):
def m(a, b):
c = a[0] > b[0]
l = [c * x for x in a[1]]
l += [(1 - c) * x for x in b[1]]
return c.if_else(a[0], b[0]), l
red = util.tree_reduce(m, [(x[0], [1]) for x in pool])
red = util.tree_reduce(m, [(x[0], [1] if training else [])
for x in pool])
self.Y[bi][i][j][k] = red[0]
for i, x in enumerate(red[1]):
self.comparisons[bi][k][i] = x
Expand Down
3 changes: 3 additions & 0 deletions Compiler/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,9 @@ def malloc(self, size, mem_type, reg_type=None, creator_tape=None):
self.allocated_mem[mem_type] += size
if len(str(addr)) != len(str(addr + size)) and self.verbose:
print("Memory of type '%s' now of size %d" % (mem_type, addr + size))
if addr + size >= 2 ** 32:
raise CompilerError("allocation exceeded for type '%s'" %
mem_type)
self.allocated_mem_blocks[addr,mem_type] = size
if single_size:
from .library import get_thread_number, runtime_error_if
Expand Down
35 changes: 29 additions & 6 deletions Compiler/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1710,7 +1710,12 @@ def reveal_to(self, player):
res = Array.create_from(res)
return personal(player, res)

def bit_decompose(self, length):
def bit_decompose(self, length=None):
""" Bit decomposition.
:param length: number of bits
"""
return [personal(self.player, x) for x in self._v.bit_decompose(length)]

def _san(self, other):
Expand Down Expand Up @@ -2144,14 +2149,17 @@ class sint(_secret, _int):
the bit length.
:param val: initialization (sint/cint/regint/int/cgf2n or list
thereof or sbits/sbitvec/sfix)
thereof, sbits/sbitvec/sfix, or :py:class:`personal`)
:param size: vector size (int), defaults to 1 or size of list
When converting :py:class:`~Compiler.GC.types.sbits`, the result is a
vector of bits, and when converting
:py:class:`~Compiler.GC.types.sbitvec`, the result is a vector of values
with bit length equal the length of the input.
Initializing from a :py:class:`personal` value implies the
relevant party inputting their value securely.
"""
__slots__ = []
instruction_type = 'modp'
Expand Down Expand Up @@ -4285,6 +4293,7 @@ class sfix(_fix):
""" Secret fixed-point number represented as secret integer, by
multiplying with ``2^f`` and then rounding. See :py:class:`sint`
for security considerations of the underlying integer operations.
The secret integer is stored as the :py:obj:`v` member.
It supports basic arithmetic (``+, -, *, /``), returning
:py:class:`sfix`, and comparisons (``==, !=, <, <=, >, >=``),
Expand Down Expand Up @@ -5121,7 +5130,8 @@ class Array(_vectorizable):
array ``a`` and ``i`` being a :py:class:`regint`,
:py:class:`cint`, or a Python integer.
:param length: compile-time integer (int) or :py:obj:`None` for unknown length
:param length: compile-time integer (int) or :py:obj:`None`
for unknown length (need to specify :py:obj:`address`)
:param value_type: basic type
:param address: if given (regint/int), the array will not be allocated
Expand Down Expand Up @@ -5178,6 +5188,8 @@ def delete(self):
self.address = None

def get_address(self, index):
if isinstance(index, (_secret, _single)):
raise CompilerError('need cleartext index')
key = str(index)
if self.length is not None:
from .GC.types import cbits
Expand Down Expand Up @@ -5211,6 +5223,7 @@ def get_slice(self, index):
if index.step == 0:
raise CompilerError('slice step cannot be zero')
return index.start or 0, \
index.stop if self.length is None else \
min(index.stop or self.length, self.length), index.step or 1

def __getitem__(self, index):
Expand Down Expand Up @@ -5517,7 +5530,15 @@ def print_reveal_nested(self, end='\n'):
:param end: string to print after (default: line break)
"""
library.print_str('%s' + end, self.get_vector().reveal())
if util.is_constant(self.length):
library.print_str('%s' + end, self.get_vector().reveal())
else:
library.print_str('[')
@library.for_range(self.length - 1)
def _(i):
library.print_str('%s, ', self[i].reveal())
library.print_str('%s', self[self.length - 1].reveal())
library.print_str(']' + end)

def reveal_to_binary_output(self, player=None):
""" Reveal to binary output if supported by type.
Expand Down Expand Up @@ -5893,7 +5914,8 @@ def dot(self, other, res_params=None, n_threads=None):
""" Matrix-matrix and matrix-vector multiplication.
:param self: two-dimensional
:param other: Matrix or Array of matching size and type """
:param other: Matrix or Array of matching size and type
:param n_threads: number of threads (default: all in same thread) """
assert len(self.sizes) == 2
if isinstance(other, Array):
assert len(other) == self.sizes[1]
Expand Down Expand Up @@ -5928,6 +5950,7 @@ def _(base, size):
res_matrix.assign_part_vector(
self.get_part(base, size).direct_mul(other), base)
except AttributeError:
assert n_threads is None
if max(res_matrix.sizes) > 1000:
raise AttributeError()
A = self.get_vector()
Expand All @@ -5937,7 +5960,7 @@ def _(base, size):
res_params))
except (AttributeError, AssertionError):
# fallback for sfloat etc.
@library.for_range_opt(self.sizes[0])
@library.for_range_opt_multithread(n_threads, self.sizes[0])
def _(i):
try:
res_matrix[i] = self.value_type.row_matrix_mul(
Expand Down
Loading

0 comments on commit 9ef15cc

Please sign in to comment.