Skip to content

Commit

Permalink
refactor(compiler): remove TfhersFheIntDescription wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
youben11 committed Nov 7, 2024
1 parent 05585a3 commit 8c494a0
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 245 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -243,77 +243,6 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
m.def("check_gpu_runtime_enabled", &checkGPURuntimeEnabled);
m.def("check_cuda_device_available", &checkCudaDeviceAvailable);

pybind11::class_<TfhersFheIntDescription>(m, "TfhersFheIntDescription")
.def(pybind11::init([](size_t width, bool is_signed,
size_t message_modulus, size_t carry_modulus,
size_t degree, size_t lwe_size, size_t n_cts,
size_t noise_level, bool ks_first) {
auto desc = TfhersFheIntDescription();
desc.width = width;
desc.is_signed = is_signed;
desc.message_modulus = message_modulus;
desc.carry_modulus = carry_modulus;
desc.degree = degree;
desc.lwe_size = lwe_size;
desc.n_cts = n_cts;
desc.noise_level = noise_level;
desc.ks_first = ks_first;
return desc;
}))
.def_static("UNKNOWN_NOISE_LEVEL",
[] { return concrete_cpu_tfhers_unknown_noise_level(); })
.def_property(
"width", [](TfhersFheIntDescription &desc) { return desc.width; },
[](TfhersFheIntDescription &desc, size_t width) {
desc.width = width;
})
.def_property(
"message_modulus",
[](TfhersFheIntDescription &desc) { return desc.message_modulus; },
[](TfhersFheIntDescription &desc, size_t message_modulus) {
desc.message_modulus = message_modulus;
})
.def_property(
"carry_modulus",
[](TfhersFheIntDescription &desc) { return desc.carry_modulus; },
[](TfhersFheIntDescription &desc, size_t carry_modulus) {
desc.carry_modulus = carry_modulus;
})
.def_property(
"degree", [](TfhersFheIntDescription &desc) { return desc.degree; },
[](TfhersFheIntDescription &desc, size_t degree) {
desc.degree = degree;
})
.def_property(
"lwe_size",
[](TfhersFheIntDescription &desc) { return desc.lwe_size; },
[](TfhersFheIntDescription &desc, size_t lwe_size) {
desc.lwe_size = lwe_size;
})
.def_property(
"n_cts", [](TfhersFheIntDescription &desc) { return desc.n_cts; },
[](TfhersFheIntDescription &desc, size_t n_cts) {
desc.n_cts = n_cts;
})
.def_property(
"noise_level",
[](TfhersFheIntDescription &desc) { return desc.noise_level; },
[](TfhersFheIntDescription &desc, size_t noise_level) {
desc.noise_level = noise_level;
})
.def_property(
"is_signed",
[](TfhersFheIntDescription &desc) { return desc.is_signed; },
[](TfhersFheIntDescription &desc, bool is_signed) {
desc.is_signed = is_signed;
})
.def_property(
"ks_first",
[](TfhersFheIntDescription &desc) { return desc.ks_first; },
[](TfhersFheIntDescription &desc, bool ks_first) {
desc.ks_first = ks_first;
});

pybind11::enum_<mlir::concretelang::Backend>(m, "Backend")
.value("CPU", mlir::concretelang::Backend::CPU,
"Circuit codegen targets cpu.")
Expand Down Expand Up @@ -1993,6 +1922,109 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
"Return the `circuit` ClientCircuit.", arg("circuit"))
.doc() = "Client-side / Encryption program";

// ------------------------------------------------------------------------------//
// TFHERS INTEGER DESCRIPTION //
// ------------------------------------------------------------------------------//

