Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(gpu): port fix to compression encoding #1903

Merged
merged 2 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ template <typename Torus> struct int_decompression {
Torus *tmp_extracted_lwe;
uint32_t *tmp_indexes_array;

int_radix_lut<Torus> *carry_extract_lut;
int_radix_lut<Torus> *decompression_rescale_lut;

int_decompression(cudaStream_t const *streams, uint32_t const *gpu_indexes,
uint32_t gpu_count, int_radix_params encryption_params,
Expand All @@ -83,7 +83,7 @@ template <typename Torus> struct int_decompression {
Torus lwe_accumulator_size = (compression_params.glwe_dimension *
compression_params.polynomial_size +
1);
carry_extract_lut = new int_radix_lut<Torus>(
decompression_rescale_lut = new int_radix_lut<Torus>(
streams, gpu_indexes, gpu_count, encryption_params, 1,
num_radix_blocks, allocate_gpu_memory);

Expand All @@ -96,18 +96,28 @@ template <typename Torus> struct int_decompression {
num_radix_blocks * lwe_accumulator_size * sizeof(Torus), streams[0],
gpu_indexes[0]);

// Carry extract LUT
auto carry_extract_f = [encryption_params](Torus x) -> Torus {
return x / encryption_params.message_modulus;
// Rescale is done using an identity LUT
// Here we do not divide by message_modulus
// Example: in the 2_2 case we are mapping a 2 bits message onto a 4 bits
// space, we want to keep the original 2 bits value in the 4 bits space,
// so we apply the identity and the encoding will rescale it for us.
auto decompression_rescale_f = [encryption_params](Torus x) -> Torus {
return x;
};

generate_device_accumulator<Torus>(
streams[0], gpu_indexes[0], carry_extract_lut->get_lut(0, 0),
auto effective_compression_message_modulus =
encryption_params.carry_modulus;
auto effective_compression_carry_modulus = 1;

generate_device_accumulator_with_encoding<Torus>(
streams[0], gpu_indexes[0], decompression_rescale_lut->get_lut(0, 0),
encryption_params.glwe_dimension, encryption_params.polynomial_size,
effective_compression_message_modulus,
effective_compression_carry_modulus,
encryption_params.message_modulus, encryption_params.carry_modulus,
carry_extract_f);
decompression_rescale_f);

carry_extract_lut->broadcast_lut(streams, gpu_indexes, 0);
decompression_rescale_lut->broadcast_lut(streams, gpu_indexes, 0);
}
}
void release(cudaStream_t const *streams, uint32_t const *gpu_indexes,
Expand All @@ -116,8 +126,8 @@ template <typename Torus> struct int_decompression {
cuda_drop_async(tmp_extracted_lwe, streams[0], gpu_indexes[0]);
cuda_drop_async(tmp_indexes_array, streams[0], gpu_indexes[0]);

carry_extract_lut->release(streams, gpu_indexes, gpu_count);
delete carry_extract_lut;
decompression_rescale_lut->release(streams, gpu_indexes, gpu_count);
delete decompression_rescale_lut;
}
};
#endif
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@ void generate_device_accumulator_bivariate_with_factor(
cudaStream_t stream, uint32_t gpu_index, Torus *acc_bivariate,
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t message_modulus,
uint32_t carry_modulus, std::function<Torus(Torus, Torus)> f, int factor);

template <typename Torus>
void generate_device_accumulator_with_encoding(
cudaStream_t stream, uint32_t gpu_index, Torus *acc,
uint32_t glwe_dimension, uint32_t polynomial_size,
uint32_t input_message_modulus, uint32_t input_carry_modulus,
uint32_t output_message_modulus, uint32_t output_carry_modulus,
std::function<Torus(Torus)> f);

/*
* generate univariate accumulator (lut) for device pointer
* stream - cuda stream
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ __host__ void host_integer_decompress(
/// Apply PBS to apply a LUT, reduce the noise and go from a small LWE
/// dimension to a big LWE dimension
auto encryption_params = h_mem_ptr->encryption_params;
auto lut = h_mem_ptr->carry_extract_lut;
auto lut = h_mem_ptr->decompression_rescale_lut;
auto active_gpu_count = get_active_gpu_count(num_radix_blocks, gpu_count);
if (active_gpu_count == 1) {
execute_pbs_async<Torus>(
Expand Down
110 changes: 80 additions & 30 deletions backends/tfhe-cuda-backend/cuda/src/integer/integer.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -627,26 +627,48 @@ void rotate_left(Torus *buffer, int mid, uint32_t array_length) {
std::rotate(buffer, buffer + mid, buffer + array_length);
}

/// Caller needs to ensure that the operation applied is coherent from an
/// encoding perspective.
///
/// For example:
///
/// Input encoding has 2 bits and output encoding has 4 bits, applying the
/// identity lut would map the following:
///
/// 0|00|xx -> 0|00|00
/// 0|01|xx -> 0|00|01
/// 0|10|xx -> 0|00|10
/// 0|11|xx -> 0|00|11
///
/// The reason is the identity function is computed in the input space but the
/// scaling is done in the output space, as there are more bits in the output
/// space, the delta is smaller hence the apparent "division" happening.
template <typename Torus>
void generate_lookup_table(Torus *acc, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t message_modulus,
uint32_t carry_modulus,
std::function<Torus(Torus)> f) {

uint32_t modulus_sup = message_modulus * carry_modulus;
uint32_t box_size = polynomial_size / modulus_sup;
Torus delta = (1ul << 63) / modulus_sup;
void generate_lookup_table_with_encoding(Torus *acc, uint32_t glwe_dimension,
uint32_t polynomial_size,
uint32_t input_message_modulus,
uint32_t input_carry_modulus,
uint32_t output_message_modulus,
uint32_t output_carry_modulus,
std::function<Torus(Torus)> f) {

uint32_t input_modulus_sup = input_message_modulus * input_carry_modulus;
uint32_t output_modulus_sup = output_message_modulus * output_carry_modulus;
uint32_t box_size = polynomial_size / input_modulus_sup;
auto nbits = sizeof(Torus) * 8;
Torus output_delta =
(static_cast<Torus>(1) << (nbits - 1)) / output_modulus_sup;

memset(acc, 0, glwe_dimension * polynomial_size * sizeof(Torus));

auto body = &acc[glwe_dimension * polynomial_size];

// This accumulator extracts the carry bits
for (int i = 0; i < modulus_sup; i++) {
for (int i = 0; i < input_modulus_sup; i++) {
int index = i * box_size;
for (int j = index; j < index + box_size; j++) {
auto f_eval = f(i);
body[j] = f_eval * delta;
body[j] = f_eval * output_delta;
}
}

Expand All @@ -660,6 +682,16 @@ void generate_lookup_table(Torus *acc, uint32_t glwe_dimension,
rotate_left<Torus>(body, half_box_size, polynomial_size);
}

template <typename Torus>
void generate_lookup_table(Torus *acc, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t message_modulus,
uint32_t carry_modulus,
std::function<Torus(Torus)> f) {
generate_lookup_table_with_encoding(acc, glwe_dimension, polynomial_size,
message_modulus, carry_modulus,
message_modulus, carry_modulus, f);
}

template <typename Torus>
void generate_many_lookup_table(
Torus *acc, uint32_t glwe_dimension, uint32_t polynomial_size,
Expand All @@ -668,7 +700,8 @@ void generate_many_lookup_table(

uint32_t modulus_sup = message_modulus * carry_modulus;
uint32_t box_size = polynomial_size / modulus_sup;
Torus delta = (1ul << 63) / modulus_sup;
auto nbits = sizeof(Torus) * 8;
Torus delta = (static_cast<Torus>(1) << (nbits - 1)) / modulus_sup;

memset(acc, 0, glwe_dimension * polynomial_size * sizeof(Torus));

Expand Down Expand Up @@ -803,28 +836,22 @@ void generate_device_accumulator_bivariate_with_factor(
free(h_lut);
}

/*
* generate accumulator for device pointer
* v_stream - cuda stream
* acc - device pointer for accumulator
* ...
* f - evaluating function with one Torus input
*/
template <typename Torus>
void generate_device_accumulator(cudaStream_t stream, uint32_t gpu_index,
Torus *acc, uint32_t glwe_dimension,
uint32_t polynomial_size,
uint32_t message_modulus,
uint32_t carry_modulus,
std::function<Torus(Torus)> f) {
void generate_device_accumulator_with_encoding(
cudaStream_t stream, uint32_t gpu_index, Torus *acc,
uint32_t glwe_dimension, uint32_t polynomial_size,
uint32_t input_message_modulus, uint32_t input_carry_modulus,
uint32_t output_message_modulus, uint32_t output_carry_modulus,
std::function<Torus(Torus)> f) {

// host lut
Torus *h_lut =
(Torus *)malloc((glwe_dimension + 1) * polynomial_size * sizeof(Torus));

// fill accumulator
generate_lookup_table<Torus>(h_lut, glwe_dimension, polynomial_size,
message_modulus, carry_modulus, f);
generate_lookup_table_with_encoding<Torus>(
h_lut, glwe_dimension, polynomial_size, input_message_modulus,
input_carry_modulus, output_message_modulus, output_carry_modulus, f);

// copy host lut and lut_indexes_vec to device
cuda_memcpy_async_to_gpu(
Expand All @@ -835,6 +862,26 @@ void generate_device_accumulator(cudaStream_t stream, uint32_t gpu_index,
free(h_lut);
}

/*
* generate accumulator for device pointer
* v_stream - cuda stream
* acc - device pointer for accumulator
* ...
* f - evaluating function with one Torus input
*/
template <typename Torus>
void generate_device_accumulator(cudaStream_t stream, uint32_t gpu_index,
Torus *acc, uint32_t glwe_dimension,
uint32_t polynomial_size,
uint32_t message_modulus,
uint32_t carry_modulus,
std::function<Torus(Torus)> f) {

generate_device_accumulator_with_encoding(
stream, gpu_index, acc, glwe_dimension, polynomial_size, message_modulus,
carry_modulus, message_modulus, carry_modulus, f);
}

/*
* generate many lut accumulator for device pointer
* v_stream - cuda stream
Expand Down Expand Up @@ -1055,7 +1102,8 @@ void host_compute_propagation_simulators_and_group_carries(
message_modulus, carry_modulus);

uint32_t modulus_sup = message_modulus * carry_modulus;
Torus delta = (1ull << 63) / modulus_sup;
auto nbits = sizeof(Torus) * 8;
Torus delta = (static_cast<Torus>(1) << (nbits - 1)) / modulus_sup;
auto simulators = mem->simulators;
auto grouping_pgns = mem->grouping_pgns;
host_radix_split_simulators_and_grouping_pgns<Torus>(
Expand Down Expand Up @@ -1382,8 +1430,8 @@ __host__ void
create_trivial_radix(cudaStream_t stream, uint32_t gpu_index,
Torus *lwe_array_out, Torus const *scalar_array,
uint32_t lwe_dimension, uint32_t num_radix_blocks,
uint32_t num_scalar_blocks, uint64_t message_modulus,
uint64_t carry_modulus) {
uint32_t num_scalar_blocks, Torus message_modulus,
Torus carry_modulus) {

cudaSetDevice(gpu_index);
size_t radix_size = (lwe_dimension + 1) * num_radix_blocks;
Expand All @@ -1403,7 +1451,9 @@ create_trivial_radix(cudaStream_t stream, uint32_t gpu_index,
// Value of the shift we multiply our messages by
// If message_modulus and carry_modulus are always powers of 2 we can simplify
// this
uint64_t delta = ((uint64_t)1 << 63) / (message_modulus * carry_modulus);
auto nbits = sizeof(Torus) * 8;
Torus delta = (static_cast<Torus>(1) << (nbits - 1)) /
(message_modulus * carry_modulus);

device_create_trivial_radix<Torus><<<grid, thds, 0, stream>>>(
lwe_array_out, scalar_array, num_scalar_blocks, lwe_dimension, delta);
Expand Down
16 changes: 9 additions & 7 deletions tfhe/src/integer/gpu/ciphertext/compressed_ciphertext_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,15 @@ impl CudaCompressedCiphertextList {
let end_block_index = start_block_index + current_info.num_blocks() - 1;

Some((
decomp_key.unpack(
&self.packed_list,
current_info,
start_block_index,
end_block_index,
streams,
),
decomp_key
.unpack(
&self.packed_list,
current_info,
start_block_index,
end_block_index,
streams,
)
.unwrap(),
current_info,
))
}
Expand Down
22 changes: 19 additions & 3 deletions tfhe/src/integer/gpu/list_compression/server_keys.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,23 @@ impl CudaDecompressionKey {
start_block_index: usize,
end_block_index: usize,
streams: &CudaStreams,
) -> CudaRadixCiphertext {
) -> Result<CudaRadixCiphertext, crate::Error> {
if self.message_modulus.0 != self.carry_modulus.0 {
return Err(crate::Error::new(format!(
"Tried to unpack values from a list where message modulus \
({:?}) is != carry modulus ({:?}), this is not supported.",
self.message_modulus, self.carry_modulus,
)));
}

if end_block_index >= packed_list.bodies_count {
return Err(crate::Error::new(format!(
"Tried getting index {end_block_index} for CompressedCiphertextList \
with {} elements, out of bound access.",
packed_list.bodies_count
)));
}

let indexes_array = (start_block_index..=end_block_index)
.map(|x| x as u32)
.collect_vec();
Expand Down Expand Up @@ -264,10 +280,10 @@ impl CudaDecompressionKey {

let blocks = vec![first_block_info; output_lwe.0.lwe_ciphertext_count.0];

CudaRadixCiphertext {
Ok(CudaRadixCiphertext {
d_blocks: output_lwe,
info: CudaRadixCiphertextInfo { blocks },
}
})
}
CudaBootstrappingKey::MultiBit(_) => {
panic! {"Compression is currently not compatible with Multi-Bit PBS"}
Expand Down
Loading