From 6f3a0e40dac10bf7bf9215f0191c89c7d78fac9c Mon Sep 17 00:00:00 2001 From: youben11 Date: Thu, 28 Nov 2024 11:43:17 +0100 Subject: [PATCH] perf(frontend/compiler): support deser keyset using path reduce memory usage by avoiding unecessary copy --- .../lib/Bindings/Python/CompilerAPIModule.cpp | 20 +++++++++++++++++++ .../concrete/fhe/compilation/keys.py | 19 ++++++++++++------ 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp index c1183a8128..0d92749237 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp @@ -1433,6 +1433,26 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( return std::make_unique(std::move(keyset)); }, "Deserialize a Keyset from bytes.", arg("bytes")) + .def_static( + "deserialize_from_file", + [](const std::string path) { + std::ifstream ifs; + ifs.open(path); + if (!ifs.good()) { + throw std::runtime_error("Failed to open keyset file " + path); + } + + auto keysetProto = Message(); + auto maybeError = keysetProto.readBinaryFromIstream( + ifs, mlir::concretelang::python::DESER_OPTIONS); + if (maybeError.has_failure()) { + throw std::runtime_error("Failed to deserialize keyset." + + maybeError.as_failure().error().mesg); + } + auto keyset = Keyset::fromProto(keysetProto); + return std::make_unique(std::move(keyset)); + }, + "Deserialize a Keyset from a file.", arg("path")) .def( "serialize", [](Keyset &keySet) { diff --git a/frontends/concrete-python/concrete/fhe/compilation/keys.py b/frontends/concrete-python/concrete/fhe/compilation/keys.py index 72942cb4c1..76356ca7c6 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/keys.py +++ b/frontends/concrete-python/concrete/fhe/compilation/keys.py @@ -133,7 +133,7 @@ def load(self, location: Union[str, Path]): message = f"Unable to load keys from {location} because it doesn't exist" raise ValueError(message) - keys = Keys.deserialize(bytes(location.read_bytes())) + keys = Keys.deserialize(location) # pylint: disable=protected-access self._specs = None @@ -185,20 +185,27 @@ def serialize(self) -> bytes: return serialized_keyset @staticmethod - def deserialize(serialized_keys: bytes) -> "Keys": + def deserialize(serialized_keys: Union[Path, bytes]) -> "Keys": """ - Deserialize keys from bytes. + Deserialize keys from file or buffer. + + Prefer using a Path instead of bytes in case of big Keys. It reduces memory usage. Args: - serialized_keys (bytes): - previously serialized keys + serialized_keys (Union[Path, bytes]): + previously serialized keys (either Path or buffer) Returns: Keys: deserialized keys """ - keyset = Keyset.deserialize(serialized_keys) + keyset = None + if isinstance(serialized_keys, Path): + keyset = Keyset.deserialize_from_file(str(serialized_keys)) + elif isinstance(serialized_keys, bytes): + keyset = Keyset.deserialize(serialized_keys) + assert keyset is not None, "serialized_keys should be either Path or bytes" # pylint: disable=protected-access result = Keys(None)