diff --git a/backends/concrete-cpu/implementation/Cargo.lock b/backends/concrete-cpu/implementation/Cargo.lock index dee61949b3..247a1f2b10 100644 --- a/backends/concrete-cpu/implementation/Cargo.lock +++ b/backends/concrete-cpu/implementation/Cargo.lock @@ -223,7 +223,8 @@ dependencies = [ [[package]] name = "concrete-csprng" version = "0.4.1" -source = "git+https://github.com/zama-ai/tfhe-rs.git?rev=483a4fe#483a4fecf1a6b2e2a7827251a04608d018c83002" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90518357249582c16a6b64d7410243dfb3109d5bf0ad1665c058c9a59f2fc4cc" dependencies = [ "aes", "libc", @@ -864,8 +865,9 @@ checksum = "23d434d3f8967a09480fb04132ebe0a3e088c173e6d0ee7897abbdf4eab0f8b9" [[package]] name = "tfhe" -version = "0.8.0" -source = "git+https://github.com/zama-ai/tfhe-rs.git?rev=483a4fe#483a4fecf1a6b2e2a7827251a04608d018c83002" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e2648b0df14216576ea543bb9021beed019d7eb43fedacdbc24a0c095d33d2a" dependencies = [ "aligned-vec", "bincode", @@ -886,8 +888,9 @@ dependencies = [ [[package]] name = "tfhe-versionable" -version = "0.3.2" -source = "git+https://github.com/zama-ai/tfhe-rs.git?rev=483a4fe#483a4fecf1a6b2e2a7827251a04608d018c83002" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "feeb340d850c65660b321e5379a28b2f3b226c61163de0a12766aedfe8575a29" dependencies = [ "aligned-vec", "num-complex", @@ -897,8 +900,9 @@ dependencies = [ [[package]] name = "tfhe-versionable-derive" -version = "0.3.2" -source = "git+https://github.com/zama-ai/tfhe-rs.git?rev=483a4fe#483a4fecf1a6b2e2a7827251a04608d018c83002" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01d985f9645ed62be4aefb9c06ec70563291ec475036ebcd2cf95c5429a12e8a" dependencies = [ "proc-macro2", "quote", diff --git a/backends/concrete-cpu/implementation/Cargo.toml b/backends/concrete-cpu/implementation/Cargo.toml index 686674dd65..55be115d27 100644 --- a/backends/concrete-cpu/implementation/Cargo.toml +++ b/backends/concrete-cpu/implementation/Cargo.toml @@ -10,9 +10,7 @@ crate-type = ["lib", "staticlib"] [dependencies] -concrete-csprng = { git = "https://github.com/zama-ai/tfhe-rs.git", rev = "483a4fe", optional = true, features = [ - "generator_fallback", -] } +concrete-csprng = { version = "0.4.1", optional = true, features = ["generator_fallback"] } concrete-cpu-noise-model = { path = "../noise-model/" } concrete-security-curves = { path = "../../../tools/parameter-curves/concrete-security-curves-rust" } libc = { version = "0.2", default-features = false } @@ -31,16 +29,16 @@ serde = "~1" rayon = { version = "1.6", optional = true } once_cell = { version = "1.16", optional = true } -tfhe = { git = "https://github.com/zama-ai/tfhe-rs.git", rev = "483a4fe", features = ["integer"] } +tfhe = { version = "0.8.6", features = ["integer"] } [target.x86_64-unknown-unix-gnu.dependencies] -tfhe = { git = "https://github.com/zama-ai/tfhe-rs.git", rev = "483a4fe", features = ["integer", "x86_64-unix"] } +tfhe = { version = "0.8.6", features = ["integer", "x86_64-unix"] } [target.aarch64-unknown-unix-gnu.dependencies] -tfhe = { git = "https://github.com/zama-ai/tfhe-rs.git", rev = "483a4fe", features = ["integer", "aarch64-unix"] } +tfhe = { version = "0.8.6", features = ["integer", "aarch64-unix"] } [target.x86_64-pc-windows-gnu.dependencies] -tfhe = { git = "https://github.com/zama-ai/tfhe-rs.git", rev = "483a4fe", features = ["integer", "x86_64"] } +tfhe = { version = "0.8.6", features = ["integer", "x86_64"] } [features] default = ["parallel", "std", "csprng"] diff --git a/backends/concrete-cpu/implementation/include/concrete-cpu.h b/backends/concrete-cpu/implementation/include/concrete-cpu.h index 4d9cc35d6a..2f51dbd84b 100644 --- a/backends/concrete-cpu/implementation/include/concrete-cpu.h +++ b/backends/concrete-cpu/implementation/include/concrete-cpu.h @@ -372,12 +372,14 @@ void concrete_cpu_keyswitch_lwe_ciphertext_u64(uint64_t *ct_out, size_t concrete_cpu_lwe_array_to_tfhers_int8(const uint64_t *lwe_vec_buffer, uint8_t *buffer, size_t buffer_len, - struct TfhersFheIntDescription fheint_desc); + size_t n_elem, + struct TfhersFheIntDescription desc); size_t concrete_cpu_lwe_array_to_tfhers_uint8(const uint64_t *lwe_vec_buffer, uint8_t *buffer, size_t buffer_len, - struct TfhersFheIntDescription fheuint_desc); + size_t n_elem, + struct TfhersFheIntDescription desc); size_t concrete_cpu_lwe_ciphertext_size_u64(size_t lwe_dimension); @@ -418,16 +420,18 @@ size_t concrete_cpu_serialize_lwe_secret_key_u64(const uint64_t *lwe_sk, uint8_t *out_buffer, size_t out_buffer_len); -size_t concrete_cpu_tfhers_fheint_buffer_size_u64(size_t lwe_size, size_t n_cts); +size_t concrete_cpu_tfhers_fheint_buffer_size_u64(size_t lwe_size, size_t n_cts, size_t n_elem); -int64_t concrete_cpu_tfhers_int8_to_lwe_array(const uint8_t *serialized_data_ptr, - size_t serialized_data_len, +int64_t concrete_cpu_tfhers_int8_to_lwe_array(const uint8_t *buffer, + size_t buffer_len, uint64_t *lwe_vec_buffer, + size_t n_elem, struct TfhersFheIntDescription desc); int64_t concrete_cpu_tfhers_uint8_to_lwe_array(const uint8_t *buffer, size_t buffer_len, uint64_t *lwe_vec_buffer, + size_t n_elem, struct TfhersFheIntDescription desc); size_t concrete_cpu_tfhers_unknown_noise_level(void); diff --git a/backends/concrete-cpu/implementation/src/c_api/fheint.rs b/backends/concrete-cpu/implementation/src/c_api/fheint.rs index 3eecab1e98..eea2040513 100644 --- a/backends/concrete-cpu/implementation/src/c_api/fheint.rs +++ b/backends/concrete-cpu/implementation/src/c_api/fheint.rs @@ -150,8 +150,7 @@ pub fn tfhers_int8_description(fheuint: FheInt8) -> TfhersFheIntDescription { } } -#[no_mangle] -pub unsafe extern "C" fn concrete_cpu_tfhers_uint8_to_lwe_array( +unsafe fn tfhers_uint8_to_lwe_array( buffer: *const u8, buffer_len: usize, lwe_vec_buffer: *mut u64, @@ -165,15 +164,11 @@ pub unsafe extern "C" fn concrete_cpu_tfhers_uint8_to_lwe_array( return 1; } + let lwe_size = desc.lwe_size; + let n_cts = desc.n_cts; // collect LWEs from fheuint let (radix, _, _) = fheuint.into_raw_parts(); let blocks = radix.blocks(); - let first_ct = match blocks.first() { - Some(value) => &value.ct, - None => return 1, - }; - let lwe_size = first_ct.lwe_size().0; - let n_cts = blocks.len(); // copy LWEs to C buffer. Note that lsb is cts[0] let lwe_vector: &mut [u64] = slice::from_raw_parts_mut(lwe_vec_buffer, n_cts * lwe_size); for (i, block) in blocks.iter().enumerate() { @@ -184,31 +179,76 @@ pub unsafe extern "C" fn concrete_cpu_tfhers_uint8_to_lwe_array( }) } +unsafe fn tfhers_uint8_array_to_lwe_array( + buffer: *const u8, + buffer_len: usize, + mut lwe_vec_buffer: *mut u64, + desc: TfhersFheIntDescription, +) -> i64 { + nounwind(|| { + let fheuint_array: Vec = super::utils::unsafe_deserialize(buffer, buffer_len); + // TODO - Use conformance check + let fheuint_desc = tfhers_uint8_description(fheuint_array[0].clone()); + if !fheuint_desc.is_similar(&desc) { + return 1; + } + + let lwe_size: usize = desc.lwe_size; + let n_cts: usize = desc.n_cts; + let blocks_size = n_cts * lwe_size; + // collect LWEs from fheuint + for fheuint in fheuint_array { + let (radix, _, _) = fheuint.into_raw_parts(); + let blocks = radix.blocks(); + // copy LWEs to C buffer. Note that lsb is cts[0] + let lwe_vector: &mut [u64] = + slice::from_raw_parts_mut(lwe_vec_buffer, n_cts * lwe_size); + for (i, block) in blocks.iter().enumerate() { + lwe_vector[i * lwe_size..(i + 1) * lwe_size] + .copy_from_slice(block.ct.clone().into_container().as_slice()); + } + // shift to next block + lwe_vec_buffer = lwe_vec_buffer.add(blocks_size); + } + 0 + }) +} + #[no_mangle] -pub unsafe extern "C" fn concrete_cpu_tfhers_int8_to_lwe_array( - serialized_data_ptr: *const u8, - serialized_data_len: usize, +pub unsafe extern "C" fn concrete_cpu_tfhers_uint8_to_lwe_array( + buffer: *const u8, + buffer_len: usize, + lwe_vec_buffer: *mut u64, + n_elem: usize, + desc: TfhersFheIntDescription, +) -> i64 { + assert!(n_elem > 0); + if n_elem == 1 { + tfhers_uint8_to_lwe_array(buffer, buffer_len, lwe_vec_buffer, desc) + } else { + tfhers_uint8_array_to_lwe_array(buffer, buffer_len, lwe_vec_buffer, desc) + } +} + +unsafe fn tfhers_int8_to_lwe_array( + buffer: *const u8, + buffer_len: usize, lwe_vec_buffer: *mut u64, desc: TfhersFheIntDescription, ) -> i64 { nounwind(|| { - let fheint: FheInt8 = - super::utils::safe_deserialize(serialized_data_ptr, serialized_data_len); + let fheint: FheInt8 = super::utils::safe_deserialize(buffer, buffer_len); // TODO - Use conformance check let fheint_desc = tfhers_int8_description(fheint.clone()); if !fheint_desc.is_similar(&desc) { return 1; } - // collect LWEs from fheuint + let lwe_size = desc.lwe_size; + let n_cts = desc.n_cts; + // collect LWEs from fheint let (radix, _, _) = fheint.into_raw_parts(); let blocks = radix.blocks(); - let first_ct = match blocks.first() { - Some(value) => &value.ct, - None => return 1, - }; - let lwe_size = first_ct.lwe_size().0; - let n_cts = blocks.len(); // copy LWEs to C buffer. Note that lsb is cts[0] let lwe_vector: &mut [u64] = slice::from_raw_parts_mut(lwe_vec_buffer, n_cts * lwe_size); for (i, block) in blocks.iter().enumerate() { @@ -219,41 +259,98 @@ pub unsafe extern "C" fn concrete_cpu_tfhers_int8_to_lwe_array( }) } +unsafe fn tfhers_int8_array_to_lwe_array( + buffer: *const u8, + buffer_len: usize, + mut lwe_vec_buffer: *mut u64, + desc: TfhersFheIntDescription, +) -> i64 { + nounwind(|| { + let fheint_array: Vec = super::utils::unsafe_deserialize(buffer, buffer_len); + // TODO - Use conformance check + let fheint_desc = tfhers_int8_description(fheint_array[0].clone()); + if !fheint_desc.is_similar(&desc) { + return 1; + } + + let lwe_size: usize = desc.lwe_size; + let n_cts: usize = desc.n_cts; + let blocks_size = n_cts * lwe_size; + // collect LWEs from fheint + for fheint in fheint_array { + let (radix, _, _) = fheint.into_raw_parts(); + let blocks = radix.blocks(); + // copy LWEs to C buffer. Note that lsb is cts[0] + let lwe_vector: &mut [u64] = + slice::from_raw_parts_mut(lwe_vec_buffer, n_cts * lwe_size); + for (i, block) in blocks.iter().enumerate() { + lwe_vector[i * lwe_size..(i + 1) * lwe_size] + .copy_from_slice(block.ct.clone().into_container().as_slice()); + } + // shift to next block + lwe_vec_buffer = lwe_vec_buffer.add(blocks_size); + } + 0 + }) +} + +#[no_mangle] +pub unsafe extern "C" fn concrete_cpu_tfhers_int8_to_lwe_array( + buffer: *const u8, + buffer_len: usize, + lwe_vec_buffer: *mut u64, + n_elem: usize, + desc: TfhersFheIntDescription, +) -> i64 { + assert!(n_elem > 0); + if n_elem == 1 { + tfhers_int8_to_lwe_array(buffer, buffer_len, lwe_vec_buffer, desc) + } else { + tfhers_int8_array_to_lwe_array(buffer, buffer_len, lwe_vec_buffer, desc) + } +} + #[no_mangle] pub extern "C" fn concrete_cpu_tfhers_fheint_buffer_size_u64( lwe_size: usize, n_cts: usize, + n_elem: usize, ) -> usize { // TODO - that is fragile // all FheUint should have the same size, but we use a big one to be safe let meta_fheuint = core::mem::size_of::(); let meta_ct = core::mem::size_of::(); - // FheUint[metadata, ciphertexts[ciphertext[metadata, lwe_buffer] * n_cts]] + headers - (meta_fheuint + (meta_ct + lwe_size * 8/*u64*/) * n_cts) + 201 + if n_elem <= 1 { + // FheUint[metadata, ciphertexts[ciphertext[metadata, lwe_buffer] * n_cts]] + headers + (meta_fheuint + (meta_ct + lwe_size * 8/*u64*/) * n_cts) + 201 + } else { + let meta_vec: usize = core::mem::size_of::>(); + // Vec[FheUint[metadata, ciphertexts[ciphertext[metadata, lwe_buffer] * n_cts]] + headers] * n_elem + meta_vec + ((meta_fheuint + (meta_ct + lwe_size * 8/*u64*/) * n_cts) + 201) * n_elem + } } -#[no_mangle] -pub unsafe extern "C" fn concrete_cpu_lwe_array_to_tfhers_uint8( +unsafe fn lwe_array_to_tfhers_uint8( lwe_vec_buffer: *const u64, buffer: *mut u8, buffer_len: usize, - fheuint_desc: TfhersFheIntDescription, + desc: TfhersFheIntDescription, ) -> usize { nounwind(|| { // we want to trigger a PBS on TFHErs side assert!( - fheuint_desc.noise_level == NoiseLevel::UNKNOWN.get(), + desc.noise_level == NoiseLevel::UNKNOWN.get(), "noise_level must be unknown" ); // we want to use the max degree as we don't track it on Concrete side assert!( - fheuint_desc.degree == fheuint_desc.message_modulus - 1, + desc.degree == desc.message_modulus - 1, "degree must be the max value (msg_modulus - 1)" ); - let lwe_size = fheuint_desc.lwe_size; - let n_cts = fheuint_desc.n_cts; + let lwe_size = desc.lwe_size; + let n_cts = desc.n_cts; // construct fheuint from LWEs let lwe_vector: &[u64] = slice::from_raw_parts(lwe_vec_buffer, n_cts * lwe_size); let mut blocks: Vec = Vec::with_capacity(n_cts); @@ -262,9 +359,9 @@ pub unsafe extern "C" fn concrete_cpu_lwe_array_to_tfhers_uint8( lwe_vector[i * lwe_size..(i + 1) * lwe_size].to_vec(), CiphertextModulus::new_native(), ); - blocks.push(fheuint_desc.ct_from_lwe(lwe_ct)); + blocks.push(desc.ct_from_lwe(lwe_ct)); } - let fheuint = match FheUint8::from_expanded_blocks(blocks, fheuint_desc.data_kind()) { + let fheuint = match FheUint8::from_expanded_blocks(blocks, desc.data_kind()) { Ok(value) => value, Err(_e) => { return 0; @@ -274,28 +371,90 @@ pub unsafe extern "C" fn concrete_cpu_lwe_array_to_tfhers_uint8( }) } -#[no_mangle] -pub unsafe extern "C" fn concrete_cpu_lwe_array_to_tfhers_int8( +unsafe fn lwe_array_to_tfhers_uint8_array( lwe_vec_buffer: *const u64, buffer: *mut u8, buffer_len: usize, - fheint_desc: TfhersFheIntDescription, + n_elem: usize, + desc: TfhersFheIntDescription, ) -> usize { nounwind(|| { // we want to trigger a PBS on TFHErs side assert!( - fheint_desc.noise_level == NoiseLevel::UNKNOWN.get(), + desc.noise_level == NoiseLevel::UNKNOWN.get(), "noise_level must be unknown" ); // we want to use the max degree as we don't track it on Concrete side assert!( - fheint_desc.degree == fheint_desc.message_modulus - 1, + desc.degree == desc.message_modulus - 1, "degree must be the max value (msg_modulus - 1)" ); - let lwe_size = fheint_desc.lwe_size; - let n_cts = fheint_desc.n_cts; + let lwe_size = desc.lwe_size; + let n_cts = desc.n_cts; // construct fheuint from LWEs + let lwe_vector: &[u64] = slice::from_raw_parts(lwe_vec_buffer, n_elem * n_cts * lwe_size); + let mut fheuint_array: Vec = Vec::with_capacity(n_elem); + for i in 0..n_elem { + let elem_offset = i * n_cts * lwe_size; + let mut blocks: Vec = Vec::with_capacity(n_cts); + for j in 0..n_cts { + let lwe_ct = LweCiphertext::>::from_container( + lwe_vector[elem_offset + j * lwe_size..elem_offset + (j + 1) * lwe_size] + .to_vec(), + CiphertextModulus::new_native(), + ); + blocks.push(desc.ct_from_lwe(lwe_ct)); + } + let fheuint = match FheUint8::from_expanded_blocks(blocks, desc.data_kind()) { + Ok(value) => value, + Err(_) => { + return 0; + } + }; + fheuint_array.push(fheuint); + } + super::utils::unsafe_serialize(&fheuint_array, buffer, buffer_len) + }) +} + +#[no_mangle] +pub unsafe extern "C" fn concrete_cpu_lwe_array_to_tfhers_uint8( + lwe_vec_buffer: *const u64, + buffer: *mut u8, + buffer_len: usize, + n_elem: usize, + desc: TfhersFheIntDescription, +) -> usize { + assert!(n_elem > 0); + if n_elem == 1 { + lwe_array_to_tfhers_uint8(lwe_vec_buffer, buffer, buffer_len, desc) + } else { + lwe_array_to_tfhers_uint8_array(lwe_vec_buffer, buffer, buffer_len, n_elem, desc) + } +} + +unsafe fn lwe_array_to_tfhers_int8( + lwe_vec_buffer: *const u64, + buffer: *mut u8, + buffer_len: usize, + desc: TfhersFheIntDescription, +) -> usize { + nounwind(|| { + // we want to trigger a PBS on TFHErs side + assert!( + desc.noise_level == NoiseLevel::UNKNOWN.get(), + "noise_level must be unknown" + ); + // we want to use the max degree as we don't track it on Concrete side + assert!( + desc.degree == desc.message_modulus - 1, + "degree must be the max value (msg_modulus - 1)" + ); + + let lwe_size = desc.lwe_size; + let n_cts = desc.n_cts; + // construct fheint from LWEs let lwe_vector: &[u64] = slice::from_raw_parts(lwe_vec_buffer, n_cts * lwe_size); let mut blocks: Vec = Vec::with_capacity(n_cts); for i in 0..n_cts { @@ -303,15 +462,79 @@ pub unsafe extern "C" fn concrete_cpu_lwe_array_to_tfhers_int8( lwe_vector[i * lwe_size..(i + 1) * lwe_size].to_vec(), CiphertextModulus::new_native(), ); - blocks.push(fheint_desc.ct_from_lwe(lwe_ct)); + blocks.push(desc.ct_from_lwe(lwe_ct)); } - let fheuint = match FheInt8::from_expanded_blocks(blocks, fheint_desc.data_kind()) { + let fheint = match FheInt8::from_expanded_blocks(blocks, desc.data_kind()) { Ok(value) => value, Err(_) => { return 0; } }; - super::utils::safe_serialize(&fheuint, buffer, buffer_len) + super::utils::safe_serialize(&fheint, buffer, buffer_len) + }) +} + +unsafe fn lwe_array_to_tfhers_int8_array( + lwe_vec_buffer: *const u64, + buffer: *mut u8, + buffer_len: usize, + n_elem: usize, + desc: TfhersFheIntDescription, +) -> usize { + nounwind(|| { + // we want to trigger a PBS on TFHErs side + assert!( + desc.noise_level == NoiseLevel::UNKNOWN.get(), + "noise_level must be unknown" + ); + // we want to use the max degree as we don't track it on Concrete side + assert!( + desc.degree == desc.message_modulus - 1, + "degree must be the max value (msg_modulus - 1)" + ); + + let lwe_size = desc.lwe_size; + let n_cts = desc.n_cts; + // construct fheint from LWEs + let lwe_vector: &[u64] = slice::from_raw_parts(lwe_vec_buffer, n_elem * n_cts * lwe_size); + let mut fheint_array: Vec = Vec::with_capacity(n_elem); + for i in 0..n_elem { + let elem_offset = i * n_cts * lwe_size; + let mut blocks: Vec = Vec::with_capacity(n_cts); + for j in 0..n_cts { + let lwe_ct = LweCiphertext::>::from_container( + lwe_vector[elem_offset + j * lwe_size..elem_offset + (j + 1) * lwe_size] + .to_vec(), + CiphertextModulus::new_native(), + ); + blocks.push(desc.ct_from_lwe(lwe_ct)); + } + let fheint = match FheInt8::from_expanded_blocks(blocks, desc.data_kind()) { + Ok(value) => value, + Err(_) => { + return 0; + } + }; + fheint_array.push(fheint); + } + + super::utils::unsafe_serialize(&fheint_array, buffer, buffer_len) }) } + +#[no_mangle] +pub unsafe extern "C" fn concrete_cpu_lwe_array_to_tfhers_int8( + lwe_vec_buffer: *const u64, + buffer: *mut u8, + buffer_len: usize, + n_elem: usize, + desc: TfhersFheIntDescription, +) -> usize { + assert!(n_elem > 0); + if n_elem == 1 { + lwe_array_to_tfhers_int8(lwe_vec_buffer, buffer, buffer_len, desc) + } else { + lwe_array_to_tfhers_int8_array(lwe_vec_buffer, buffer, buffer_len, n_elem, desc) + } +} diff --git a/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/ClientLib.h b/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/ClientLib.h index 1aacf41efa..8d74364ef2 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/ClientLib.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/ClientLib/ClientLib.h @@ -37,7 +37,8 @@ namespace clientlib { Result importTfhersInteger(llvm::ArrayRef buffer, TfhersFheIntDescription integerDesc, uint32_t encryptionKeyId, - double encryptionVariance); + double encryptionVariance, + std::vector shape = {}); Result> exportTfhersInteger(TransportValue value, TfhersFheIntDescription integerDesc); diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CMakeLists.txt b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CMakeLists.txt index b3e0639cea..2beee8d8e9 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CMakeLists.txt +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CMakeLists.txt @@ -39,7 +39,6 @@ declare_mlir_python_sources( concrete/compiler/compilation_context.py concrete/compiler/tfhers_int.py concrete/compiler/utils.py - concrete/compiler/wrapper.py concrete/__init__.py concrete/lang/__init__.py concrete/lang/dialects/__init__.py diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp index 5c5b264a59..112bf2b91e 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp @@ -243,77 +243,6 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( m.def("check_gpu_runtime_enabled", &checkGPURuntimeEnabled); m.def("check_cuda_device_available", &checkCudaDeviceAvailable); - pybind11::class_(m, "TfhersFheIntDescription") - .def(pybind11::init([](size_t width, bool is_signed, - size_t message_modulus, size_t carry_modulus, - size_t degree, size_t lwe_size, size_t n_cts, - size_t noise_level, bool ks_first) { - auto desc = TfhersFheIntDescription(); - desc.width = width; - desc.is_signed = is_signed; - desc.message_modulus = message_modulus; - desc.carry_modulus = carry_modulus; - desc.degree = degree; - desc.lwe_size = lwe_size; - desc.n_cts = n_cts; - desc.noise_level = noise_level; - desc.ks_first = ks_first; - return desc; - })) - .def_static("UNKNOWN_NOISE_LEVEL", - [] { return concrete_cpu_tfhers_unknown_noise_level(); }) - .def_property( - "width", [](TfhersFheIntDescription &desc) { return desc.width; }, - [](TfhersFheIntDescription &desc, size_t width) { - desc.width = width; - }) - .def_property( - "message_modulus", - [](TfhersFheIntDescription &desc) { return desc.message_modulus; }, - [](TfhersFheIntDescription &desc, size_t message_modulus) { - desc.message_modulus = message_modulus; - }) - .def_property( - "carry_modulus", - [](TfhersFheIntDescription &desc) { return desc.carry_modulus; }, - [](TfhersFheIntDescription &desc, size_t carry_modulus) { - desc.carry_modulus = carry_modulus; - }) - .def_property( - "degree", [](TfhersFheIntDescription &desc) { return desc.degree; }, - [](TfhersFheIntDescription &desc, size_t degree) { - desc.degree = degree; - }) - .def_property( - "lwe_size", - [](TfhersFheIntDescription &desc) { return desc.lwe_size; }, - [](TfhersFheIntDescription &desc, size_t lwe_size) { - desc.lwe_size = lwe_size; - }) - .def_property( - "n_cts", [](TfhersFheIntDescription &desc) { return desc.n_cts; }, - [](TfhersFheIntDescription &desc, size_t n_cts) { - desc.n_cts = n_cts; - }) - .def_property( - "noise_level", - [](TfhersFheIntDescription &desc) { return desc.noise_level; }, - [](TfhersFheIntDescription &desc, size_t noise_level) { - desc.noise_level = noise_level; - }) - .def_property( - "is_signed", - [](TfhersFheIntDescription &desc) { return desc.is_signed; }, - [](TfhersFheIntDescription &desc, bool is_signed) { - desc.is_signed = is_signed; - }) - .def_property( - "ks_first", - [](TfhersFheIntDescription &desc) { return desc.ks_first; }, - [](TfhersFheIntDescription &desc, bool ks_first) { - desc.ks_first = ks_first; - }); - pybind11::enum_(m, "Backend") .value("CPU", mlir::concretelang::Backend::CPU, "Circuit codegen targets cpu.") @@ -1993,15 +1922,118 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( "Return the `circuit` ClientCircuit.", arg("circuit")) .doc() = "Client-side / Encryption program"; + // ------------------------------------------------------------------------------// + // TFHERS INTEGER DESCRIPTION // + // ------------------------------------------------------------------------------// + + pybind11::class_(m, "TfhersFheIntDescription") + .def(pybind11::init([](size_t width, bool is_signed, + size_t message_modulus, size_t carry_modulus, + size_t degree, size_t lwe_size, size_t n_cts, + size_t noise_level, bool ks_first) { + auto desc = TfhersFheIntDescription(); + desc.width = width; + desc.is_signed = is_signed; + desc.message_modulus = message_modulus; + desc.carry_modulus = carry_modulus; + desc.degree = degree; + desc.lwe_size = lwe_size; + desc.n_cts = n_cts; + desc.noise_level = noise_level; + desc.ks_first = ks_first; + return desc; + }), + arg("width"), arg("is_signed"), arg("lwe_size"), arg("n_cts"), + arg("degree"), arg("noise_level"), arg("message_modulus"), + arg("carry_modulus"), arg("ks_first")) + .def_static("get_unknown_noise_level", + [] { return concrete_cpu_tfhers_unknown_noise_level(); }) + .def_property( + "width", [](TfhersFheIntDescription &desc) { return desc.width; }, + [](TfhersFheIntDescription &desc, size_t width) { + desc.width = width; + }) + .def_property( + "message_modulus", + [](TfhersFheIntDescription &desc) { return desc.message_modulus; }, + [](TfhersFheIntDescription &desc, size_t message_modulus) { + desc.message_modulus = message_modulus; + }) + .def_property( + "carry_modulus", + [](TfhersFheIntDescription &desc) { return desc.carry_modulus; }, + [](TfhersFheIntDescription &desc, size_t carry_modulus) { + desc.carry_modulus = carry_modulus; + }) + .def_property( + "degree", [](TfhersFheIntDescription &desc) { return desc.degree; }, + [](TfhersFheIntDescription &desc, size_t degree) { + desc.degree = degree; + }) + .def_property( + "lwe_size", + [](TfhersFheIntDescription &desc) { return desc.lwe_size; }, + [](TfhersFheIntDescription &desc, size_t lwe_size) { + desc.lwe_size = lwe_size; + }) + .def_property( + "n_cts", [](TfhersFheIntDescription &desc) { return desc.n_cts; }, + [](TfhersFheIntDescription &desc, size_t n_cts) { + desc.n_cts = n_cts; + }) + .def_property( + "noise_level", + [](TfhersFheIntDescription &desc) { return desc.noise_level; }, + [](TfhersFheIntDescription &desc, size_t noise_level) { + desc.noise_level = noise_level; + }) + .def_property( + "is_signed", + [](TfhersFheIntDescription &desc) { return desc.is_signed; }, + [](TfhersFheIntDescription &desc, bool is_signed) { + desc.is_signed = is_signed; + }) + .def_property( + "ks_first", + [](TfhersFheIntDescription &desc) { return desc.ks_first; }, + [](TfhersFheIntDescription &desc, bool ks_first) { + desc.ks_first = ks_first; + }) + .def("__str__", + [](TfhersFheIntDescription &desc) { + std::ostringstream stringStream; + stringStream << "tfhers_int_description"; + return stringStream.str(); + }) + .doc() = "TFHE-rs integer description"; + m.def("import_tfhers_int", [](const pybind11::bytes &serialized_fheuint, TfhersFheIntDescription info, uint32_t encryptionKeyId, - double encryptionVariance) { + double encryptionVariance, std::vector shape) { const std::string &buffer_str = serialized_fheuint; std::vector buffer(buffer_str.begin(), buffer_str.end()); auto arrayRef = llvm::ArrayRef(buffer); auto valueOrError = ::concretelang::clientlib::importTfhersInteger( - arrayRef, info, encryptionKeyId, encryptionVariance); + arrayRef, info, encryptionKeyId, encryptionVariance, shape); if (valueOrError.has_error()) { throw std::runtime_error(valueOrError.error().mesg); } diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/__init__.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/__init__.py index ca8dd267e4..533dfb4804 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/__init__.py +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/__init__.py @@ -18,6 +18,7 @@ ServerKeyset, Keyset, Compiler, + TfhersFheIntDescription, TransportValue, Value, ServerProgram, @@ -47,10 +48,7 @@ from .compilation_feedback import MoreCircuitCompilationFeedback from .compilation_context import CompilationContext -from .tfhers_int import ( - TfhersExporter, - TfhersFheIntDescription, -) +from .tfhers_int import TfhersExporter Parameter = Union[ LweSecretKeyParam, BootstrapKeyParam, KeyswitchKeyParam, PackingKeyswitchKeyParam diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/tfhers_int.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/tfhers_int.py index 9491743c2e..17b9e65139 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/tfhers_int.py +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/tfhers_int.py @@ -1,183 +1,19 @@ """Import and export TFHErs integers into Concrete.""" +from typing import Tuple + # pylint: disable=no-name-in-module,import-error, from mlir._mlir_libs._concretelang._compiler import ( import_tfhers_int as _import_tfhers_int, export_tfhers_int as _export_tfhers_int, - TfhersFheIntDescription as _TfhersFheIntDescription, + TfhersFheIntDescription, TransportValue, ) -from .wrapper import WrapperCpp # pylint: enable=no-name-in-module,import-error -class TfhersFheIntDescription(WrapperCpp): - """A helper class to create `TfhersFheIntDescription`s.""" - - def __init__(self, desc: _TfhersFheIntDescription): - """ - Wrap the native C++ object. - - Args: - desc (_TfhersFheIntDescription): - object to wrap - - Raises: - TypeError: - if `desc` is not of type `_TfhersFheIntDescription` - """ - - if not isinstance(desc, _TfhersFheIntDescription): - raise TypeError( - f"desc must be of type _TfhersFheIntDescription, not {type(desc)}" - ) - - super().__init__(desc) - - @staticmethod - # pylint: disable=arguments-differ - def new( - width: int, - is_signed: bool, - message_modulus: int, - carry_modulus: int, - degree: int, - lwe_size: int, - n_cts: int, - noise_level: int, - ks_first: bool, - ) -> "TfhersFheIntDescription": - """Create a TfhersFheIntDescription. - - Args: - width (int): integer width - is_signed (bool): signed or unsigned - message_modulus (int): message modulus (not its log2) - carry_modulus (int): carry modulus (not its log2) - degree (int): degree - lwe_size (int): LWE size - n_cts (int): number of ciphertexts - noise_level (int): noise level - ks_first (bool): PBS order (keyswitch first, or bootstrap first) - - Returns: - TfhersFheIntDescription: TFHErs integer description - """ - return TfhersFheIntDescription( - _TfhersFheIntDescription( - width, - is_signed, - message_modulus, - carry_modulus, - degree, - lwe_size, - n_cts, - noise_level, - ks_first, - ) - ) - - @property - def width(self) -> int: - """Total integer bitwidth""" - return self.cpp().width - - @width.setter - def width(self, width: int): - self.cpp().width = width - - @property - def is_signed(self) -> bool: - """Is the integer signed""" - return self.cpp().is_signed - - @is_signed.setter - def is_signed(self, is_signed: bool): - self.cpp().is_signed = is_signed - - @property - def message_modulus(self) -> int: - """Modulus of the message part in each ciphertext""" - return self.cpp().message_modulus - - @message_modulus.setter - def message_modulus(self, message_modulus: int): - self.cpp().message_modulus = message_modulus - - @property - def carry_modulus(self) -> int: - """Modulus of the carry part in each ciphertext""" - return self.cpp().carry_modulus - - @carry_modulus.setter - def carry_modulus(self, carry_modulus: int): - self.cpp().carry_modulus = carry_modulus - - @property - def degree(self) -> int: - """Tracks the number of operations that have been done""" - return self.cpp().degree - - @degree.setter - def degree(self, degree: int): - self.cpp().degree = degree - - @property - def lwe_size(self) -> int: - """LWE size""" - return self.cpp().lwe_size - - @lwe_size.setter - def lwe_size(self, lwe_size: int): - self.cpp().lwe_size = lwe_size - - @property - def n_cts(self) -> int: - """Number of ciphertexts""" - return self.cpp().n_cts - - @n_cts.setter - def n_cts(self, n_cts: int): - self.cpp().n_cts = n_cts - - @property - def noise_level(self) -> int: - """Noise level""" - return self.cpp().noise_level - - @noise_level.setter - def noise_level(self, noise_level: int): - self.cpp().noise_level = noise_level - - @staticmethod - def get_unknown_noise_level() -> int: - """Get unknow noise level value. - - Returns: - int: unknown noise level value - """ - return _TfhersFheIntDescription.UNKNOWN_NOISE_LEVEL() - - @property - def ks_first(self) -> bool: - """Keyswitch placement relative to the bootsrap in a PBS""" - return self.cpp().ks_first - - @ks_first.setter - def ks_first(self, ks_first: bool): - self.cpp().ks_first = ks_first - - def __str__(self) -> str: - return ( - f"tfhers_int_description" - ) - - class TfhersExporter: """A helper class to import and export TFHErs big integers.""" @@ -201,11 +37,15 @@ def export_int(value: TransportValue, info: TfhersFheIntDescription) -> bytes: raise TypeError( f"info must be of type TfhersFheIntDescription, not {type(info)}" ) - return bytes(_export_tfhers_int(value, info.cpp())) + return bytes(_export_tfhers_int(value, info)) @staticmethod def import_int( - buffer: bytes, info: TfhersFheIntDescription, keyid: int, variance: float + buffer: bytes, + info: TfhersFheIntDescription, + keyid: int, + variance: float, + shape: Tuple[int, ...], ) -> TransportValue: """Unserialize and convert from TFHErs to Concrete value. @@ -214,6 +54,7 @@ def import_int( info (TfhersFheIntDescription): description of the TFHErs integer to import keyid (int): id of the key used for encryption variance (float): variance used for encryption + shape (Tuple[int, ...]): expected shape Raises: TypeError: if wrong input types @@ -231,4 +72,6 @@ def import_int( raise TypeError(f"keyid must be of type int, not {type(keyid)}") if not isinstance(variance, float): raise TypeError(f"variance must be of type float, not {type(variance)}") - return _import_tfhers_int(buffer, info.cpp(), keyid, variance) + if not isinstance(shape, tuple): + raise TypeError(f"shape must be of type tuple(int, ...), not {type(shape)}") + return _import_tfhers_int(buffer, info, keyid, variance, shape) diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/wrapper.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/wrapper.py deleted file mode 100644 index 5420905361..0000000000 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/wrapper.py +++ /dev/null @@ -1,39 +0,0 @@ -# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions. -# See https://github.com/zama-ai/concrete/blob/main/LICENSE.txt for license information. - -"""Wrapper for native Cpp objects.""" - - -class WrapperCpp: - """Wrapper base class for native Cpp objects. - - Initialization should mainly store the wrapped object, and future calls to the wrapper will be forwarded - to it. A static wrap method is provided to be more explicit. Wrappers should always be constructed using - the new method, which construct the Cpp object using the provided arguments, then wrap it. Classes that - inherit from this class should preferably type check the wrapped object during calls to init, and - reimplement the new method if the class is meant to be constructed. - """ - - def __init__(self, cpp_obj): - self._cpp_obj = cpp_obj - - @classmethod - def wrap(cls, cpp_obj) -> "WrapperCpp": - """Wrap the Cpp object into a Python object. - - Args: - cpp_obj: object to wrap - - Returns: - WrapperCpp: wrapper - """ - return cls(cpp_obj) - - @staticmethod - def new(*args, **kwargs): - """Create a new wrapper by building the underlying object with a specific set of arguments.""" - raise RuntimeError("This class shouldn't be built") - - def cpp(self): - """Return the Cpp wrapped object.""" - return self._cpp_obj diff --git a/compilers/concrete-compiler/compiler/lib/ClientLib/ClientLib.cpp b/compilers/concrete-compiler/compiler/lib/ClientLib/ClientLib.cpp index 3fee8e9586..36f49e8ec7 100644 --- a/compilers/concrete-compiler/compiler/lib/ClientLib/ClientLib.cpp +++ b/compilers/concrete-compiler/compiler/lib/ClientLib/ClientLib.cpp @@ -7,11 +7,13 @@ #include #include #include +#include #include #include #include #include "boost/outcome.h" +#include "capnp/common.h" #include "concrete-cpu.h" #include "concrete-protocol.capnp.h" #include "concretelang/ClientLib/ClientLib.h" @@ -190,10 +192,11 @@ Result ClientProgram::getClientCircuit(std::string circuitName) { Result importTfhersInteger(llvm::ArrayRef buffer, TfhersFheIntDescription integerDesc, uint32_t encryptionKeyId, - double encryptionVariance) { + double encryptionVariance, + std::vector shape) { // Select conversion function based on integer description - std::function conversion_func; if (integerDesc.width == 8) { @@ -211,10 +214,19 @@ Result importTfhersInteger(llvm::ArrayRef buffer, return StringError(errorMsg); } - auto dims = std::vector({integerDesc.n_cts, integerDesc.lwe_size}); + // construct the different dimensions + size_t tensorFlatSize = + std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); + std::vector abstractDims(shape.begin(), shape.end()); + abstractDims.push_back(integerDesc.n_cts); + std::vector concreteDims(abstractDims.begin(), abstractDims.end()); + concreteDims.push_back(integerDesc.lwe_size); + std::vector dims(concreteDims.begin(), concreteDims.end()); + auto outputTensor = Tensor::fromDimensions(dims); - auto err = conversion_func(buffer.data(), buffer.size(), - outputTensor.values.data(), integerDesc); + auto err = + conversion_func(buffer.data(), buffer.size(), outputTensor.values.data(), + tensorFlatSize, integerDesc); if (err) { return StringError("couldn't convert fheint to lwe array"); } @@ -223,9 +235,10 @@ Result importTfhersInteger(llvm::ArrayRef buffer, auto lwe = value.asBuilder().initTypeInfo().initLweCiphertext(); lwe.setIntegerPrecision(64); // dimensions - lwe.initAbstractShape().setDimensions({(uint32_t)integerDesc.n_cts}); + lwe.initAbstractShape().setDimensions( + ::kj::ArrayPtr(abstractDims.data(), abstractDims.size())); lwe.initConcreteShape().setDimensions( - {(uint32_t)integerDesc.n_cts, (uint32_t)integerDesc.lwe_size}); + ::kj::ArrayPtr(concreteDims.data(), concreteDims.size())); // encryption auto encryption = lwe.initEncryption(); encryption.setLweDimension((uint32_t)integerDesc.lwe_size - 1); @@ -247,10 +260,9 @@ Result importTfhersInteger(llvm::ArrayRef buffer, Result> exportTfhersInteger(TransportValue value, TfhersFheIntDescription integerDesc) { // Select conversion function based on integer description - std::function conversion_func; - std::function buffer_size_func; if (integerDesc.width == 8) { if (integerDesc.is_signed) { // fheint8 conversion_func = concrete_cpu_lwe_array_to_tfhers_int8; @@ -274,12 +286,20 @@ exportTfhersInteger(TransportValue value, TfhersFheIntDescription integerDesc) { if (!tensorOrError.has_value()) { return StringError("couldn't get tensor from value"); } + auto concreteShape = tensorOrError.value().dimensions; + assert(concreteShape.size() >= 2); + std::vector tensorShape(concreteShape.begin(), + concreteShape.end() - + 2 /* remove radix and lwe dims */); + size_t tensorFlatSize = std::accumulate( + tensorShape.begin(), tensorShape.end(), 1, std::multiplies()); + // TODO: compute new buffer size of tensor size_t buffer_size = concrete_cpu_tfhers_fheint_buffer_size_u64( - integerDesc.lwe_size, integerDesc.n_cts); + integerDesc.lwe_size, integerDesc.n_cts, tensorFlatSize); std::vector buffer(buffer_size, 0); auto flat_data = tensorOrError.value().values; auto size = conversion_func(flat_data.data(), buffer.data(), buffer.size(), - integerDesc); + tensorFlatSize, integerDesc); if (size == 0) { return StringError("couldn't convert lwe array to fheint8"); } diff --git a/frontends/concrete-python/concrete/fhe/tfhers/bridge.py b/frontends/concrete-python/concrete/fhe/tfhers/bridge.py index de08b36842..bff25a89c8 100644 --- a/frontends/concrete-python/concrete/fhe/tfhers/bridge.py +++ b/frontends/concrete-python/concrete/fhe/tfhers/bridge.py @@ -3,7 +3,7 @@ """ # pylint: disable=import-error,no-member,no-name-in-module -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple from concrete.compiler import LweSecretKey, TfhersExporter, TfhersFheIntDescription @@ -20,22 +20,31 @@ class Bridge: a non-tfhers type output_types (List[Optional[TFHERSIntegerType]]): maps every output to a type. None means a non-tfhers type + input_shapes (List[Optional[Tuple[int, ...]]]): maps every input to a shape. None means + a non-tfhers type + output_shapes (List[Optional[Tuple[int, ...]]]): maps every output to a shape. None means + a non-tfhers type """ circuit: "fhe.Circuit" input_types: List[Optional[TFHERSIntegerType]] - output_types: List[Optional[TFHERSIntegerType]] + input_shapes: List[Optional[Tuple[int, ...]]] + output_shapes: List[Optional[Tuple[int, ...]]] def __init__( self, circuit: "fhe.Circuit", input_types: List[Optional[TFHERSIntegerType]], output_types: List[Optional[TFHERSIntegerType]], + input_shapes: List[Optional[Tuple[int, ...]]], + output_shapes: List[Optional[Tuple[int, ...]]], ): self.circuit = circuit self.input_types = input_types self.output_types = output_types + self.input_shapes = input_shapes + self.output_shapes = output_shapes def _input_type(self, input_idx: int) -> Optional[TFHERSIntegerType]: """Return the type of a certain input. @@ -59,6 +68,28 @@ def _output_type(self, output_idx: int) -> Optional[TFHERSIntegerType]: """ return self.output_types[output_idx] + def _input_shape(self, input_idx: int) -> Optional[Tuple[int, ...]]: + """Return the shape of a certain input. + + Args: + input_idx (int): the input index to get the shape of + + Returns: + Optional[Tuple[int, ...]]: input shape. None means a non-tfhers type + """ + return self.input_shapes[input_idx] + + def _output_shape(self, output_idx: int) -> Optional[Tuple[int, ...]]: # pragma: no cover + """Return the shape of a certain output. + + Args: + output_idx (int): the output index to get the shape of + + Returns: + Optional[Tuple[int, ...]]: output shape. None means a non-tfhers type + """ + return self.output_shapes[output_idx] + def _input_keyid(self, input_idx: int) -> int: return self.circuit.client.specs.program_info.input_keyid_at( input_idx, self.circuit.function_name @@ -90,7 +121,7 @@ def _description_from_type( # this should imply running a PBS on TFHErs side noise_level = TfhersFheIntDescription.get_unknown_noise_level() - return TfhersFheIntDescription.new( + return TfhersFheIntDescription( bit_width, signed, message_modulus, @@ -113,14 +144,15 @@ def import_value(self, buffer: bytes, input_idx: int) -> Value: fhe.TransportValue: imported value """ input_type = self._input_type(input_idx) - if input_type is None: # pragma: no cover + input_shape = self._input_shape(input_idx) + if input_type is None or input_shape is None: # pragma: no cover msg = "input at 'input_idx' is not a TFHErs value" raise ValueError(msg) fheint_desc = self._description_from_type(input_type) keyid = self._input_keyid(input_idx) variance = self._input_variance(input_idx) - return Value(TfhersExporter.import_int(buffer, fheint_desc, keyid, variance)) + return Value(TfhersExporter.import_int(buffer, fheint_desc, keyid, variance, input_shape)) def export_value(self, value: Value, output_idx: int) -> bytes: """Export a value as a serialized TFHErs integer. @@ -218,21 +250,24 @@ def new_bridge(circuit: "fhe.Circuit") -> Bridge: Returns: Bridge: TFHErs bridge """ - input_types = [ - ( - input_node.output.dtype - if isinstance(input_node.output.dtype, TFHERSIntegerType) - else None - ) - for input_node in circuit.graph.ordered_inputs() - ] - output_types = [ - ( - output_node.output.dtype - if isinstance(output_node.output.dtype, TFHERSIntegerType) - else None - ) - for output_node in circuit.graph.ordered_outputs() - ] - - return Bridge(circuit, input_types, output_types) + input_types: List[Optional[TFHERSIntegerType]] = [] + input_shapes: List[Optional[Tuple[int, ...]]] = [] + for input_node in circuit.graph.ordered_inputs(): + if isinstance(input_node.output.dtype, TFHERSIntegerType): + input_types.append(input_node.output.dtype) + input_shapes.append(input_node.output.shape) + else: + input_types.append(None) + input_shapes.append(None) + + output_types: List[Optional[TFHERSIntegerType]] = [] + output_shapes: List[Optional[Tuple[int, ...]]] = [] + for output_node in circuit.graph.ordered_outputs(): + if isinstance(output_node.output.dtype, TFHERSIntegerType): + output_types.append(output_node.output.dtype) + output_shapes.append(output_node.output.shape) + else: # pragma: no cover + output_types.append(None) + output_shapes.append(None) + + return Bridge(circuit, input_types, output_types, input_shapes, output_shapes) diff --git a/frontends/concrete-python/tests/execution/test_tfhers.py b/frontends/concrete-python/tests/execution/test_tfhers.py index c747916084..1ac6487ffd 100644 --- a/frontends/concrete-python/tests/execution/test_tfhers.py +++ b/frontends/concrete-python/tests/execution/test_tfhers.py @@ -5,7 +5,7 @@ import json import os import tempfile -from typing import List +from typing import List, Union import numpy as np import pytest @@ -336,6 +336,33 @@ def lut_add_lut(x, y): TFHERS_INT_8_3_2_4096, id="signed(x) + signed(y)", ), + pytest.param( + lambda x, y: x + y, + { + "x": {"range": [0, 2**7 - 1], "status": "encrypted", "shape": (2,)}, + "y": {"range": [0, 2**7 - 1], "status": "encrypted", "shape": (2,)}, + }, + TFHERS_UINT_8_3_2_4096, + id="tensor(x) + tensor(y)", + ), + pytest.param( + lambda x, y: x + y, + { + "x": {"range": [0, 2**7 - 1], "status": "encrypted", "shape": (3, 2)}, + "y": {"range": [0, 2**7 - 1], "status": "encrypted", "shape": (3, 2)}, + }, + TFHERS_UINT_8_3_2_4096, + id="tensor_2d(x) + tensor_2d(y)", + ), + pytest.param( + lambda x, y: x + y, + { + "x": {"range": [-(2**6), -2], "status": "encrypted", "shape": (3,)}, + "y": {"range": [0, 2**6 - 1], "status": "encrypted", "shape": (3,)}, + }, + TFHERS_INT_8_3_2_4096, + id="tensor_signed(x) + tensor_signed(y)", + ), pytest.param( lambda x, y: x - y, { @@ -373,22 +400,22 @@ def lut_add_lut(x, y): id="signed(x) * signed(y)", ), pytest.param( - lambda x, y: x * y, + lut_add_lut, { - "x": {"range": [-(2**3), 2**2], "status": "encrypted"}, - "y": {"range": [-(2**2), 2**3], "status": "encrypted"}, + "x": {"range": [0, 2**7 - 1], "status": "encrypted"}, + "y": {"range": [0, 2**7 - 1], "status": "encrypted"}, }, - TFHERS_INT_8_3_2_4096, - id="signed(x) * signed(y)", + TFHERS_UINT_8_3_2_4096, + id="lut_add_lut", ), pytest.param( lut_add_lut, { - "x": {"range": [0, 2**7 - 1], "status": "encrypted"}, - "y": {"range": [0, 2**7 - 1], "status": "encrypted"}, + "x": {"range": [0, 2**7 - 1], "status": "encrypted", "shape": (2, 2)}, + "y": {"range": [0, 2**7 - 1], "status": "encrypted", "shape": (2, 2)}, }, TFHERS_UINT_8_3_2_4096, - id="lut_add_lut", + id="tensor_lut_add_lut", ), ], ) @@ -403,6 +430,8 @@ def test_tfhers_binary_encrypted_complete_circuit_concrete_keygen( parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters) + is_tensor = all([param.get("shape") is not None for param in parameters.values()]) + # Only valid when running in multi if helpers.configuration().parameter_selection_strategy != fhe.ParameterSelectionStrategy.MULTI: return @@ -440,8 +469,8 @@ def test_tfhers_binary_encrypted_complete_circuit_concrete_keygen( # serialize key _, key_path = tempfile.mkstemp() serialized_key = tfhers_bridge.serialize_input_secret_key(input_idx=0) - with open(key_path, "wb") as f: - f.write(serialized_key) + with open(key_path, "wb") as fw: + fw.write(serialized_key) ct1, ct2 = sample _, ct1_path = tempfile.mkstemp() @@ -462,46 +491,75 @@ def test_tfhers_binary_encrypted_complete_circuit_concrete_keygen( == 0 ) + def prepare_value(concrete_value, repeat_int: int = 1) -> str: + if isinstance(concrete_value, (int, np.integer)): + assert repeat_int >= 1 + values = [ + concrete_value, + ] * repeat_int + elif isinstance(concrete_value, np.ndarray): + values = concrete_value.flatten().tolist() + else: + raise TypeError( + f"concrete_value should either be int or ndarray, not {type(concrete_value)}" + ) + return "--value=" + ",".join(map(str, values)) + # encrypt inputs and incremnt them by one in TFHErs + repeat_int = 1 + if is_tensor and isinstance(ct1, np.ndarray): + repeat_int = int(np.prod(ct1.shape)) assert ( os.system( - f"{tfhers_utils} encrypt-with-key {'--signed' if dtype.is_signed else ''} --value=1 -c {ct_one_path} --client-key {client_key_path}" + f"{tfhers_utils} encrypt-with-key " + f"{'--signed' if dtype.is_signed else ''} " + f"{prepare_value(1, repeat_int)} -c {ct_one_path} --client-key {client_key_path}" ) == 0 ) sample = [s + 1 for s in sample] assert ( os.system( - f"{tfhers_utils} encrypt-with-key {'--signed' if dtype.is_signed else ''} --value={ct1} -c {ct1_path} --client-key {client_key_path}" + f"{tfhers_utils} encrypt-with-key " + f"{'--signed' if dtype.is_signed else ''} " + f"{prepare_value(ct1)} -c {ct1_path} --client-key {client_key_path}" ) == 0 ) assert ( os.system( - f"{tfhers_utils} encrypt-with-key {'--signed' if dtype.is_signed else ''} --value={ct2} -c {ct2_path} --client-key {client_key_path}" + f"{tfhers_utils} encrypt-with-key " + f"{'--signed' if dtype.is_signed else ''} " + f"{prepare_value(ct2)} -c {ct2_path} --client-key {client_key_path}" ) == 0 ) assert ( os.system( - f"{tfhers_utils} add {'--signed' if dtype.is_signed else ''} -c {ct1_path} {ct_one_path} -s {server_key_path} -o {ct1_path}" + f"{tfhers_utils} add " + f"{'--signed' if dtype.is_signed else ''} " + f"{'--tensor' if is_tensor else ''} " + f"-c {ct1_path} {ct_one_path} -s {server_key_path} -o {ct1_path}" ) == 0 ) assert ( os.system( - f"{tfhers_utils} add {'--signed' if dtype.is_signed else ''} -c {ct2_path} {ct_one_path} -s {server_key_path} -o {ct2_path}" + f"{tfhers_utils} add " + f"{'--signed' if dtype.is_signed else ''} " + f"{'--tensor' if is_tensor else ''} " + f"-c {ct2_path} {ct_one_path} -s {server_key_path} -o {ct2_path}" ) == 0 ) # import ciphertexts and run cts = [] - with open(ct1_path, "rb") as f: - buff = f.read() + with open(ct1_path, "rb") as fr: + buff = fr.read() cts.append(tfhers_bridge.import_value(buff, 0)) - with open(ct2_path, "rb") as f: - buff = f.read() + with open(ct2_path, "rb") as fr: + buff = fr.read() cts.append(tfhers_bridge.import_value(buff, 1)) os.remove(ct1_path) os.remove(ct2_path) @@ -510,18 +568,21 @@ def test_tfhers_binary_encrypted_complete_circuit_concrete_keygen( # concrete decryption should work decrypted = circuit.decrypt(tfhers_encrypted_result) - assert (dtype.decode(decrypted) == function(*sample)).all() # type: ignore + assert isinstance(decrypted, (list, np.ndarray)) + decoded = dtype.decode(decrypted) + assert (decoded == function(*sample)).all() # type: ignore # tfhers decryption buff = tfhers_bridge.export_value(tfhers_encrypted_result, output_idx=0) # type: ignore _, ct_out_path = tempfile.mkstemp() _, pt_path = tempfile.mkstemp() - with open(ct_out_path, "wb") as f: - f.write(buff) + with open(ct_out_path, "wb") as fw: + fw.write(buff) assert ( os.system( f"{tfhers_utils} decrypt-with-key" + f"{' --tensor ' if is_tensor else ''}" f"{' --signed ' if dtype.is_signed else ''}" f" -c {ct_out_path} --lwe-sk {key_path} -p {pt_path}" ) @@ -529,7 +590,13 @@ def test_tfhers_binary_encrypted_complete_circuit_concrete_keygen( ) with open(pt_path, "r", encoding="utf-8") as f: - result = int(f.read()) + result: Union[int, np.ndarray] + if is_tensor: + assert isinstance(decoded, np.ndarray) + result_raw = list(map(int, f.read().split(","))) + result = np.array(result_raw).reshape(decoded.shape) + else: + result = int(f.read()) # close remaining tempfiles os.remove(key_path) @@ -539,7 +606,10 @@ def test_tfhers_binary_encrypted_complete_circuit_concrete_keygen( os.remove(client_key_path) os.remove(server_key_path) - assert result == function(*sample) + if is_tensor: + assert (result == function(*sample)).all() + else: + assert result == function(*sample) @pytest.mark.parametrize( diff --git a/frontends/concrete-python/tests/tfhers-utils/Cargo.lock b/frontends/concrete-python/tests/tfhers-utils/Cargo.lock index a0166f6d12..2c54fd07cc 100644 --- a/frontends/concrete-python/tests/tfhers-utils/Cargo.lock +++ b/frontends/concrete-python/tests/tfhers-utils/Cargo.lock @@ -172,7 +172,8 @@ checksum = "d3fd119d74b830634cea2a0f58bbd0d54540518a14397557951e79340abc28c0" [[package]] name = "concrete-csprng" version = "0.4.1" -source = "git+https://github.com/zama-ai/tfhe-rs.git?rev=483a4fe#483a4fecf1a6b2e2a7827251a04608d018c83002" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90518357249582c16a6b64d7410243dfb3109d5bf0ad1665c058c9a59f2fc4cc" dependencies = [ "aes", "libc", @@ -507,8 +508,9 @@ dependencies = [ [[package]] name = "tfhe" -version = "0.8.0" -source = "git+https://github.com/zama-ai/tfhe-rs.git?rev=483a4fe#483a4fecf1a6b2e2a7827251a04608d018c83002" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e2648b0df14216576ea543bb9021beed019d7eb43fedacdbc24a0c095d33d2a" dependencies = [ "aligned-vec", "bincode", @@ -529,8 +531,9 @@ dependencies = [ [[package]] name = "tfhe-versionable" -version = "0.3.2" -source = "git+https://github.com/zama-ai/tfhe-rs.git?rev=483a4fe#483a4fecf1a6b2e2a7827251a04608d018c83002" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "feeb340d850c65660b321e5379a28b2f3b226c61163de0a12766aedfe8575a29" dependencies = [ "aligned-vec", "num-complex", @@ -540,8 +543,9 @@ dependencies = [ [[package]] name = "tfhe-versionable-derive" -version = "0.3.2" -source = "git+https://github.com/zama-ai/tfhe-rs.git?rev=483a4fe#483a4fecf1a6b2e2a7827251a04608d018c83002" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01d985f9645ed62be4aefb9c06ec70563291ec475036ebcd2cf95c5429a12e8a" dependencies = [ "proc-macro2", "quote", diff --git a/frontends/concrete-python/tests/tfhers-utils/Cargo.toml b/frontends/concrete-python/tests/tfhers-utils/Cargo.toml index 1fa2ae0a90..71e3cf64c8 100644 --- a/frontends/concrete-python/tests/tfhers-utils/Cargo.toml +++ b/frontends/concrete-python/tests/tfhers-utils/Cargo.toml @@ -10,13 +10,13 @@ serde = "1" clap = { version = "4.5.16", features = ["derive"] } -tfhe = { git = "https://github.com/zama-ai/tfhe-rs.git", rev = "483a4fe", features = ["integer"] } +tfhe = { version = "0.8.6", features = ["integer"] } [target.x86_64-unknown-linux-gnu.dependencies] -tfhe = { git = "https://github.com/zama-ai/tfhe-rs.git", rev = "483a4fe", features = ["integer", "x86_64-unix"] } +tfhe = { version = "0.8.6", features = ["integer", "x86_64-unix"] } [target.aarch64-unknown-linux-gnu.dependencies] -tfhe = { git = "https://github.com/zama-ai/tfhe-rs.git", rev = "483a4fe", features = ["integer", "aarch64-unix"] } +tfhe = { version = "0.8.6", features = ["integer", "aarch64-unix"] } [target.x86_64-pc-windows-gnu.dependencies] -tfhe = { git = "https://github.com/zama-ai/tfhe-rs.git", rev = "483a4fe", features = ["integer", "x86_64"] } +tfhe = { version = "0.8.6", features = ["integer", "x86_64"] } diff --git a/frontends/concrete-python/tests/tfhers-utils/src/main.rs b/frontends/concrete-python/tests/tfhers-utils/src/main.rs index 07095d9e1c..5543c90291 100644 --- a/frontends/concrete-python/tests/tfhers-utils/src/main.rs +++ b/frontends/concrete-python/tests/tfhers-utils/src/main.rs @@ -48,14 +48,30 @@ fn set_server_key_from_file(path: &String) { set_server_key(sk); } -fn encrypt_with_key_u8(value: u8, client_key: ClientKey, ciphertext_path: &String) { - let ct = FheUint8::encrypt(value, &client_key); - safe_save(ciphertext_path, &ct) +fn encrypt_with_key_u8(value: Vec, client_key: ClientKey, ciphertext_path: &String) { + if value.len() == 1 { + let ct = FheUint8::encrypt(value[0], &client_key); + safe_save(ciphertext_path, &ct) + } else { + let cts: Vec = value + .iter() + .map(|v| FheUint8::encrypt(v.clone(), &client_key)) + .collect(); + unsafe_save(ciphertext_path, &cts); + } } -fn encrypt_with_key_i8(value: i8, client_key: ClientKey, ciphertext_path: &String) { - let ct = FheInt8::encrypt(value, &client_key); - safe_save(ciphertext_path, &ct) +fn encrypt_with_key_i8(value: Vec, client_key: ClientKey, ciphertext_path: &String) { + if value.len() == 1 { + let ct = FheInt8::encrypt(value[0], &client_key); + safe_save(ciphertext_path, &ct) + } else { + let cts: Vec = value + .iter() + .map(|v| FheInt8::encrypt(v.clone(), &client_key)) + .collect(); + unsafe_save(ciphertext_path, &cts); + } } fn decrypt_with_key( @@ -63,17 +79,38 @@ fn decrypt_with_key( ciphertext_path: &String, plaintext_path: Option<&String>, signed: bool, + tensor: bool, ) { let string_result: String; - if signed { - let fheint: FheInt8 = safe_load(ciphertext_path); - let result: i8 = fheint.decrypt(&client_key); - string_result = result.to_string(); + if tensor { + if signed { + let fheint_array: Vec = unsafe_load(ciphertext_path); + let results: Vec = fheint_array + .iter() + .map(|v| v.decrypt(&client_key)) + .collect(); + let results_str: Vec = results.iter().map(|v| v.to_string()).collect(); + string_result = results_str.join(","); + } else { + let fheint_array: Vec = unsafe_load(ciphertext_path); + let results: Vec = fheint_array + .iter() + .map(|v| v.decrypt(&client_key)) + .collect(); + let results_str: Vec = results.iter().map(|v| v.to_string()).collect(); + string_result = results_str.join(","); + } } else { - let fheuint: FheUint8 = safe_load(ciphertext_path); - let result: u8 = fheuint.decrypt(&client_key); - string_result = result.to_string(); + if signed { + let fheint: FheInt8 = safe_load(ciphertext_path); + let result: i8 = fheint.decrypt(&client_key); + string_result = result.to_string(); + } else { + let fheuint: FheUint8 = safe_load(ciphertext_path); + let result: u8 = fheuint.decrypt(&client_key); + string_result = result.to_string(); + } } if let Some(path) = plaintext_path { @@ -84,24 +121,47 @@ fn decrypt_with_key( } } -fn sum(cts_paths: Vec<&String>, out_ct_path: &String, signed: bool) { +fn sum(cts_paths: Vec<&String>, out_ct_path: &String, signed: bool, tensor: bool) { if cts_paths.is_empty() { panic!("can't call sum with 0 ciphertexts"); } - if signed { - let mut acc: FheInt8 = safe_load(cts_paths[0]); - for ct_path in cts_paths[1..].iter() { - let fheuint: FheInt8 = safe_load(ct_path); - acc += fheuint; + if tensor { + if signed { + let mut acc: Vec = unsafe_load(cts_paths[0]); + for ct_path in cts_paths[1..].iter() { + let fheint_array: Vec = unsafe_load(ct_path); + for (i, inc_value) in fheint_array.iter().enumerate() { + acc[i] += inc_value; + } + } + unsafe_save(out_ct_path, &acc) + } else { + // fails here + let mut acc: Vec = unsafe_load(cts_paths[0]); + for ct_path in cts_paths[1..].iter() { + let fheuint_array: Vec = unsafe_load(ct_path); + for (i, inc_value) in fheuint_array.iter().enumerate() { + acc[i] += inc_value; + } + } + unsafe_save(out_ct_path, &acc) } - safe_save(out_ct_path, &acc) } else { - let mut acc: FheUint8 = safe_load(cts_paths[0]); - for ct_path in cts_paths[1..].iter() { - let fheuint: FheUint8 = safe_load(ct_path); - acc += fheuint; + if signed { + let mut acc: FheInt8 = safe_load(cts_paths[0]); + for ct_path in cts_paths[1..].iter() { + let fheint: FheInt8 = safe_load(ct_path); + acc += fheint; + } + safe_save(out_ct_path, &acc) + } else { + let mut acc: FheUint8 = safe_load(cts_paths[0]); + for ct_path in cts_paths[1..].iter() { + let fheuint: FheUint8 = safe_load(ct_path); + acc += fheuint; + } + safe_save(out_ct_path, &acc) } - safe_save(out_ct_path, &acc) } } @@ -171,7 +231,8 @@ fn main() { .help("value to encrypt") .action(ArgAction::Set) .required(true) - .num_args(1), + .value_delimiter(',') + .num_args(1..), ) .arg( Arg::new("signed") @@ -218,6 +279,12 @@ fn main() { .help("decrypt as a signed integer") .action(ArgAction::SetTrue), ) + .arg( + Arg::new("tensor") + .long("tensor") + .help("decrypt as a tensor") + .action(ArgAction::SetTrue), + ) .arg( Arg::new("ciphertext") .short('c') @@ -274,6 +341,12 @@ fn main() { .help("consider ciphertexts as signed integers") .action(ArgAction::SetTrue), ) + .arg( + Arg::new("tensor") + .long("tensor") + .help("consider ciphertexts as tensors") + .action(ArgAction::SetTrue), + ) .arg( Arg::new("ciphertexts") .short('c') @@ -309,7 +382,6 @@ fn main() { .arg( Arg::new("output-lwe-sk") .long("output-lwe-sk") - .default_value("lwe_secret_key") .help("output lwe key path") .action(ArgAction::Set) .num_args(1), @@ -337,7 +409,10 @@ fn main() { match matches.subcommand() { Some(("encrypt-with-key", encrypt_matches)) => { - let value_str = encrypt_matches.get_one::("value").unwrap(); + let value_str: Vec<&String> = encrypt_matches + .get_many::("value") + .unwrap() + .collect(); let ciphertext_path = encrypt_matches.get_one::("ciphertext").unwrap(); let signed = encrypt_matches.get_flag("signed"); @@ -351,10 +426,10 @@ fn main() { } if signed { - let value: i8 = value_str.parse().unwrap(); + let value: Vec = value_str.iter().map(|v| v.parse().unwrap()).collect(); encrypt_with_key_i8(value, client_key, ciphertext_path) } else { - let value: u8 = value_str.parse().unwrap(); + let value: Vec = value_str.iter().map(|v| v.parse().unwrap()).collect(); encrypt_with_key_u8(value, client_key, ciphertext_path) } } @@ -362,6 +437,7 @@ fn main() { let ciphertext_path = decrypt_mtches.get_one::("ciphertext").unwrap(); let plaintext_path = decrypt_mtches.get_one::("plaintext"); let signed = decrypt_mtches.get_flag("signed"); + let tensor = decrypt_mtches.get_flag("tensor"); let client_key: ClientKey; if let Some(lwe_sk_path) = decrypt_mtches.get_one::("lwe-sk") { @@ -371,38 +447,39 @@ fn main() { } else { panic!("no key specified"); } - decrypt_with_key(client_key, ciphertext_path, plaintext_path, signed) + decrypt_with_key(client_key, ciphertext_path, plaintext_path, signed, tensor) } Some(("add", add_mtches)) => { let server_key_path = add_mtches.get_one::("server-key").unwrap(); let cts_path = add_mtches.get_many::("ciphertexts").unwrap(); let output_ct_path = add_mtches.get_one::("output-ciphertext").unwrap(); let signed = add_mtches.get_flag("signed"); + let tensor = add_mtches.get_flag("tensor"); set_server_key_from_file(server_key_path); - sum(cts_path.collect(), output_ct_path, signed) + sum(cts_path.collect(), output_ct_path, signed, tensor) } Some(("keygen", keygen_mtches)) => { let client_key_path = keygen_mtches.get_one::("client-key").unwrap(); let server_key_path = keygen_mtches.get_one::("server-key").unwrap(); - let output_lwe_path = keygen_mtches.get_one::("output-lwe-sk").unwrap(); + let output_lwe_path = keygen_mtches.get_one::("output-lwe-sk"); // we keygen based on an initial secret key if provided, otherwise we keygen from scratch if let Some(lwe_sk_path) = keygen_mtches.get_one::("lwe-sk") { let client_key = keygen_from_lwe(lwe_sk_path); let server_key = client_key.generate_server_key(); - let lwe_secret_key = unsafe_load(lwe_sk_path); + // we already have the initial lwe-sk, so no need to write it to output-lwe-sk write_keys( client_key_path, server_key_path, - output_lwe_path, + &String::new(), Some(client_key), Some(server_key), - Some(lwe_secret_key), + None, ) } else { - keygen(client_key_path, server_key_path, output_lwe_path) + keygen(client_key_path, server_key_path, output_lwe_path.unwrap()) } } _ => unreachable!(), // If all subcommands are defined above, anything else is unreachable