Skip to content

Commit

Permalink
QM31 (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware authored Jul 7, 2024
1 parent d0e10c2 commit 7af2aaf
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 6 deletions.
1 change: 1 addition & 0 deletions stwo_cairo_verifier/src/fields.cairo
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub mod m31;
pub mod cm31;
pub mod qm31;

pub type BaseField = m31::M31;
20 changes: 15 additions & 5 deletions stwo_cairo_verifier/src/fields/cm31.cairo
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
use super::m31::{M31, m31};
use core::num::traits::{One, Zero};
use super::m31::{M31, m31, M31Trait};

#[derive(Copy, Drop, Debug, PartialEq, Eq)]
pub struct CM31 {
a: M31,
b: M31,
pub a: M31,
pub b: M31,
}

#[generate_trait]
pub impl CM31Impl of CM31Trait {
fn inverse(self: CM31) -> CM31 {
assert_ne!(self, Zero::zero());
let denom_inverse: M31 = (self.a * self.a + self.b * self.b).inverse();
CM31 { a: self.a * denom_inverse, b: -self.b * denom_inverse }
}
}

pub impl CM31Add of core::traits::Add<CM31> {
Expand All @@ -21,7 +31,7 @@ pub impl CM31Mul of core::traits::Mul<CM31> {
CM31 { a: lhs.a * rhs.a - lhs.b * rhs.b, b: lhs.a * rhs.b + lhs.b * rhs.a }
}
}
pub impl CM31Zero of core::num::traits::Zero<CM31> {
pub impl CM31Zero of Zero<CM31> {
fn zero() -> CM31 {
cm31(0, 0)
}
Expand All @@ -32,7 +42,7 @@ pub impl CM31Zero of core::num::traits::Zero<CM31> {
(*self).a.is_non_zero() || (*self).b.is_non_zero()
}
}
pub impl CM31One of core::num::traits::One<CM31> {
pub impl CM31One of One<CM31> {
fn one() -> CM31 {
cm31(1, 0)
}
Expand Down
28 changes: 27 additions & 1 deletion stwo_cairo_verifier/src/fields/m31.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,25 @@ pub impl M31Impl of M31Trait {
let (_, res) = core::integer::u64_safe_divmod(val, P64NZ);
M31 { inner: res.try_into().unwrap() }
}

#[inline]
fn sqn(v: M31, n: usize) -> M31 {
if n == 0 {
return v;
}
Self::sqn(v * v, n - 1)
}

fn inverse(self: M31) -> M31 {
assert_ne!(self, core::num::traits::Zero::zero());
let t0 = Self::sqn(self, 2) * self;
let t1 = Self::sqn(t0, 1) * t0;
let t2 = Self::sqn(t1, 3) * t0;
let t3 = Self::sqn(t2, 1) * t0;
let t4 = Self::sqn(t3, 8) * t3;
let t5 = Self::sqn(t4, 8) * t3;
Self::sqn(t5, 7) * t2
}
}
pub impl M31Add of core::traits::Add<M31> {
fn add(lhs: M31, rhs: M31) -> M31 {
Expand Down Expand Up @@ -78,7 +97,9 @@ pub fn m31(val: u32) -> M31 {

#[cfg(test)]
mod tests {
use super::{m31, P};
use super::{m31, P, M31, M31Trait};
const POW2_15: u32 = 0b1000000000000000;
const POW2_16: u32 = 0b10000000000000000;

#[test]
fn test_m31() {
Expand All @@ -90,4 +111,9 @@ mod tests {
assert_eq!(m31(0) - m31(1), m31(P - 1));
assert_eq!(m31(0) - m31(P - 1), m31(1));
}

#[test]
fn test_m31_inv() {
assert_eq!(m31(POW2_15).inverse(), m31(POW2_16));
}
}
108 changes: 108 additions & 0 deletions stwo_cairo_verifier/src/fields/qm31.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
use super::m31::{M31, m31};
use super::cm31::{CM31, cm31, CM31Trait};
use core::num::traits::zero::Zero;
use core::num::traits::one::One;

pub const R: CM31 = CM31 { a: M31 { inner: 2 }, b: M31 { inner: 1 } };

#[derive(Copy, Drop, Debug, PartialEq, Eq)]
pub struct QM31 {
a: CM31,
b: CM31,
}

#[generate_trait]
impl QM31Impl of QM31Trait {
fn inverse(self: QM31) -> QM31 {
assert_ne!(self, Zero::zero());
let b2 = self.b * self.b;
let ib2 = CM31 { a: -b2.b, b: b2.a };
let denom = self.a * self.a - (b2 + b2 + ib2);
let denom_inverse = denom.inverse();
QM31 { a: self.a * denom_inverse, b: -self.b * denom_inverse }
}
}

pub impl QM31Add of core::traits::Add<QM31> {
fn add(lhs: QM31, rhs: QM31) -> QM31 {
QM31 { a: lhs.a + rhs.a, b: lhs.b + rhs.b }
}
}
pub impl QM31Sub of core::traits::Sub<QM31> {
fn sub(lhs: QM31, rhs: QM31) -> QM31 {
QM31 { a: lhs.a - rhs.a, b: lhs.b - rhs.b }
}
}
pub impl QM31Mul of core::traits::Mul<QM31> {
fn mul(lhs: QM31, rhs: QM31) -> QM31 {
// (a + bu) * (c + du) = (ac + rbd) + (ad + bc)u.
QM31 { a: lhs.a * rhs.a + R * lhs.b * rhs.b, b: lhs.a * rhs.b + lhs.b * rhs.a }
}
}
pub impl QM31Zero of Zero<QM31> {
fn zero() -> QM31 {
QM31 { a: Zero::zero(), b: Zero::zero() }
}
fn is_zero(self: @QM31) -> bool {
(*self).a.is_zero() && (*self).b.is_zero()
}
fn is_non_zero(self: @QM31) -> bool {
(*self).a.is_non_zero() || (*self).b.is_non_zero()
}
}
pub impl QM31One of One<QM31> {
fn one() -> QM31 {
QM31 { a: One::one(), b: Zero::zero() }
}
fn is_one(self: @QM31) -> bool {
(*self).a.is_one() && (*self).b.is_zero()
}
fn is_non_one(self: @QM31) -> bool {
(*self).a.is_non_one() || (*self).b.is_non_zero()
}
}
pub impl M31IntoQM31 of core::traits::Into<M31, QM31> {
fn into(self: M31) -> QM31 {
QM31 { a: self.into(), b: Zero::zero() }
}
}
pub impl CM31IntoQM31 of core::traits::Into<CM31, QM31> {
fn into(self: CM31) -> QM31 {
QM31 { a: self, b: Zero::zero() }
}
}
pub impl QM31Neg of Neg<QM31> {
fn neg(a: QM31) -> QM31 {
QM31 { a: -a.a, b: -a.b }
}
}

pub fn qm31(a: u32, b: u32, c: u32, d: u32) -> QM31 {
QM31 { a: cm31(a, b), b: cm31(c, d) }
}


#[cfg(test)]
mod tests {
use super::{QM31, qm31, QM31Trait};
use super::super::m31::{m31, P, M31Trait};

#[test]
fn test_QM31() {
let qm0 = qm31(1, 2, 3, 4);
let qm1 = qm31(4, 5, 6, 7);
let m = m31(8);
let qm = Into::<_, QM31>::into(m);
let qm0_x_qm1 = qm31(P - 71, 93, P - 16, 50);

assert_eq!(qm0 + qm1, qm31(5, 7, 9, 11));
assert_eq!(qm1 + m.into(), qm1 + qm);
assert_eq!(qm0 * qm1, qm0_x_qm1);
assert_eq!(qm1 * m.into(), qm1 * qm);
assert_eq!(-qm0, qm31(P - 1, P - 2, P - 3, P - 4));
assert_eq!(qm0 - qm1, qm31(P - 3, P - 3, P - 3, P - 3));
assert_eq!(qm1 - m.into(), qm1 - qm);
assert_eq!(qm0_x_qm1 * qm1.inverse(), qm31(1, 2, 3, 4));
assert_eq!(qm1 * m.inverse().into(), qm1 * qm.inverse());
}
}

0 comments on commit 7af2aaf

Please sign in to comment.