From 46c115ab35388cc083bdd0c361105a5ae200d33e Mon Sep 17 00:00:00 2001 From: Umut Date: Thu, 8 Aug 2024 11:15:16 +0300 Subject: [PATCH] feat(frontend): support configuration overrides for Server.load when via_mlir is used --- .../concrete/fhe/compilation/server.py | 8 ++++-- .../tests/compilation/test_circuit.py | 25 +++++++++++++++++++ 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/frontends/concrete-python/concrete/fhe/compilation/server.py b/frontends/concrete-python/concrete/fhe/compilation/server.py index 5978d51825..47b415a3ae 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/server.py +++ b/frontends/concrete-python/concrete/fhe/compilation/server.py @@ -299,7 +299,7 @@ def save(self, path: Union[str, Path], via_mlir: bool = False): shutil.make_archive(path, "zip", self._output_dir) @staticmethod - def load(path: Union[str, Path]) -> "Server": + def load(path: Union[str, Path], **kwargs) -> "Server": """ Load the server from the given path in zip format. @@ -307,6 +307,10 @@ def load(path: Union[str, Path]) -> "Server": path (Union[str, Path]): path to load the server from + kwargs (Dict[str, Any]): + configuration options to overwrite when loading a server saved with `via_mlir` + if server isn't loaded via mlir, kwargs are ignored + Returns: Server: server loaded from the filesystem @@ -343,7 +347,7 @@ def load(path: Union[str, Path]) -> "Server": mlir = f.read() with open(output_dir_path / "configuration.json", "r", encoding="utf-8") as f: - configuration = Configuration().fork(**jsonpickle.loads(f.read())) + configuration = Configuration().fork(**jsonpickle.loads(f.read())).fork(**kwargs) return Server.create( mlir, configuration, is_simulated, composition_rules=composition_rules diff --git a/frontends/concrete-python/tests/compilation/test_circuit.py b/frontends/concrete-python/tests/compilation/test_circuit.py index 8e6e5bc8fa..b08b22866b 100644 --- a/frontends/concrete-python/tests/compilation/test_circuit.py +++ b/frontends/concrete-python/tests/compilation/test_circuit.py @@ -397,6 +397,31 @@ def function(x): server.cleanup() +def test_server_loading_via_mlir_kwargs(helpers): + """ + Test server loading via MLIR with kwarg overrides. + """ + + configuration = helpers.configuration().fork(global_p_error=None, p_error=0.001) + + @fhe.compiler({"x": "encrypted", "y": "encrypted"}) + def function(x, y): + return x == y + + inputset = fhe.inputset(fhe.uint4, fhe.uint4) + circuit = function.compile(inputset, configuration) + + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_dir_path = Path(tmp_dir) + + server_path = tmp_dir_path / "server.zip" + circuit.server.save(server_path, via_mlir=True) + + server = Server.load(server_path, p_error=0.05) + + assert server.complexity < circuit.complexity + + def test_circuit_run_with_unused_arg(helpers): """ Test `encrypt_run_decrypt` method of `Circuit` class with unused arguments.