From 54fa6ca26099bdf1454ffbe5a4267d737683b6e6 Mon Sep 17 00:00:00 2001 From: Shahar Papini <43779613+spapinistarkware@users.noreply.github.com> Date: Sun, 3 Mar 2024 11:14:05 +0200 Subject: [PATCH] Packed QM31 (#416) --- src/core/backend/avx512/mod.rs | 1 + src/core/backend/avx512/qm31.rs | 112 ++++++++++++++++++++++++++++++++ src/core/fields/qm31.rs | 2 +- 3 files changed, 114 insertions(+), 1 deletion(-) create mode 100644 src/core/backend/avx512/qm31.rs diff --git a/src/core/backend/avx512/mod.rs b/src/core/backend/avx512/mod.rs index f553ebd7c..a84fb28f0 100644 --- a/src/core/backend/avx512/mod.rs +++ b/src/core/backend/avx512/mod.rs @@ -1,6 +1,7 @@ pub mod bit_reverse; pub mod cm31; pub mod m31; +pub mod qm31; use bytemuck::{cast_slice, cast_slice_mut, Pod, Zeroable}; use num_traits::Zero; diff --git a/src/core/backend/avx512/qm31.rs b/src/core/backend/avx512/qm31.rs new file mode 100644 index 000000000..7cbb9b196 --- /dev/null +++ b/src/core/backend/avx512/qm31.rs @@ -0,0 +1,112 @@ +use std::ops::{Add, Mul, Sub}; + +use super::cm31::PackedCM31; +use super::m31::K_BLOCK_SIZE; +use crate::core::fields::qm31::QM31; + +/// AVX implementation for an extension of CM31. +/// See [crate::core::fields::qm31::QM31] for more information. +#[derive(Copy, Clone)] +pub struct PackedQM31(pub [PackedCM31; 2]); +impl PackedQM31 { + pub fn a(&self) -> PackedCM31 { + self.0[0] + } + pub fn b(&self) -> PackedCM31 { + self.0[1] + } + pub fn to_array(&self) -> [QM31; K_BLOCK_SIZE] { + std::array::from_fn(|i| QM31(self.a().to_array()[i], self.b().to_array()[i])) + } +} +impl Add for PackedQM31 { + type Output = Self; + fn add(self, rhs: Self) -> Self::Output { + Self([self.a() + rhs.a(), self.b() + rhs.b()]) + } +} +impl Sub for PackedQM31 { + type Output = Self; + fn sub(self, rhs: Self) -> Self::Output { + Self([self.a() - rhs.a(), self.b() - rhs.b()]) + } +} +impl Mul for PackedQM31 { + type Output = Self; + fn mul(self, rhs: Self) -> Self::Output { + // Compute using Karatsuba. + // (a + ub) * (c + ud) = + // (ac + (1+2i)bd) + (ad + bc)u = + // ac + bd + 2ibd + (ad + bc)u. + let ac = self.a() * rhs.a(); + let bd = self.b() * rhs.b(); + let bd2 = bd + bd; + // Computes ac + bd. + let ac_p_bd = ac + bd; + // Computes ad + bc. + let ad_p_bc = (self.a() + self.b()) * (rhs.a() + rhs.b()) - ac_p_bd; + // ac + bd + 2ibd = + // ac + bd -Im(2bd) + iRe(2bd) + let l = PackedCM31([ac_p_bd.a() - bd2.b(), ac_p_bd.b() + bd2.a()]); + Self([l, ad_p_bc]) + } +} + +#[cfg(test)] +mod tests { + use rand::rngs::StdRng; + use rand::{Rng, SeedableRng}; + + use super::*; + use crate::core::backend::avx512::m31::PackedBaseField; + use crate::core::fields::m31::{M31, P}; + + #[test] + fn test_qm31avx512_basic_ops() { + let rng = &mut StdRng::seed_from_u64(0); + let x = PackedQM31([ + PackedCM31([ + PackedBaseField::from_array(std::array::from_fn(|_| { + M31::from(rng.gen::() % P) + })), + PackedBaseField::from_array(std::array::from_fn(|_| { + M31::from(rng.gen::() % P) + })), + ]), + PackedCM31([ + PackedBaseField::from_array(std::array::from_fn(|_| { + M31::from(rng.gen::() % P) + })), + PackedBaseField::from_array(std::array::from_fn(|_| { + M31::from(rng.gen::() % P) + })), + ]), + ]); + let y = PackedQM31([ + PackedCM31([ + PackedBaseField::from_array(std::array::from_fn(|_| { + M31::from(rng.gen::() % P) + })), + PackedBaseField::from_array(std::array::from_fn(|_| { + M31::from(rng.gen::() % P) + })), + ]), + PackedCM31([ + PackedBaseField::from_array(std::array::from_fn(|_| { + M31::from(rng.gen::() % P) + })), + PackedBaseField::from_array(std::array::from_fn(|_| { + M31::from(rng.gen::() % P) + })), + ]), + ]); + let sum = x + y; + let diff = x - y; + let prod = x * y; + for i in 0..16 { + assert_eq!(sum.to_array()[i], x.to_array()[i] + y.to_array()[i]); + assert_eq!(diff.to_array()[i], x.to_array()[i] - y.to_array()[i]); + assert_eq!(prod.to_array()[i], x.to_array()[i] * y.to_array()[i]); + } + } +} diff --git a/src/core/fields/qm31.rs b/src/core/fields/qm31.rs index 1deda61b1..bc3c23399 100644 --- a/src/core/fields/qm31.rs +++ b/src/core/fields/qm31.rs @@ -16,7 +16,7 @@ pub const R: CM31 = CM31::from_u32_unchecked(1, 2); /// Equivalent to CM31\[x\] over (x^2 - 1 - 2i) as the irreducible polynomial. /// Represented as ((a, b), (c, d)) of (a + bi) + (c + di)u. #[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct QM31(CM31, CM31); +pub struct QM31(pub CM31, pub CM31); pub type SecureField = QM31; impl_field!(QM31, P4);