From 11bf8d90664d1e1958dff25f47128ec272414522 Mon Sep 17 00:00:00 2001 From: Bourgerie Quentin Date: Wed, 25 Sep 2024 19:09:21 +0200 Subject: [PATCH] fix(frontend-python): Fixing default circuit name in tfhe-rs bridge --- .../Python/concrete/compiler/client_parameters.py | 2 +- frontends/concrete-python/concrete/fhe/tfhers/bridge.py | 2 +- frontends/concrete-python/examples/tfhers/example.py | 2 +- frontends/concrete-python/tests/execution/test_tfhers.py | 8 ++++---- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/client_parameters.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/client_parameters.py index 87dbf4214e..bf2fa1659a 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/client_parameters.py +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/client_parameters.py @@ -54,7 +54,7 @@ def lwe_secret_key_param(self, key_id: int) -> LweSecretKeyParam: raise TypeError(f"key_id must be of type int, not {type(key_id)}") return LweSecretKeyParam.wrap(self.cpp().lwe_secret_key_param(key_id)) - def input_keyid_at(self, input_idx: int, circuit_name: str = "main") -> int: + def input_keyid_at(self, input_idx: int, circuit_name: str = "") -> int: """Get the keyid of a selected encrypted input in a given circuit. Args: diff --git a/frontends/concrete-python/concrete/fhe/tfhers/bridge.py b/frontends/concrete-python/concrete/fhe/tfhers/bridge.py index b22dc18505..73f530aaad 100644 --- a/frontends/concrete-python/concrete/fhe/tfhers/bridge.py +++ b/frontends/concrete-python/concrete/fhe/tfhers/bridge.py @@ -232,7 +232,7 @@ def keygen_with_initial_keys( def new_bridge( circuit: "fhe.Circuit", - func_name: str = "main", + func_name: str = "", ) -> Bridge: """Create a TFHErs bridge from a circuit. diff --git a/frontends/concrete-python/examples/tfhers/example.py b/frontends/concrete-python/examples/tfhers/example.py index 2d5eb6b664..a808d4d350 100644 --- a/frontends/concrete-python/examples/tfhers/example.py +++ b/frontends/concrete-python/examples/tfhers/example.py @@ -75,7 +75,7 @@ def ccompilee(): inputset = [(tfhers_int(120), tfhers_int(120))] circuit = compiler.compile(inputset) - tfhers_bridge = tfhers.new_bridge(circuit=circuit, func_name="main") + tfhers_bridge = tfhers.new_bridge(circuit=circuit) return circuit, tfhers_bridge diff --git a/frontends/concrete-python/tests/execution/test_tfhers.py b/frontends/concrete-python/tests/execution/test_tfhers.py index 1fdc01875c..8589cdb724 100644 --- a/frontends/concrete-python/tests/execution/test_tfhers.py +++ b/frontends/concrete-python/tests/execution/test_tfhers.py @@ -382,7 +382,7 @@ def test_tfhers_binary_encrypted_complete_circuit_concrete_keygen( assert (dtype.decode(concrete_encoded_result) == function(*sample)).all() ###### TFHErs Encryption & Computation ######################################## - tfhers_bridge = tfhers.new_bridge(circuit, func_name="main") + tfhers_bridge = tfhers.new_bridge(circuit) # serialize key _, key_path = tempfile.mkstemp() @@ -617,7 +617,7 @@ def test_tfhers_one_tfhers_one_native_complete_circuit_concrete_keygen( assert (dtype.decode(concrete_encoded_result) == function(*sample)).all() ###### TFHErs Encryption ###################################################### - tfhers_bridge = tfhers.new_bridge(circuit, func_name="main") + tfhers_bridge = tfhers.new_bridge(circuit) # serialize key _, key_path = tempfile.mkstemp() @@ -780,7 +780,7 @@ def test_tfhers_binary_encrypted_complete_circuit_tfhers_keygen( ) ###### Concrete Keygen ######################################################## - tfhers_bridge = tfhers.new_bridge(circuit, func_name="main") + tfhers_bridge = tfhers.new_bridge(circuit) with open(sk_path, "rb") as f: sk_buff = f.read() @@ -1044,7 +1044,7 @@ def test_tfhers_one_tfhers_one_native_complete_circuit_tfhers_keygen( ) ###### Concrete Keygen ######################################################## - tfhers_bridge = tfhers.new_bridge(circuit, func_name="main") + tfhers_bridge = tfhers.new_bridge(circuit) with open(sk_path, "rb") as f: sk_buff = f.read()