Skip to content

Commit

Permalink
Functionality to call high-level code from C++.
Browse files Browse the repository at this point in the history
  • Loading branch information
mkskeller committed Nov 21, 2024
1 parent e7554cc commit 91321ff
Show file tree
Hide file tree
Showing 245 changed files with 3,868 additions and 1,132 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ callgrind.out.*
Programs/Bytecode/*
Programs/Schedules/*
Programs/Public-Input/*
Programs/Functions
*.com
*.class
*.dll
Expand Down
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,6 @@
[submodule "deps/SimplestOT_C"]
path = deps/SimplestOT_C
url = https://github.com/mkskeller/SimplestOT_C
[submodule "deps/sse2neon"]
path = deps/sse2neon
url = https://github.com/DLTcollab/sse2neon
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.

## 0.4.0 (November 21, 2024)

- Functionality to call high-level code from C++
- Matrix triples from file for all appropriate protocols
- Exit with message on errors instead of uncaught exceptions
- Reduce memory usage for binary memory
- Optimized cint-regint conversion in Dealer protocol
- Fixed security bug: missing MAC check in probabilistic truncation

## 0.3.9 (July 9, 2024)

- Inference with non-sequential PyTorch networks
Expand Down
2 changes: 2 additions & 0 deletions CONFIG
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ CXX = clang++
# use CONFIG.mine to overwrite DIR settings
-include CONFIG.mine

AVX_SIMPLEOT := $(AVX_OT)

ifeq ($(USE_GF2N_LONG),1)
GF2N_LONG = -DUSE_GF2N_LONG
endif
Expand Down
4 changes: 2 additions & 2 deletions Compiler/GC/instructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,8 +540,8 @@ class split(base.Instruction):
:param: number of arguments to follow (number of bits times number of additive shares plus one)
:param: source (sint)
:param: first share of least significant bit
:param: second share of least significant bit
:param: first share of least significant bit (sbit)
:param: second share of least significant bit (sbit)
:param: (remaining share of least significant bit)...
:param: (repeat from first share for bit one step higher)...
"""
Expand Down
21 changes: 12 additions & 9 deletions Compiler/GC/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,7 @@ def get_type(cls, n):
:py:obj:`v` and the columns by calling :py:obj:`elements`.
"""
class sbitvecn(cls, _structure):
n_bits = n
@staticmethod
def get_type(n):
return cls.get_type(n)
Expand All @@ -757,17 +758,19 @@ def get_input_from(cls, player, size=1, f=0):
:param: player (int)
"""
v = [0] * n
sbits._check_input_player(player)
instructions_base.check_vector_size(size)
for i in range(size):
vv = [sbit() for i in range(n)]
inst.inputbvec(n + 3, f, player, *vv)
for j in range(n):
tmp = vv[j] << i
v[j] = tmp ^ v[j]
sbits._check_input_player(player)
return cls.from_vec(v)
if size == 1:
res = cls.from_vec(sbit() for i in range(n))
inst.inputbvec(n + 3, f, player, *res.v)
return res
else:
elements = []
for i in range(size):
v = sbits.get_type(n)()
inst.inputb(player, n, f, v)
elements.append(v)
return cls(elements)
get_raw_input_from = get_input_from
@classmethod
def from_vec(cls, vector):
Expand Down
9 changes: 7 additions & 2 deletions Compiler/allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ def alloc_reg(self, reg, free):
dup = dup.vectorbase
self.alloc[dup] = self.alloc[base]
dup.i = self.alloc[base]
if not dup.dup_count:
dup.dup_count = len(base.duplicates)

def dealloc_reg(self, reg, inst, free):
if reg.vector:
Expand Down Expand Up @@ -275,8 +277,9 @@ def finalize(self, options):
for reg in self.alloc:
for x in reg.get_all():
if x not in self.dealloc and reg not in self.dealloc \
and len(x.duplicates) == 0:
print('Warning: read before write at register', x)
and len(x.duplicates) == x.dup_count:
print('Warning: read before write at register %s/%x' %
(x, id(x)))
print('\tregister trace: %s' % format_trace(x.caller,
'\t\t'))
if options.stop:
Expand Down Expand Up @@ -750,6 +753,8 @@ def eliminate(i):
G.remove_node(i)
merge_nodes.discard(i)
stats[type(instructions[i]).__name__] += 1
for reg in instructions[i].get_def():
self.block.parent.program.base_addresses.pop(reg)
instructions[i] = None
if unused_result:
eliminate(i)
Expand Down
10 changes: 10 additions & 0 deletions Compiler/compilerLib.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,18 @@


class Compiler:
singleton = None

def __init__(self, custom_args=None, usage=None, execute=False,
split_args=False):
if Compiler.singleton:
raise CompilerError(
"Cannot have more than one compiler instance. "
"It's not possible to run direct compilation programs with "
"compile.py or compile-run.py.")
else:
Compiler.singleton = self

if usage:
self.usage = usage
else:
Expand Down
60 changes: 49 additions & 11 deletions Compiler/dijkstra.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
""" This module implements `Dijkstra's algorithm
<https://en.wikipedia.org/wiki/Dijkstra%27s_algorithm>`_ based on
oblivious RAM. """


from Compiler.oram import *

from Compiler.program import Program
Expand Down Expand Up @@ -222,7 +227,21 @@ def dump(self, msg=''):
print_ln()
print_ln()

def dijkstra(source, edges, e_index, oram_type, n_loops=None, int_type=None):
def dijkstra(source, edges, e_index, oram_type, n_loops=None, int_type=None,
debug=False):
""" Securely compute Dijstra's algorithm on a secret graph. See
:download:`../Programs/Source/dijkstra_example.mpc` for an
explanation of the required inputs.
:param source: source node (secret or clear-text integer)
:param edges: ORAM representation of edges
:param e_index: ORAM representation of vertices
:param oram_type: ORAM type to use internally (default:
:py:func:`~Compiler.oram.OptimalORAM`)
:param n_loops: when to stop (default: number of edges)
:param int_type: secret integer type (default: sint)
"""
vert_loops = n_loops * e_index.size // edges.size \
if n_loops else -1
dist = oram_type(e_index.size, entry_size=(32,log2(e_index.size)), \
Expand Down Expand Up @@ -267,27 +286,46 @@ def f(i):
dist.access(v, (basic_type(alt), u), is_shorter)
#previous.access(v, u, is_shorter)
Q.update(v, basic_type(alt), is_shorter)
print_ln('u: %s, v: %s, alt: %s, dv: %s, first visit: %s', \
u.reveal(), v.reveal(), alt.reveal(), dv[0].reveal(), \
not_visited.reveal())
if debug:
print_ln('u: %s, v: %s, alt: %s, dv: %s, first visit: %s, '
'shorter: %s, running: %s, queue size: %s, last edge: %s',
u.reveal(), v.reveal(), alt.reveal(), dv[0].reveal(),
not_visited.reveal(), is_shorter.reveal(),
running.reveal(), Q.size.reveal(), last_edge.reveal())
return dist

def convert_graph(G):
""" Convert a `NetworkX directed graph
<https://networkx.org/documentation/stable/reference/classes/digraph.html>`_
to the cleartext representation of what :py:func:`dijkstra` expects. """
G = G.copy()
for u in G:
for v in G[u]:
G[u][v].setdefault('weight', 1)
edges = [None] * (2 * G.size())
e_index = [None] * (len(G))
i = 0
for v in G:
for v in sorted(G):
e_index[v] = i
for u in G[v]:
for u in sorted(G[v]):
edges[i] = [u, G[v][u]['weight'], 0]
i += 1
if not G[v]:
edges[i] = [v, 0, 0]
i += 1
edges[i-1][-1] = 1
return edges, e_index
return list(filter(lambda x: x, edges)), e_index

def test_dijkstra(G, source, oram_type=ORAM, n_loops=None, int_type=sint):
for u in G:
for v in G[u]:
G[u][v].setdefault('weight', 1)
def test_dijkstra(G, source, oram_type=ORAM, n_loops=None,
int_type=sint):
""" Securely compute Dijstra's algorithm on a cleartext graph.
:param G: directed graph with NetworkX interface
:param source: source node (secret or clear-text integer)
:param n_loops: when to stop (default: number of edges)
:param int_type: secret integer type (default: sint)
"""
edges_list, e_index_list = convert_graph(G)
edges = oram_type(len(edges_list), \
entry_size=(log2(len(G)), log2(len(G)), 1), \
Expand Down
18 changes: 9 additions & 9 deletions Compiler/instructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ class stop(base.Instruction):
arg_format = ['i']

class use(base.Instruction):
""" Offline data usage. Necessary to avoid reusage while using
r""" Offline data usage. Necessary to avoid reusage while using
preprocessing from files. Also used to multithreading for expensive
preprocessing.
Expand All @@ -419,7 +419,7 @@ def get_usage(cls, args):
args[2].i}

class use_inp(base.Instruction):
""" Input usage. Necessary to avoid reusage while using
r""" Input usage. Necessary to avoid reusage while using
preprocessing from files.
:param: domain (0: integer, 1: :math:`\mathrm{GF}(2^n)`, 2: bit)
Expand Down Expand Up @@ -1738,7 +1738,7 @@ class print_reg_plains(base.IOInstruction):
arg_format = ['s']

class cond_print_plain(base.IOInstruction):
""" Conditionally output clear register (with precision).
r""" Conditionally output clear register (with precision).
Outputs :math:`x \cdot 2^p` where :math:`p` is the precision.
:param: condition (cint, no output if zero)
Expand Down Expand Up @@ -1989,7 +1989,7 @@ class closeclientconnection(base.IOInstruction):
code = base.opcodes['CLOSECLIENTCONNECTION']
arg_format = ['ci']

class writesharestofile(base.IOInstruction):
class writesharestofile(base.VectorInstruction, base.IOInstruction):
""" Write shares to ``Persistence/Transactions-P<playerno>.data``
(appending at the end).
Expand All @@ -2002,11 +2002,12 @@ class writesharestofile(base.IOInstruction):
__slots__ = []
code = base.opcodes['WRITEFILESHARE']
arg_format = tools.chain(['ci'], itertools.repeat('s'))
vector_index = 1

def has_var_args(self):
return True

class readsharesfromfile(base.IOInstruction):
class readsharesfromfile(base.VectorInstruction, base.IOInstruction):
""" Read shares from ``Persistence/Transactions-P<playerno>.data``.
:param: number of arguments to follow / number of shares plus two (int)
Expand All @@ -2018,6 +2019,7 @@ class readsharesfromfile(base.IOInstruction):
__slots__ = []
code = base.opcodes['READFILESHARE']
arg_format = tools.chain(['ci', 'ciw'], itertools.repeat('sw'))
vector_index = 2

def has_var_args(self):
return True
Expand Down Expand Up @@ -2341,7 +2343,7 @@ class convint(base.Instruction):

@base.vectorize
class convmodp(base.Instruction):
""" Convert clear integer register (vector) to clear register
r""" Convert clear integer register (vector) to clear register
(vector). If the bit length is zero, the unsigned conversion is
used, otherwise signed conversion is used. This makes a difference
when computing modulo a prime :math:`p`. Signed conversion of
Expand Down Expand Up @@ -2814,13 +2816,11 @@ class check(base.Instruction):
@base.gf2n
@base.vectorize
class sqrs(base.CISC):
""" Secret squaring $s_i = s_j \cdot s_j$. """
r""" Secret squaring $s_i = s_j \cdot s_j$. """
__slots__ = []
arg_format = ['sw', 's']

def expand(self):
if program.options.ring:
return muls(self.args[0], self.args[1], self.args[1])
s = [program.curr_block.new_reg('s') for i in range(6)]
c = [program.curr_block.new_reg('c') for i in range(2)]
square(s[0], s[1])
Expand Down
4 changes: 3 additions & 1 deletion Compiler/instructions_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1200,9 +1200,11 @@ def has_var_args(self):
class VectorInstruction(Instruction):
__slots__ = []
is_vec = lambda self: True
vector_index = 0

def get_code(self):
return super(VectorInstruction, self).get_code(len(self.args[0]))
return super(VectorInstruction, self).get_code(
len(self.args[self.vector_index]))

class Ciscable(Instruction):
def copy(self, size, subs):
Expand Down
Loading

0 comments on commit 91321ff

Please sign in to comment.