Skip to content

Commit

Permalink
perf(frontend/compiler): support deser keyset using path
Browse files Browse the repository at this point in the history
reduce memory usage by avoiding unecessary copy
  • Loading branch information
youben11 committed Nov 28, 2024
1 parent 2f41262 commit 2f0dc41
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1433,6 +1433,26 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
return std::make_unique<Keyset>(std::move(keyset));
},
"Deserialize a Keyset from bytes.", arg("bytes"))
.def_static(
"deserialize_from_file",
[](const std::string path) {
std::ifstream ifs;
ifs.open(path);
if (!ifs.good()) {
throw std::runtime_error("Failed to open keyset file " + path);
}

auto keysetProto = Message<concreteprotocol::Keyset>();
auto maybeError = keysetProto.readBinaryFromIstream(
ifs, mlir::concretelang::python::DESER_OPTIONS);
if (maybeError.has_failure()) {
throw std::runtime_error("Failed to deserialize keyset." +
maybeError.as_failure().error().mesg);
}
auto keyset = Keyset::fromProto(keysetProto);
return std::make_unique<Keyset>(std::move(keyset));
},
"Deserialize a Keyset from a file.", arg("path"))
.def(
"serialize",
[](Keyset &keySet) {
Expand Down
17 changes: 11 additions & 6 deletions frontends/concrete-python/concrete/fhe/compilation/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def load(self, location: Union[str, Path]):
message = f"Unable to load keys from {location} because it doesn't exist"
raise ValueError(message)

keys = Keys.deserialize(bytes(location.read_bytes()))
keys = Keys.deserialize(location)

# pylint: disable=protected-access
self._specs = None
Expand Down Expand Up @@ -185,20 +185,25 @@ def serialize(self) -> bytes:
return serialized_keyset

@staticmethod
def deserialize(serialized_keys: bytes) -> "Keys":
def deserialize(serialized_keys: Union[Path, bytes]) -> "Keys":
"""
Deserialize keys from bytes.
Deserialize keys from file or buffer.
Args:
serialized_keys (bytes):
previously serialized keys
serialized_keys (Union[Path, bytes]):
previously serialized keys (either Path or buffer)
Returns:
Keys:
deserialized keys
"""

keyset = Keyset.deserialize(serialized_keys)
keyset = None
if isinstance(serialized_keys, Path):
keyset = Keyset.deserialize_from_file(str(serialized_keys))
elif isinstance(serialized_keys, bytes):
keyset = Keyset.deserialize(serialized_keys)
assert keyset is not None, "serialized_keys should be either Path or bytes"

# pylint: disable=protected-access
result = Keys(None)
Expand Down

0 comments on commit 2f0dc41

Please sign in to comment.