From 6eb8cb09a8e0cd7f6c0add58de81c5e68ffafc23 Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Fri, 13 Dec 2024 16:37:35 +0100 Subject: [PATCH] refactor(integer): factorize expansion code --- tfhe/src/integer/ciphertext/compact_list.rs | 301 ++++++++------------ 1 file changed, 113 insertions(+), 188 deletions(-) diff --git a/tfhe/src/integer/ciphertext/compact_list.rs b/tfhe/src/integer/ciphertext/compact_list.rs index 82272c68ad..9a66d119dc 100644 --- a/tfhe/src/integer/ciphertext/compact_list.rs +++ b/tfhe/src/integer/ciphertext/compact_list.rs @@ -9,6 +9,7 @@ use crate::integer::encryption::{create_clear_radix_block_iterator, KnowsMessage use crate::integer::parameters::CompactCiphertextListConformanceParams; pub use crate::integer::parameters::IntegerCompactCiphertextListExpansionMode; use crate::integer::{CompactPublicKey, ServerKey}; +use crate::shortint::ciphertext::Degree; #[cfg(feature = "zk-pok")] use crate::shortint::ciphertext::ProvenCompactCiphertextListConformanceParams; use crate::shortint::parameters::{ @@ -545,6 +546,85 @@ impl IntegerUnpackingToShortintCastingModeHelper { } } +type ExpansionHelperCallback<'a, ListType> = &'a dyn Fn( + &ListType, + ShortintCompactCiphertextListCastingMode<'_>, +) -> Result, crate::Error>; + +fn expansion_helper( + expansion_mode: IntegerCompactCiphertextListExpansionMode<'_>, + ct_list: &ListType, + list_degree: Degree, + info: &[DataKind], + is_packed: bool, + list_expansion_fn: ExpansionHelperCallback<'_, ListType>, +) -> Result, crate::Error> { + if is_packed + && matches!( + expansion_mode, + IntegerCompactCiphertextListExpansionMode::NoCastingAndNoUnpacking + ) + { + return Err(crate::Error::new(String::from( + WRONG_UNPACKING_MODE_ERR_MSG, + ))); + } + + match expansion_mode { + IntegerCompactCiphertextListExpansionMode::CastAndUnpackIfNecessary( + key_switching_key_view, + ) => { + let dest_sks = &key_switching_key_view.key.dest_server_key; + let function_helper = IntegerUnpackingToShortintCastingModeHelper::new( + dest_sks.message_modulus, + dest_sks.carry_modulus, + ); + let functions = if is_packed { + function_helper.generate_unpack_and_sanitize_functions(info) + } else { + function_helper.generate_sanitize_without_unpacking_functions(info) + }; + + list_expansion_fn( + ct_list, + ShortintCompactCiphertextListCastingMode::CastIfNecessary { + casting_key: key_switching_key_view.key, + functions: Some(functions.as_slice()), + }, + ) + } + IntegerCompactCiphertextListExpansionMode::UnpackAndSanitizeIfNecessary(sks) => { + let expanded_blocks = + list_expansion_fn(ct_list, ShortintCompactCiphertextListCastingMode::NoCasting)?; + + if is_packed { + let mut conformance_params = sks.key.conformance_params(); + conformance_params.degree = list_degree; + + for ct in expanded_blocks.iter() { + if !ct.is_conformant(&conformance_params) { + return Err(crate::Error::new( + "This compact list is not conformant with the given server key" + .to_string(), + )); + } + } + + Ok(unpack_and_sanitize_message_and_carries( + expanded_blocks, + sks, + info, + )) + } else { + Ok(sanitize_boolean_blocks(expanded_blocks, sks, info)) + } + } + IntegerCompactCiphertextListExpansionMode::NoCastingAndNoUnpacking => { + list_expansion_fn(ct_list, ShortintCompactCiphertextListCastingMode::NoCasting) + } + } +} + impl CompactCiphertextList { pub fn is_packed(&self) -> bool { self.ct_list.degree.get() @@ -694,66 +774,14 @@ impl CompactCiphertextList { ) -> crate::Result { let is_packed = self.is_packed(); - if is_packed - && matches!( - expansion_mode, - IntegerCompactCiphertextListExpansionMode::NoCastingAndNoUnpacking - ) - { - return Err(crate::Error::new(String::from( - WRONG_UNPACKING_MODE_ERR_MSG, - ))); - } - - let expanded_blocks = match expansion_mode { - IntegerCompactCiphertextListExpansionMode::CastAndUnpackIfNecessary( - key_switching_key_view, - ) => { - let dest_sks = &key_switching_key_view.key.dest_server_key; - let function_helper = IntegerUnpackingToShortintCastingModeHelper::new( - dest_sks.message_modulus, - dest_sks.carry_modulus, - ); - let functions = if is_packed { - function_helper.generate_unpack_and_sanitize_functions(&self.info) - } else { - function_helper.generate_sanitize_without_unpacking_functions(&self.info) - }; - - self.ct_list - .expand(ShortintCompactCiphertextListCastingMode::CastIfNecessary { - casting_key: key_switching_key_view.key, - functions: Some(functions.as_slice()), - })? - } - IntegerCompactCiphertextListExpansionMode::UnpackAndSanitizeIfNecessary(sks) => { - let expanded_blocks = self - .ct_list - .expand(ShortintCompactCiphertextListCastingMode::NoCasting)?; - - if is_packed { - let degree = self.ct_list.degree; - let mut conformance_params = sks.key.conformance_params(); - conformance_params.degree = degree; - - for ct in expanded_blocks.iter() { - if !ct.is_conformant(&conformance_params) { - return Err(crate::Error::new( - "This compact list is not conformant with the given server key" - .to_string(), - )); - } - } - - unpack_and_sanitize_message_and_carries(expanded_blocks, sks, &self.info) - } else { - sanitize_boolean_blocks(expanded_blocks, sks, &self.info) - } - } - IntegerCompactCiphertextListExpansionMode::NoCastingAndNoUnpacking => self - .ct_list - .expand(ShortintCompactCiphertextListCastingMode::NoCasting)?, - }; + let expanded_blocks = expansion_helper( + expansion_mode, + &self.ct_list, + self.ct_list.degree, + &self.info, + is_packed, + &crate::shortint::ciphertext::CompactCiphertextList::expand, + )?; Ok(CompactCiphertextListExpander::new( expanded_blocks, @@ -822,78 +850,27 @@ impl ProvenCompactCiphertextList { ) -> crate::Result { let is_packed = self.is_packed(); - if is_packed - && matches!( + // Type annotation needed rust is not able to coerce the type on its own, also forces us to + // use a trait object + let callback: ExpansionHelperCallback<'_, _> = &|ct_list, expansion_mode| { + crate::shortint::ciphertext::ProvenCompactCiphertextList::verify_and_expand( + ct_list, + crs, + &public_key.key, + metadata, expansion_mode, - IntegerCompactCiphertextListExpansionMode::NoCastingAndNoUnpacking ) - { - return Err(crate::Error::new(String::from( - WRONG_UNPACKING_MODE_ERR_MSG, - ))); - } - - let expanded_blocks = match expansion_mode { - IntegerCompactCiphertextListExpansionMode::CastAndUnpackIfNecessary( - key_switching_key_view, - ) => { - let dest_sks = &key_switching_key_view.key.dest_server_key; - let function_helper = IntegerUnpackingToShortintCastingModeHelper::new( - dest_sks.message_modulus, - dest_sks.carry_modulus, - ); - let functions = if is_packed { - function_helper.generate_unpack_and_sanitize_functions(&self.info) - } else { - function_helper.generate_sanitize_without_unpacking_functions(&self.info) - }; - self.ct_list.verify_and_expand( - crs, - &public_key.key, - metadata, - ShortintCompactCiphertextListCastingMode::CastIfNecessary { - casting_key: key_switching_key_view.key, - functions: Some(functions.as_slice()), - }, - )? - } - IntegerCompactCiphertextListExpansionMode::UnpackAndSanitizeIfNecessary(sks) => { - let expanded_blocks = self.ct_list.verify_and_expand( - crs, - &public_key.key, - metadata, - ShortintCompactCiphertextListCastingMode::NoCasting, - )?; - - if is_packed { - let degree = self.ct_list.proved_lists[0].0.degree; - let mut conformance_params = sks.key.conformance_params(); - conformance_params.degree = degree; - - for ct in expanded_blocks.iter() { - if !ct.is_conformant(&conformance_params) { - return Err(crate::Error::new( - "This compact list is not conformant with the given server key" - .to_string(), - )); - } - } - - unpack_and_sanitize_message_and_carries(expanded_blocks, sks, &self.info) - } else { - sanitize_boolean_blocks(expanded_blocks, sks, &self.info) - } - } - IntegerCompactCiphertextListExpansionMode::NoCastingAndNoUnpacking => { - self.ct_list.verify_and_expand( - crs, - &public_key.key, - metadata, - ShortintCompactCiphertextListCastingMode::NoCasting, - )? - } }; + let expanded_blocks = expansion_helper( + expansion_mode, + &self.ct_list, + self.ct_list.proved_lists[0].0.degree, + &self.info, + is_packed, + callback, + )?; + Ok(CompactCiphertextListExpander::new( expanded_blocks, self.info.clone(), @@ -910,66 +887,14 @@ impl ProvenCompactCiphertextList { ) -> crate::Result { let is_packed = self.is_packed(); - if is_packed - && matches!( - expansion_mode, - IntegerCompactCiphertextListExpansionMode::NoCastingAndNoUnpacking - ) - { - return Err(crate::Error::new(String::from( - WRONG_UNPACKING_MODE_ERR_MSG, - ))); - } - - let expanded_blocks = match expansion_mode { - IntegerCompactCiphertextListExpansionMode::CastAndUnpackIfNecessary( - key_switching_key_view, - ) => { - let dest_sks = &key_switching_key_view.key.dest_server_key; - let function_helper = IntegerUnpackingToShortintCastingModeHelper::new( - dest_sks.message_modulus, - dest_sks.carry_modulus, - ); - let functions = if is_packed { - function_helper.generate_unpack_and_sanitize_functions(&self.info) - } else { - function_helper.generate_sanitize_without_unpacking_functions(&self.info) - }; - self.ct_list.expand_without_verification( - ShortintCompactCiphertextListCastingMode::CastIfNecessary { - casting_key: key_switching_key_view.key, - functions: Some(functions.as_slice()), - }, - )? - } - IntegerCompactCiphertextListExpansionMode::UnpackAndSanitizeIfNecessary(sks) => { - let expanded_blocks = self.ct_list.expand_without_verification( - ShortintCompactCiphertextListCastingMode::NoCasting, - )?; - - if is_packed { - let degree = self.ct_list.proved_lists[0].0.degree; - let mut conformance_params = sks.key.conformance_params(); - conformance_params.degree = degree; - - for ct in expanded_blocks.iter() { - if !ct.is_conformant(&conformance_params) { - return Err(crate::Error::new( - "This compact list is not conformant with the given server key" - .to_string(), - )); - } - } - - unpack_and_sanitize_message_and_carries(expanded_blocks, sks, &self.info) - } else { - sanitize_boolean_blocks(expanded_blocks, sks, &self.info) - } - } - IntegerCompactCiphertextListExpansionMode::NoCastingAndNoUnpacking => self - .ct_list - .expand_without_verification(ShortintCompactCiphertextListCastingMode::NoCasting)?, - }; + let expanded_blocks = expansion_helper( + expansion_mode, + &self.ct_list, + self.ct_list.proved_lists[0].0.degree, + &self.info, + is_packed, + &crate::shortint::ciphertext::ProvenCompactCiphertextList::expand_without_verification, + )?; Ok(CompactCiphertextListExpander::new( expanded_blocks,