From 5b5dcf27b862e9220ce0d7485bff63a733841d03 Mon Sep 17 00:00:00 2001 From: youben11 Date: Mon, 9 Dec 2024 17:11:20 +0100 Subject: [PATCH] perf(frontend/compiler): support ser keyset using path reduce memory usage by avoiding unecessary copy --- .../lib/Bindings/Python/CompilerAPIModule.cpp | 15 +++++++++++++ .../concrete/fhe/compilation/keys.py | 21 ++++++++++++++++++- 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp index 3fe1c05b1..d69942796 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp @@ -1675,6 +1675,21 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( return pybind11::bytes(keySetSerialize(keySet)); }, "Serialize a Keyset to bytes.") + .def( + "serialize_to_file", + [](Keyset &keySet, const std::string path) { + std::ofstream ofs; + ofs.open(path); + if (!ofs.good()) { + throw std::runtime_error("Failed to open keyset file " + path); + } + auto keysetProto = keySet.toProto(); + auto maybeBuffer = keysetProto.writeBinaryToOstream(ofs); + if (maybeBuffer.has_failure()) { + throw std::runtime_error("Failed to serialize keys."); + } + }, + "Serialize a Keyset to bytes.") .def( "serialize_lwe_secret_key_as_glwe", [](Keyset &keyset, size_t keyIndex, size_t glwe_dimension, diff --git a/frontends/concrete-python/concrete/fhe/compilation/keys.py b/frontends/concrete-python/concrete/fhe/compilation/keys.py index 76356ca7c..9ef2a698c 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/keys.py +++ b/frontends/concrete-python/concrete/fhe/compilation/keys.py @@ -115,7 +115,7 @@ def save(self, location: Union[str, Path]): message = f"Unable to save keys to {location} because it already exists" raise ValueError(message) - location.write_bytes(self.serialize()) + self.serialize_to_file(location) def load(self, location: Union[str, Path]): """ @@ -171,6 +171,8 @@ def serialize(self) -> bytes: Serialize keys into bytes. Serialized keys are not encrypted, so be careful how you store/transfer them! + `serialize_to_file` is supposed to be more performant as it avoid copying the buffer + between the Compiler and the Frontend. Returns: bytes: @@ -184,6 +186,23 @@ def serialize(self) -> bytes: serialized_keyset = self._keyset.serialize() return serialized_keyset + def serialize_to_file(self, path: Path): + """ + Serialize keys into a file. + + Serialized keys are not encrypted, so be careful how you store/transfer them! + This is supposed to be more performant than `serialize` as it avoid copying the buffer + between the Compiler and the Frontend. + + Args: + path (Path): where to save serialized keys + """ + if self._keyset is None: + message = "Keys cannot be serialized before they are generated" + raise RuntimeError(message) + + self._keyset.serialize_to_file(str(path)) + @staticmethod def deserialize(serialized_keys: Union[Path, bytes]) -> "Keys": """