Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(hlapi): separate hlapi compressed list from integer #1902

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,127 @@
use std::convert::Infallible;
use tfhe_versionable::{Upgrade, Version, VersionsDispatch};
use tfhe_versionable::{
Unversionize, UnversionizeError, Upgrade, Version, Versionize, VersionizeOwned,
VersionsDispatch,
};

use crate::{CompressedCiphertextList, Tag};
use crate::core_crypto::commons::math::random::{Deserialize, Serialize};
use crate::high_level_api::compressed_ciphertext_list::InnerCompressedCiphertextList;
#[cfg(feature = "gpu")]
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
#[cfg(feature = "gpu")]
use crate::integer::gpu::ciphertext::compressed_ciphertext_list::CudaCompressedCiphertextList;
use crate::{CompressedCiphertextList, SerializedKind, Tag};

#[derive(Clone, Serialize, Deserialize)]
pub(crate) enum InnerCompressedCiphertextListV0 {
Cpu(crate::integer::ciphertext::CompressedCiphertextList),
#[cfg(feature = "gpu")]
Cuda(CudaCompressedCiphertextList),
}

#[derive(serde::Serialize)]
pub struct InnerCompressedCiphertextListV0Version<'vers>(
<InnerCompressedCiphertextListV0 as Versionize>::Versioned<'vers>,
);

impl<'vers> From<&'vers InnerCompressedCiphertextListV0>
for InnerCompressedCiphertextListV0Version<'vers>
{
fn from(value: &'vers InnerCompressedCiphertextListV0) -> Self {
Self(value.versionize())
}
}

#[derive(::serde::Serialize, ::serde::Deserialize)]
pub struct InnerCompressedCiphertextListV0Owned(
<InnerCompressedCiphertextListV0 as VersionizeOwned>::VersionedOwned,
);

impl From<InnerCompressedCiphertextListV0> for InnerCompressedCiphertextListV0Owned {
fn from(value: InnerCompressedCiphertextListV0) -> Self {
Self(value.versionize_owned())
}
}

impl TryFrom<InnerCompressedCiphertextListV0Owned> for InnerCompressedCiphertextListV0 {
type Error = UnversionizeError;

fn try_from(value: InnerCompressedCiphertextListV0Owned) -> Result<Self, Self::Error> {
Self::unversionize(value.0)
}
}

impl Version for InnerCompressedCiphertextListV0 {
type Ref<'vers>
= InnerCompressedCiphertextListV0Version<'vers>
where
Self: 'vers;

type Owned = InnerCompressedCiphertextListV0Owned;
}

impl Versionize for InnerCompressedCiphertextListV0 {
type Versioned<'vers> =
<crate::integer::ciphertext::CompressedCiphertextList as VersionizeOwned>::VersionedOwned;

fn versionize(&self) -> Self::Versioned<'_> {
match self {
Self::Cpu(inner) => inner.clone().versionize_owned(),
#[cfg(feature = "gpu")]
Self::Cuda(inner) => {
let cpu_data = with_thread_local_cuda_streams(|streams| {
inner.to_compressed_ciphertext_list(streams)
});
cpu_data.versionize_owned()
}
}
}
}

impl VersionizeOwned for InnerCompressedCiphertextListV0 {
type VersionedOwned =
<crate::integer::ciphertext::CompressedCiphertextList as VersionizeOwned>::VersionedOwned;

fn versionize_owned(self) -> Self::VersionedOwned {
match self {
Self::Cpu(inner) => inner.versionize_owned(),
#[cfg(feature = "gpu")]
Self::Cuda(inner) => {
let cpu_data = with_thread_local_cuda_streams(|streams| {
inner.to_compressed_ciphertext_list(streams)
});
cpu_data.versionize_owned()
}
}
}
}

impl Unversionize for InnerCompressedCiphertextListV0 {
fn unversionize(versioned: Self::VersionedOwned) -> Result<Self, UnversionizeError> {
Ok(Self::Cpu(
crate::integer::ciphertext::CompressedCiphertextList::unversionize(versioned)?,
))
}
}

impl Upgrade<InnerCompressedCiphertextList> for InnerCompressedCiphertextListV0 {
type Error = Infallible;

fn upgrade(self) -> Result<InnerCompressedCiphertextList, Self::Error> {
Ok(match self {
Self::Cpu(cpu) => InnerCompressedCiphertextList::Cpu(cpu.packed_list),
#[cfg(feature = "gpu")]
Self::Cuda(cuda) => InnerCompressedCiphertextList::Cuda(cuda.packed_list),
})
}
}

#[derive(VersionsDispatch)]
#[allow(unused)]
pub(crate) enum InnerCompressedCiphertextListVersions {
V0(InnerCompressedCiphertextListV0),
V1(InnerCompressedCiphertextList),
}

#[derive(Version)]
pub struct CompressedCiphertextListV0(crate::integer::ciphertext::CompressedCiphertextList);
Expand All @@ -23,12 +143,43 @@ pub struct CompressedCiphertextListV1 {
tag: Tag,
}

impl Upgrade<CompressedCiphertextList> for CompressedCiphertextListV1 {
impl Upgrade<CompressedCiphertextListV2> for CompressedCiphertextListV1 {
type Error = Infallible;

fn upgrade(self) -> Result<CompressedCiphertextListV2, Self::Error> {
Ok(CompressedCiphertextListV2 {
inner: InnerCompressedCiphertextListV0::Cpu(self.inner),
tag: self.tag,
})
}
}

#[derive(Version)]
pub struct CompressedCiphertextListV2 {
inner: InnerCompressedCiphertextListV0,
tag: Tag,
}

impl Upgrade<CompressedCiphertextList> for CompressedCiphertextListV2 {
type Error = Infallible;

fn upgrade(self) -> Result<CompressedCiphertextList, Self::Error> {
let (block_kinds, msg_modulus) = match &self.inner {
InnerCompressedCiphertextListV0::Cpu(inner) => {
(&inner.info, inner.packed_list.message_modulus)
}
#[cfg(feature = "gpu")]
InnerCompressedCiphertextListV0::Cuda(inner) => {
(&inner.info, inner.packed_list.message_modulus)
}
};
let info = block_kinds
.iter()
.map(|kind| SerializedKind::from_data_kind(*kind, msg_modulus))
.collect();
Ok(CompressedCiphertextList {
inner: crate::high_level_api::compressed_ciphertext_list::InnerCompressedCiphertextList::Cpu(self.inner),
inner: self.inner.upgrade()?,
info,
tag: self.tag,
})
}
Expand All @@ -38,5 +189,6 @@ impl Upgrade<CompressedCiphertextList> for CompressedCiphertextListV1 {
pub enum CompressedCiphertextListVersions {
V0(CompressedCiphertextListV0),
V1(CompressedCiphertextListV1),
V2(CompressedCiphertextList),
V2(CompressedCiphertextListV2),
V3(CompressedCiphertextList),
}
Loading
Loading