diff --git a/frontends/concrete-python/concrete/fhe/compilation/circuit.py b/frontends/concrete-python/concrete/fhe/compilation/circuit.py index 2dfcfd7f7e..a5d5910ef6 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/circuit.py +++ b/frontends/concrete-python/concrete/fhe/compilation/circuit.py @@ -39,6 +39,10 @@ def __init__(self, module: FheModule): def _function(self) -> FheFunction: return getattr(self._module, self._name) + @property + def function_name(self) -> str: + return self._name + def __str__(self): return self._function.graph.format() @@ -148,7 +152,10 @@ def keygen( initial keys to set before keygen """ self._module.keygen( - force=force, seed=seed, encryption_seed=encryption_seed, initial_keys=initial_keys + force=force, + seed=seed, + encryption_seed=encryption_seed, + initial_keys=initial_keys, ) def encrypt( diff --git a/frontends/concrete-python/concrete/fhe/tfhers/bridge.py b/frontends/concrete-python/concrete/fhe/tfhers/bridge.py index 73f530aaad..af8dd1701f 100644 --- a/frontends/concrete-python/concrete/fhe/tfhers/bridge.py +++ b/frontends/concrete-python/concrete/fhe/tfhers/bridge.py @@ -23,19 +23,16 @@ class Bridge: circuit: "fhe.Circuit" input_types: List[Optional[TFHERSIntegerType]] output_types: List[Optional[TFHERSIntegerType]] - func_name: str def __init__( self, circuit: "fhe.Circuit", input_types: List[Optional[TFHERSIntegerType]], output_types: List[Optional[TFHERSIntegerType]], - func_name: str, ): self.circuit = circuit self.input_types = input_types self.output_types = output_types - self.func_name = func_name def _input_type(self, input_idx: int) -> Optional[TFHERSIntegerType]: """Return the type of a certain input. @@ -60,7 +57,9 @@ def _output_type(self, output_idx: int) -> Optional[TFHERSIntegerType]: return self.output_types[output_idx] def _input_keyid(self, input_idx: int) -> int: - return self.circuit.client.specs.client_parameters.input_keyid_at(input_idx, self.func_name) + return self.circuit.client.specs.client_parameters.input_keyid_at( + input_idx, self.circuit.function_name + ) def _input_variance(self, input_idx: int) -> float: input_type = self._input_type(input_idx) @@ -230,15 +229,11 @@ def keygen_with_initial_keys( ) -def new_bridge( - circuit: "fhe.Circuit", - func_name: str = "", -) -> Bridge: +def new_bridge(circuit: "fhe.Circuit") -> Bridge: """Create a TFHErs bridge from a circuit. Args: circuit (Circuit): compiled circuit - func_name (str, optional): name of the function to use. Defaults to "main". Returns: Bridge: TFHErs bridge @@ -260,4 +255,4 @@ def new_bridge( for output_node in circuit.graph.ordered_outputs() ] - return Bridge(circuit, input_types, output_types, func_name) + return Bridge(circuit, input_types, output_types)