diff --git a/src/core/backend/avx512/cm31.rs b/src/core/backend/avx512/cm31.rs index e76fbddbb..b2a060a65 100644 --- a/src/core/backend/avx512/cm31.rs +++ b/src/core/backend/avx512/cm31.rs @@ -1,9 +1,9 @@ -use std::ops::{Add, Mul, MulAssign, Sub}; +use std::ops::{Add, Mul, MulAssign, Neg, Sub}; use num_traits::{One, Zero}; use super::m31::{PackedBaseField, K_BLOCK_SIZE}; -use crate::core::fields::cm31::{CM31, P2}; +use crate::core::fields::cm31::CM31; use crate::core::fields::FieldExpOps; /// AVX implementation for the complex extension field of M31. @@ -69,10 +69,17 @@ impl MulAssign for PackedCM31 { *self = *self * rhs; } } +impl Neg for PackedCM31 { + type Output = Self; + fn neg(self) -> Self::Output { + Self([-self.a(), -self.b()]) + } +} impl FieldExpOps for PackedCM31 { fn inverse(&self) -> Self { assert!(!self.is_zero(), "0 has no inverse"); - self.pow((P2 - 2) as u128) + // 1 / (a + bi) = (a - bi) / (a^2 + b^2). + Self([self.a(), -self.b()]) * (self.a().square() + self.b().square()).inverse() } } diff --git a/src/core/backend/avx512/m31.rs b/src/core/backend/avx512/m31.rs index af07610a6..fa965c4f9 100644 --- a/src/core/backend/avx512/m31.rs +++ b/src/core/backend/avx512/m31.rs @@ -254,6 +254,8 @@ impl One for PackedBaseField { impl FieldExpOps for PackedBaseField { fn inverse(&self) -> Self { + // TODO(andrew): Use a better multiplication tree. Also for other constant powers in the + // code. assert!(!self.is_zero(), "0 has no inverse"); self.pow((P - 2) as u128) } diff --git a/src/core/backend/avx512/qm31.rs b/src/core/backend/avx512/qm31.rs index e0d06ffd4..f63d8df61 100644 --- a/src/core/backend/avx512/qm31.rs +++ b/src/core/backend/avx512/qm31.rs @@ -6,7 +6,7 @@ use num_traits::{One, Zero}; use super::cm31::PackedCM31; use super::m31::K_BLOCK_SIZE; use super::PackedBaseField; -use crate::core::fields::qm31::{P4, QM31}; +use crate::core::fields::qm31::QM31; use crate::core::fields::FieldExpOps; /// AVX implementation for an extension of CM31. @@ -125,10 +125,13 @@ impl MulAssign for PackedSecureField { } impl FieldExpOps for PackedSecureField { fn inverse(&self) -> Self { - // TODO(andrew): Use a better multiplication tree. Also for other constant powers in the - // code. assert!(!self.is_zero(), "0 has no inverse"); - self.pow(P4 - 2) + // (a + bu)^-1 = (a - bu) / (a^2 - (2+i)b^2). + let b2 = self.b().square(); + let ib2 = PackedCM31([-b2.b(), b2.a()]); + let denom = self.a().square() - (b2 + b2 + ib2); + let denom_inverse = denom.inverse(); + Self([self.a() * denom_inverse, -self.b() * denom_inverse]) } } diff --git a/src/core/backend/cpu/mod.rs b/src/core/backend/cpu/mod.rs index cee638a38..5edcfc469 100644 --- a/src/core/backend/cpu/mod.rs +++ b/src/core/backend/cpu/mod.rs @@ -57,6 +57,7 @@ mod tests { use rand::rngs::SmallRng; use crate::core::backend::{CPUBackend, Column, FieldOps}; + use crate::core::fields::m31::P; use crate::core::fields::qm31::QM31; use crate::core::fields::FieldExpOps; @@ -66,10 +67,10 @@ mod tests { let column: Vec = (0..16) .map(|_| { QM31::from_u32_unchecked( - rng.gen::(), - rng.gen::(), - rng.gen::(), - rng.gen::(), + rng.gen::() % P, + rng.gen::() % P, + rng.gen::() % P, + rng.gen::() % P, ) }) .collect(); @@ -88,10 +89,10 @@ mod tests { let column: Vec = (0..16) .map(|_| { QM31::from_u32_unchecked( - rng.gen::(), - rng.gen::(), - rng.gen::(), - rng.gen::(), + rng.gen::() % P, + rng.gen::() % P, + rng.gen::() % P, + rng.gen::() % P, ) }) .collect(); diff --git a/src/core/fields/cm31.rs b/src/core/fields/cm31.rs index c585855b6..b1f8e11fc 100644 --- a/src/core/fields/cm31.rs +++ b/src/core/fields/cm31.rs @@ -52,6 +52,14 @@ impl Mul for CM31 { } } +impl FieldExpOps for CM31 { + fn inverse(&self) -> Self { + assert!(!self.is_zero(), "0 has no inverse"); + // 1 / (a + bi) = (a - bi) / (a^2 + b^2). + Self(self.0, -self.1) * (self.0.square() + self.1.square()).inverse() + } +} + #[cfg(test)] #[macro_export] macro_rules! cm31 { @@ -66,9 +74,16 @@ mod tests { use super::CM31; use crate::core::fields::m31::P; - use crate::core::fields::IntoSlice; + use crate::core::fields::{FieldExpOps, IntoSlice}; use crate::m31; + #[test] + fn test_inverse() { + let cm = cm31!(1, 2); + let cm_inv = cm.inverse(); + assert_eq!(cm * cm_inv, cm31!(1, 0)); + } + #[test] fn test_ops() { let cm0 = cm31!(1, 2); diff --git a/src/core/fields/m31.rs b/src/core/fields/m31.rs index 20fb92424..37809cbc4 100644 --- a/src/core/fields/m31.rs +++ b/src/core/fields/m31.rs @@ -78,6 +78,13 @@ impl Mul for M31 { } } +impl FieldExpOps for M31 { + fn inverse(&self) -> Self { + assert!(!self.is_zero(), "0 has no inverse"); + self.pow(P as u128 - 2) + } +} + impl ComplexConjugate for M31 { fn complex_conjugate(&self) -> Self { *self diff --git a/src/core/fields/mod.rs b/src/core/fields/mod.rs index 2ab0c9c37..b35d49f80 100644 --- a/src/core/fields/mod.rs +++ b/src/core/fields/mod.rs @@ -231,13 +231,6 @@ macro_rules! impl_field { } } - impl FieldExpOps for $field_name { - fn inverse(&self) -> Self { - assert!(!self.is_zero(), "0 has no inverse"); - self.pow(($field_size - 2) as u128) - } - } - impl Product for $field_name { fn product(mut iter: I) -> Self where diff --git a/src/core/fields/qm31.rs b/src/core/fields/qm31.rs index 3862ba321..9b739b15f 100644 --- a/src/core/fields/qm31.rs +++ b/src/core/fields/qm31.rs @@ -66,6 +66,18 @@ impl Mul for QM31 { } } +impl FieldExpOps for QM31 { + fn inverse(&self) -> Self { + assert!(!self.is_zero(), "0 has no inverse"); + // (a + bu)^-1 = (a - bu) / (a^2 - (2+i)b^2). + let b2 = self.1.square(); + let ib2 = CM31(-b2.1, b2.0); + let denom = self.0.square() - (b2 + b2 + ib2); + let denom_inverse = denom.inverse(); + Self(self.0 * denom_inverse, -self.1 * denom_inverse) + } +} + #[cfg(test)] #[macro_export] macro_rules! qm31 { @@ -77,13 +89,21 @@ macro_rules! qm31 { #[cfg(test)] mod tests { + use num_traits::One; use rand::Rng; use super::QM31; use crate::core::fields::m31::P; - use crate::core::fields::IntoSlice; + use crate::core::fields::{FieldExpOps, IntoSlice}; use crate::m31; + #[test] + fn test_inverse() { + let qm = qm31!(1, 2, 3, 4); + let qm_inv = qm.inverse(); + assert_eq!(qm * qm_inv, QM31::one()); + } + #[test] fn test_ops() { let qm0 = qm31!(1, 2, 3, 4);