pybind11::class_<TfhersFheIntDescription>(m, "TfhersFheIntDescription")
.def(pybind11::init([](size_t width, bool is_signed,
size_t message_modulus, size_t carry_modulus,
size_t degree, size_t lwe_size, size_t n_cts,
size_t noise_level, bool ks_first) {
auto desc = TfhersFheIntDescription();
desc.width = width;
desc.is_signed = is_signed;
desc.message_modulus = message_modulus;
desc.carry_modulus = carry_modulus;
desc.degree = degree;
desc.lwe_size = lwe_size;
desc.n_cts = n_cts;
desc.noise_level = noise_level;
desc.ks_first = ks_first;
return desc;
}),
arg("width"), arg("is_signed"), arg("lwe_size"), arg("n_cts"),
arg("degree"), arg("noise_level"), arg("message_modulus"),
arg("carry_modulus"), arg("ks_first"))
.def_static("get_unknown_noise_level",
[] { return concrete_cpu_tfhers_unknown_noise_level(); })
.def_property(
"width", [](TfhersFheIntDescription &desc) { return desc.width; },
[](TfhersFheIntDescription &desc, size_t width) {
desc.width = width;
})
.def_property(
"message_modulus",
[](TfhersFheIntDescription &desc) { return desc.message_modulus; },
[](TfhersFheIntDescription &desc, size_t message_modulus) {
desc.message_modulus = message_modulus;
})
.def_property(
"carry_modulus",
[](TfhersFheIntDescription &desc) { return desc.carry_modulus; },
[](TfhersFheIntDescription &desc, size_t carry_modulus) {
desc.carry_modulus = carry_modulus;
})
.def_property(
"degree", [](TfhersFheIntDescription &desc) { return desc.degree; },
[](TfhersFheIntDescription &desc, size_t degree) {
desc.degree = degree;
})
.def_property(
"lwe_size",
[](TfhersFheIntDescription &desc) { return desc.lwe_size; },
[](TfhersFheIntDescription &desc, size_t lwe_size) {
desc.lwe_size = lwe_size;
})
.def_property(
"n_cts", [](TfhersFheIntDescription &desc) { return desc.n_cts; },
[](TfhersFheIntDescription &desc, size_t n_cts) {
desc.n_cts = n_cts;
})
.def_property(
"noise_level",
[](TfhersFheIntDescription &desc) { return desc.noise_level; },
[](TfhersFheIntDescription &desc, size_t noise_level) {
desc.noise_level = noise_level;
})
.def_property(
"is_signed",
[](TfhersFheIntDescription &desc) { return desc.is_signed; },
[](TfhersFheIntDescription &desc, bool is_signed) {
desc.is_signed = is_signed;
})
.def_property(
"ks_first",
[](TfhersFheIntDescription &desc) { return desc.ks_first; },
[](TfhersFheIntDescription &desc, bool ks_first) {
desc.ks_first = ks_first;
})
.def("__str__",
[](TfhersFheIntDescription &desc) {
std::ostringstream stringStream;
stringStream << "tfhers_int_description<width=";
stringStream << desc.width;
stringStream << ", signed=";
stringStream << desc.is_signed;
stringStream << ", msg_mod=";
stringStream << desc.message_modulus;
stringStream << ", carry_mod=";
stringStream << desc.carry_modulus;
stringStream << ", degree=";
stringStream << desc.degree;
stringStream << ", lwe_size=";
stringStream << desc.lwe_size;
stringStream << ", n_cts=";
stringStream << desc.n_cts;
stringStream << ", noise_level=";
stringStream << desc.noise_level;
stringStream << ", ks_first=";
stringStream << desc.ks_first;
stringStream << ">";
return stringStream.str();
})
.doc() = "TFHE-rs integer description";

m.def("import_tfhers_int",
[](const pybind11::bytes &serialized_fheuint,
TfhersFheIntDescription info, uint32_t encryptionKeyId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
ServerKeyset,
Keyset,
Compiler,
TfhersFheIntDescription,
TransportValue,
Value,
ServerProgram,
Expand Down Expand Up @@ -47,10 +48,7 @@
from .compilation_feedback import MoreCircuitCompilationFeedback
from .compilation_context import CompilationContext

from .tfhers_int import (
TfhersExporter,
TfhersFheIntDescription,
)
from .tfhers_int import TfhersExporter

Parameter = Union[
LweSecretKeyParam, BootstrapKeyParam, KeyswitchKeyParam, PackingKeyswitchKeyParam
Expand Down
Loading

0 comments on commit 8c494a0

Please sign in to comment.