Skip to content

Commit

Permalink
test(frontend): TFHE-rs bridge with modules
Browse files Browse the repository at this point in the history
  • Loading branch information
youben11 committed Dec 19, 2024
1 parent ba6bf85 commit 559b0fb
Showing 1 changed file with 250 additions and 2 deletions.
252 changes: 250 additions & 2 deletions frontends/concrete-python/tests/execution/test_tfhers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np
import pytest

from concrete import fhe
import concrete.fhe as fhe
from concrete.fhe import tfhers


Expand Down Expand Up @@ -50,7 +50,7 @@ def parameterize_partial_dtype(partial_dtype) -> tfhers.TFHERSIntegerType:


def is_input_and_output_tfhers(
circuit: fhe.Circuit,
circuit: Union[fhe.Circuit, fhe.Module],
lwe_dim: int,
tfhers_ins: List[int],
tfhers_outs: List[int],
Expand Down Expand Up @@ -1092,6 +1092,254 @@ def test_tfhers_binary_encrypted_complete_circuit_tfhers_keygen(
os.remove(sum_ct_path)


@fhe.module()
class AddModuleOneFunc:
func_count = 1

@fhe.function({"x": "encrypted", "y": "encrypted"})
def add(x, y):
x = tfhers.to_native(x)
y = tfhers.to_native(y)
return tfhers.from_native(x + y, TFHERS_UINT_8_3_2_4096)


@fhe.module()
class AddModuleTwoFunc:
func_count = 2

@fhe.function({"x": "encrypted", "y": "encrypted"})
def add(x, y):
x = tfhers.to_native(x)
y = tfhers.to_native(y)
return tfhers.from_native(x + y, TFHERS_UINT_8_3_2_4096)

@fhe.function({"x": "encrypted"})
def inc(x):
return x + 1


@pytest.mark.parametrize(
"module, func_count, parameters, tfhers_value_range",
[
pytest.param(
AddModuleOneFunc,
1,
{
"x": {"range": [0, 2**6], "status": "encrypted"},
"y": {"range": [0, 2**6], "status": "encrypted"},
},
[0, 2**6],
id="AddModuleOneFunc",
),
pytest.param(
AddModuleTwoFunc,
2,
{
"x": {"range": [0, 2**6], "status": "encrypted"},
"y": {"range": [0, 2**6], "status": "encrypted"},
},
[0, 2**6],
id="AddModuleTwoFunc",
),
],
)
def test_tfhers_binary_encrypted_complete_circuit_tfhers_keygen_with_modules(
module, func_count, parameters, tfhers_value_range, helpers
):
"""
Test different operations wrapped by tfhers conversion (2 tfhers inputs).
Encryption/decryption are done in Rust using TFHErs, while Keygen is done in Concrete.
We use modules.
"""

# global dtype to use
dtype = TFHERS_UINT_8_3_2_4096
# global function to use
function = lambda x, y: x + y

# there is no point of using the cache here as new keys will be generated everytime
config = helpers.configuration().fork(
use_insecure_key_cache=False, insecure_key_cache_location=None
)

# Only valid when running in multi
if config.parameter_selection_strategy != fhe.ParameterSelectionStrategy.MULTI:
return

inputset = [
tuple(tfhers.TFHERSInteger(dtype, arg) for arg in inpt)
for inpt in helpers.generate_inputset(parameters)
]
if func_count == 1:
add_module = module.compile({"add": inputset}, config)
else:
assert func_count == 2
add_module = module.compile({"add": inputset, "inc": [(i,) for i in range(10)]}, config)

assert is_input_and_output_tfhers(
add_module,
dtype.params.polynomial_size,
[0, 1],
[
0,
],
)

sample = helpers.generate_sample(parameters)

###### TFHErs Keygen ##########################################################
_, client_key_path = tempfile.mkstemp()
_, server_key_path = tempfile.mkstemp()
_, sk_path = tempfile.mkstemp()

tfhers_utils = (
f"{os.path.dirname(os.path.abspath(__file__))}/../tfhers-utils/target/release/tfhers_utils"
)

assert (
os.system(
f"{tfhers_utils} keygen -s {server_key_path} -c {client_key_path} --output-lwe-sk {sk_path}"
)
== 0
)

###### Concrete Keygen ########################################################
tfhers_bridge = tfhers.new_bridge(add_module)

with open(sk_path, "rb") as f:
sk_buff = f.read()

if func_count == 1:
# set sk for input 0 and generate the remaining keys
tfhers_bridge.keygen_with_initial_keys({0: sk_buff}, force=True)
else:
assert func_count == 2
with pytest.raises(RuntimeError, match="Module contains more than one function"):
tfhers_bridge.keygen_with_initial_keys({0: sk_buff}, force=True)
tfhers_bridge.keygen_with_initial_keys({("add", 0): sk_buff}, force=True)

###### Full Concrete Execution ################################################
concrete_encoded_sample = (dtype.encode(v) for v in sample)
concrete_encoded_result = add_module.add.encrypt_run_decrypt(*concrete_encoded_sample)
assert (dtype.decode(concrete_encoded_result) == function(*sample)).all()

###### TFHErs Encryption ######################################################

# encrypt inputs
ct1, ct2 = sample
_, ct1_path = tempfile.mkstemp()
_, ct2_path = tempfile.mkstemp()

tfhers_utils = (
f"{os.path.dirname(os.path.abspath(__file__))}/../tfhers-utils/target/release/tfhers_utils"
)
assert (
os.system(f"{tfhers_utils} encrypt-with-key --value={ct1} -c {ct1_path} --lwe-sk {sk_path}")
== 0
)
assert (
os.system(f"{tfhers_utils} encrypt-with-key --value={ct2} -c {ct2_path} --lwe-sk {sk_path}")
== 0
)

# import ciphertexts and run
cts = []
with open(ct1_path, "rb") as f:
buff = f.read()
if func_count == 1:
cts.append(tfhers_bridge.import_value(buff, 0))
else:
assert func_count == 2
with pytest.raises(RuntimeError, match="Module contains more than one function"):
cts.append(tfhers_bridge.import_value(buff, 0))
cts.append(tfhers_bridge.import_value(buff, 0, func_name="add"))
with open(ct2_path, "rb") as f:
buff = f.read()
if func_count == 1:
cts.append(tfhers_bridge.import_value(buff, 1))
else:
assert func_count == 2
with pytest.raises(RuntimeError, match="Module contains more than one function"):
cts.append(tfhers_bridge.import_value(buff, 1))
cts.append(tfhers_bridge.import_value(buff, 1, func_name="add"))
os.remove(ct1_path)
os.remove(ct2_path)

tfhers_encrypted_result = add_module.add.run(*cts)

# concrete decryption should work
decrypted = add_module.add.decrypt(tfhers_encrypted_result)
assert (dtype.decode(decrypted) == function(*sample)).all() # type: ignore

# tfhers decryption
if func_count == 1:
buff = tfhers_bridge.export_value(tfhers_encrypted_result, output_idx=0) # type: ignore
else:
assert func_count == 2
with pytest.raises(RuntimeError, match="Module contains more than one function"):
buff = tfhers_bridge.export_value(tfhers_encrypted_result, output_idx=0) # type: ignore
buff = tfhers_bridge.export_value(tfhers_encrypted_result, output_idx=0, func_name="add") # type: ignore
_, ct_out_path = tempfile.mkstemp()
_, pt_path = tempfile.mkstemp()
with open(ct_out_path, "wb") as f:
f.write(buff)

assert (
os.system(
f"{tfhers_utils} decrypt-with-key" f" -c {ct_out_path} --lwe-sk {sk_path} -p {pt_path}"
)
== 0
)

with open(pt_path, "r", encoding="utf-8") as f:
result = int(f.read())
assert result == function(*sample)

###### Compute with TFHErs ####################################################
_, random_ct_path = tempfile.mkstemp()
_, sum_ct_path = tempfile.mkstemp()

# encrypt random value
random_value = np.random.randint(*tfhers_value_range)
assert (
os.system(
f"{tfhers_utils} encrypt-with-key --value={random_value} -c {random_ct_path} --client-key {client_key_path}"
)
== 0
)

# add random value to the result ct
assert (
os.system(
f"{tfhers_utils} add -c {ct_out_path} {random_ct_path} -s {server_key_path} -o {sum_ct_path}"
)
== 0
)

# decrypt result
assert (
os.system(
f"{tfhers_utils} decrypt-with-key -c {sum_ct_path} --lwe-sk {sk_path} -p {pt_path}"
)
== 0
)

with open(pt_path, "r", encoding="utf-8") as f:
tfhers_result = int(f.read())
assert result + random_value == tfhers_result

# close remaining tempfiles
os.remove(client_key_path)
os.remove(server_key_path)
os.remove(sk_path)
os.remove(ct_out_path)
os.remove(pt_path)
os.remove(random_ct_path)
os.remove(sum_ct_path)


@pytest.mark.parametrize(
"function, parameters, tfhers_value_range, dtype",
[
Expand Down

0 comments on commit 559b0fb

Please sign in to comment.