Skip to content

Commit

Permalink
feat(frontend): support configuration overrides for Server.load when …
Browse files Browse the repository at this point in the history
…via_mlir is used
  • Loading branch information
umut-sahin authored and BourgerieQuentin committed Aug 8, 2024
1 parent 79fdb8d commit 46c115a
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 2 deletions.
8 changes: 6 additions & 2 deletions frontends/concrete-python/concrete/fhe/compilation/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,14 +299,18 @@ def save(self, path: Union[str, Path], via_mlir: bool = False):
shutil.make_archive(path, "zip", self._output_dir)

@staticmethod
def load(path: Union[str, Path]) -> "Server":
def load(path: Union[str, Path], **kwargs) -> "Server":
"""
Load the server from the given path in zip format.
Args:
path (Union[str, Path]):
path to load the server from
kwargs (Dict[str, Any]):
configuration options to overwrite when loading a server saved with `via_mlir`
if server isn't loaded via mlir, kwargs are ignored
Returns:
Server:
server loaded from the filesystem
Expand Down Expand Up @@ -343,7 +347,7 @@ def load(path: Union[str, Path]) -> "Server":
mlir = f.read()

with open(output_dir_path / "configuration.json", "r", encoding="utf-8") as f:
configuration = Configuration().fork(**jsonpickle.loads(f.read()))
configuration = Configuration().fork(**jsonpickle.loads(f.read())).fork(**kwargs)

return Server.create(
mlir, configuration, is_simulated, composition_rules=composition_rules
Expand Down
25 changes: 25 additions & 0 deletions frontends/concrete-python/tests/compilation/test_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,31 @@ def function(x):
server.cleanup()


def test_server_loading_via_mlir_kwargs(helpers):
"""
Test server loading via MLIR with kwarg overrides.
"""

configuration = helpers.configuration().fork(global_p_error=None, p_error=0.001)

@fhe.compiler({"x": "encrypted", "y": "encrypted"})
def function(x, y):
return x == y

inputset = fhe.inputset(fhe.uint4, fhe.uint4)
circuit = function.compile(inputset, configuration)

with tempfile.TemporaryDirectory() as tmp_dir:
tmp_dir_path = Path(tmp_dir)

server_path = tmp_dir_path / "server.zip"
circuit.server.save(server_path, via_mlir=True)

server = Server.load(server_path, p_error=0.05)

assert server.complexity < circuit.complexity


def test_circuit_run_with_unused_arg(helpers):
"""
Test `encrypt_run_decrypt` method of `Circuit` class with unused arguments.
Expand Down

0 comments on commit 46c115a

Please sign in to comment.