diff --git a/crates/prover/src/constraint_framework/info.rs b/crates/prover/src/constraint_framework/info.rs index 05da93f6f..d2f42091c 100644 --- a/crates/prover/src/constraint_framework/info.rs +++ b/crates/prover/src/constraint_framework/info.rs @@ -1,18 +1,24 @@ -use std::ops::Mul; +use std::array; +use std::cell::{RefCell, RefMut}; +use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub}; +use std::rc::Rc; -use num_traits::One; +use num_traits::{One, Zero}; use super::EvalAtRow; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; +use crate::core::fields::FieldExpOps; use crate::core::pcs::TreeVec; /// Collects information about the constraints. -/// This includes mask offsets and columns at each interaction, and the number of constraints. +/// This includes mask offsets and columns at each interaction, the number of constraints and number +/// of arithmetic operations. #[derive(Default)] pub struct InfoEvaluator { pub mask_offsets: TreeVec>>, pub n_constraints: usize, + pub arithmetic_counts: ArithmeticCounts, } impl InfoEvaluator { pub fn new() -> Self { @@ -20,8 +26,8 @@ impl InfoEvaluator { } } impl EvalAtRow for InfoEvaluator { - type F = BaseField; - type EF = SecureField; + type F = BaseFieldCounter; + type EF = SecureFieldCounter; fn next_interaction_mask( &mut self, interaction: usize, @@ -33,16 +39,330 @@ impl EvalAtRow for InfoEvaluator { self.mask_offsets.resize(interaction + 1, vec![]); } self.mask_offsets[interaction].push(offsets.into_iter().collect()); - [BaseField::one(); N] + array::from_fn(|_| BaseFieldCounter::one()) } - fn add_constraint(&mut self, _constraint: G) + fn add_constraint(&mut self, constraint: G) where Self::EF: Mul, { + let lin_combination = SecureFieldCounter::one() + SecureFieldCounter::one() * constraint; + self.arithmetic_counts.merge(lin_combination.drain()); self.n_constraints += 1; } - fn combine_ef(_values: [Self::F; 4]) -> Self::EF { - SecureField::one() + fn combine_ef(values: [Self::F; 4]) -> Self::EF { + let mut res = SecureFieldCounter::zero(); + values.map(|v| res.merge(v)); + res + } +} + +/// Stores a count of field operations. +#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)] +pub struct ArithmeticCounts { + /// Number of `ExtensionField * ExtensionField` operations. + pub n_ef_mul_ef: usize, + /// Number of `ExtensionField * BaseField` operations. + pub n_ef_mul_f: usize, + /// Number of `ExtensionField + ExtensionField` operations. + pub n_ef_add_ef: usize, + /// Number of `ExtensionField + BaseField` operations. + pub n_ef_add_f: usize, + /// Number of `BaseField * BaseField` operations. + pub n_f_mul_f: usize, + /// Number of `BaseField + BaseField` operations. + pub n_f_add_f: usize, +} + +impl ArithmeticCounts { + fn merge(&mut self, other: ArithmeticCounts) { + self.n_ef_mul_ef += other.n_ef_mul_ef; + self.n_ef_mul_f += other.n_ef_mul_f; + self.n_ef_add_f += other.n_ef_add_f; + self.n_ef_add_ef += other.n_ef_add_ef; + self.n_f_mul_f += other.n_f_mul_f; + self.n_f_add_f += other.n_f_add_f; + } +} + +#[derive(Debug, Default, Clone)] +pub struct ArithmeticCounter(Rc>); + +pub type BaseFieldCounter = ArithmeticCounter; + +pub type SecureFieldCounter = ArithmeticCounter; + +impl ArithmeticCounter { + fn merge( + &mut self, + other: ArithmeticCounter, + ) { + // Skip if they come from the same source. + if Rc::ptr_eq(&self.0, &other.0) { + return; + } + + self.counts().merge(other.drain()); + } + + fn drain(self) -> ArithmeticCounts { + self.0.take() + } + + fn counts(&mut self) -> RefMut<'_, ArithmeticCounts> { + self.0.borrow_mut() + } +} + +impl Zero for ArithmeticCounter { + fn zero() -> Self { + Self::default() + } + + fn is_zero(&self) -> bool { + // TODO(andrew): Consider removing Zero from EvalAtRow::F, EvalAtRow::EF since is_zero + // doesn't make sense. Creating zero elements does though. + panic!() + } +} + +impl One for ArithmeticCounter { + fn one() -> Self { + Self::default() + } +} + +impl Add for ArithmeticCounter { + type Output = Self; + + fn add(mut self, rhs: Self) -> Self { + self.merge(rhs); + { + let mut counts = self.counts(); + match IS_EXT_FIELD { + true => counts.n_ef_add_ef += 1, + false => counts.n_f_add_f += 1, + } + } + self + } +} + +impl Sub for ArithmeticCounter { + type Output = Self; + + #[allow(clippy::suspicious_arithmetic_impl)] + fn sub(self, rhs: Self) -> Self { + // Treat as addition. + self + rhs + } +} + +impl Add for SecureFieldCounter { + type Output = Self; + + fn add(mut self, rhs: BaseFieldCounter) -> Self { + self.merge(rhs); + self.counts().n_ef_add_f += 1; + self + } +} + +impl Mul for ArithmeticCounter { + type Output = Self; + + fn mul(mut self, rhs: Self) -> Self { + self.merge(rhs); + { + let mut counts = self.counts(); + match IS_EXT_FIELD { + true => counts.n_ef_mul_ef += 1, + false => counts.n_f_mul_f += 1, + } + } + self + } +} + +impl Mul for SecureFieldCounter { + type Output = SecureFieldCounter; + + #[allow(clippy::suspicious_arithmetic_impl)] + fn mul(mut self, rhs: BaseFieldCounter) -> Self { + self.merge(rhs); + self.counts().n_ef_mul_f += 1; + self + } +} + +impl MulAssign for ArithmeticCounter { + fn mul_assign(&mut self, rhs: Self) { + *self = self.clone() * rhs + } +} + +impl AddAssign for ArithmeticCounter { + fn add_assign(&mut self, rhs: Self) { + *self = self.clone() + rhs + } +} + +impl AddAssign for BaseFieldCounter { + fn add_assign(&mut self, _rhs: BaseField) { + *self = self.clone() + BaseFieldCounter::zero() + } +} + +impl Mul for BaseFieldCounter { + type Output = Self; + + #[allow(clippy::suspicious_arithmetic_impl)] + fn mul(self, _rhs: BaseField) -> Self { + self * BaseFieldCounter::zero() + } +} + +impl Mul for SecureFieldCounter { + type Output = Self; + + #[allow(clippy::suspicious_arithmetic_impl)] + fn mul(self, _rhs: SecureField) -> Self { + self * SecureFieldCounter::zero() + } +} + +impl Add for BaseFieldCounter { + type Output = SecureFieldCounter; + + fn add(self, _rhs: SecureField) -> SecureFieldCounter { + SecureFieldCounter::zero() + self + } +} + +impl Add for SecureFieldCounter { + type Output = Self; + + fn add(self, _rhs: SecureField) -> Self { + self + SecureFieldCounter::zero() + } +} + +impl Sub for SecureFieldCounter { + type Output = Self; + + #[allow(clippy::suspicious_arithmetic_impl)] + fn sub(self, rhs: SecureField) -> Self { + // Tread subtraction as addition + self + rhs + } +} + +impl Mul for BaseFieldCounter { + type Output = SecureFieldCounter; + + fn mul(self, _rhs: SecureField) -> SecureFieldCounter { + SecureFieldCounter::zero() * self + } +} + +impl From for BaseFieldCounter { + fn from(_value: BaseField) -> Self { + Self::one() + } +} + +impl From for SecureFieldCounter { + fn from(_value: SecureField) -> Self { + Self::one() + } +} + +impl From for SecureFieldCounter { + fn from(value: BaseFieldCounter) -> Self { + Self(value.0) + } +} + +impl Neg for ArithmeticCounter { + type Output = Self; + + fn neg(self) -> Self { + // Treat as addition. + self + ArithmeticCounter::::zero() + } +} + +impl FieldExpOps for ArithmeticCounter { + fn inverse(&self) -> Self { + todo!() + } +} + +#[cfg(test)] +mod tests { + use num_traits::{One, Zero}; + + use super::SecureFieldCounter; + use crate::constraint_framework::info::{ArithmeticCounts, BaseFieldCounter}; + use crate::core::fields::m31::BaseField; + use crate::core::fields::qm31::SecureField; + + #[test] + fn test_arithmetic_counter() { + const N_EF_MUL_EF: usize = 1; + const N_EF_MUL_F: usize = 2; + const N_EF_MUL_ASSIGN_EF: usize = 1; + const N_EF_MUL_SECURE_FIELD: usize = 3; + const N_EF_ADD_EF: usize = 4; + const N_EF_ADD_ASSIGN_EF: usize = 4; + const N_EF_ADD_F: usize = 5; + const N_EF_NEG: usize = 6; + const N_EF_SUB_EF: usize = 7; + const N_F_MUL_F: usize = 8; + const N_F_MUL_ASSIGN_F: usize = 8; + const N_F_MUL_BASE_FIELD: usize = 9; + const N_F_ADD_F: usize = 10; + const N_F_ADD_ASSIGN_F: usize = 4; + const N_F_ADD_ASSIGN_BASE_FIELD: usize = 4; + const N_F_NEG: usize = 11; + const N_F_SUB_F: usize = 12; + let mut ef = SecureFieldCounter::zero(); + let mut f = BaseFieldCounter::zero(); + + (0..N_EF_MUL_EF).for_each(|_| ef = ef.clone() * ef.clone()); + (0..N_EF_MUL_F).for_each(|_| ef = ef.clone() * f.clone()); + (0..N_EF_MUL_SECURE_FIELD).for_each(|_| ef = ef.clone() * SecureField::one()); + (0..N_EF_MUL_ASSIGN_EF).for_each(|_| ef *= ef.clone()); + (0..N_EF_ADD_EF).for_each(|_| ef = ef.clone() + ef.clone()); + (0..N_EF_ADD_ASSIGN_EF).for_each(|_| ef += ef.clone()); + (0..N_EF_ADD_F).for_each(|_| ef = ef.clone() + f.clone()); + (0..N_EF_NEG).for_each(|_| ef = -ef.clone()); + (0..N_EF_SUB_EF).for_each(|_| ef = ef.clone() - ef.clone()); + (0..N_F_MUL_F).for_each(|_| f = f.clone() * f.clone()); + (0..N_F_MUL_ASSIGN_F).for_each(|_| f *= f.clone()); + (0..N_F_MUL_BASE_FIELD).for_each(|_| f = f.clone() * BaseField::one()); + (0..N_F_ADD_F).for_each(|_| f = f.clone() + f.clone()); + (0..N_F_ADD_ASSIGN_F).for_each(|_| f += f.clone()); + (0..N_F_ADD_ASSIGN_BASE_FIELD).for_each(|_| f += BaseField::one()); + (0..N_F_NEG).for_each(|_| f = -f.clone()); + (0..N_F_SUB_F).for_each(|_| f = f.clone() - f.clone()); + let mut res = f.drain(); + res.merge(ef.drain()); + + assert_eq!( + res, + ArithmeticCounts { + n_ef_mul_ef: N_EF_MUL_EF + N_EF_MUL_SECURE_FIELD + N_EF_MUL_ASSIGN_EF, + n_ef_mul_f: N_EF_MUL_F, + n_ef_add_ef: N_EF_ADD_EF + N_EF_NEG + N_EF_SUB_EF + N_EF_ADD_ASSIGN_EF, + n_ef_add_f: N_EF_ADD_F, + n_f_mul_f: N_F_MUL_F + N_F_MUL_BASE_FIELD + N_F_MUL_ASSIGN_F, + n_f_add_f: N_F_ADD_F + + N_F_NEG + + N_F_SUB_F + + N_F_ADD_ASSIGN_BASE_FIELD + + N_F_ADD_ASSIGN_F, + } + ); } } diff --git a/crates/prover/src/constraint_framework/logup.rs b/crates/prover/src/constraint_framework/logup.rs index a608d89b0..256228aa3 100644 --- a/crates/prover/src/constraint_framework/logup.rs +++ b/crates/prover/src/constraint_framework/logup.rs @@ -47,9 +47,9 @@ impl LogupAtRow { pub fn write_frac(&mut self, eval: &mut E, fraction: Fraction) { // Add a constraint that num / denom = diff. - if let Some(cur_frac) = self.cur_frac { - let cur_cumsum = eval.next_extension_interaction_mask(self.interaction, [0])[0]; - let diff = cur_cumsum - self.prev_col_cumsum; + if let Some(cur_frac) = self.cur_frac.clone() { + let [cur_cumsum] = eval.next_extension_interaction_mask(self.interaction, [0]); + let diff = cur_cumsum.clone() - self.prev_col_cumsum.clone(); self.prev_col_cumsum = cur_cumsum; eval.add_constraint(diff * cur_frac.denominator - cur_frac.numerator); } @@ -59,12 +59,12 @@ impl LogupAtRow { pub fn finalize(mut self, eval: &mut E) { assert!(!self.is_finalized, "LogupAtRow was already finalized"); - let frac = self.cur_frac.unwrap(); + let frac = self.cur_frac.clone().unwrap(); let [cur_cumsum, prev_row_cumsum] = eval.next_extension_interaction_mask(self.interaction, [0, -1]); - let diff = cur_cumsum - prev_row_cumsum - self.prev_col_cumsum; + let diff = cur_cumsum - prev_row_cumsum - self.prev_col_cumsum.clone(); // Instead of checking diff = num / denom, check diff = num / denom - cumsum_shift. // This makes (num / denom - cumsum_shift) have sum zero, which makes the constraint // uniform - apply on all rows. @@ -105,12 +105,12 @@ impl LookupElements { alpha_powers, } } - pub fn combine(&self, values: &[F]) -> EF + pub fn combine(&self, values: &[F]) -> EF where - EF: Copy + Zero + From + From + Mul + Sub, + EF: Clone + Zero + From + From + Mul + Sub, { - zip_eq(values, self.alpha_powers).fold(EF::zero(), |acc, (&value, power)| { - acc + EF::from(power) * value + zip_eq(values, self.alpha_powers).fold(EF::zero(), |acc, (value, power)| { + acc + EF::from(power) * value.clone() }) - EF::from(self.z) } // TODO(spapini): Try to remove this. @@ -262,7 +262,7 @@ mod tests { let mut logup = LogupAtRow::::new(1, SecureField::one(), 7); logup.write_frac( &mut InfoEvaluator::default(), - Fraction::new(SecureField::one(), SecureField::one()), + Fraction::new(SecureField::one().into(), SecureField::one().into()), ); } diff --git a/crates/prover/src/constraint_framework/mod.rs b/crates/prover/src/constraint_framework/mod.rs index 87069d344..44ad3a167 100644 --- a/crates/prover/src/constraint_framework/mod.rs +++ b/crates/prover/src/constraint_framework/mod.rs @@ -2,7 +2,7 @@ mod assert; mod component; pub mod constant_columns; -mod info; +pub mod info; pub mod logup; mod point; mod simd_domain; @@ -30,7 +30,7 @@ pub trait EvalAtRow { /// constraints. It might be [BaseField] packed types, or even [SecureField], when evaluating /// the columns out of domain. type F: FieldExpOps - + Copy + + Clone + Debug + Zero + Neg @@ -47,7 +47,7 @@ pub trait EvalAtRow { /// A field type representing the closure of `F` with multiplying by [SecureField]. Constraints /// usually get multiplied by [SecureField] values for security. type EF: One - + Copy + + Clone + Debug + Zero + From @@ -83,8 +83,12 @@ pub trait EvalAtRow { interaction: usize, offsets: [isize; N], ) -> [Self::EF; N] { - let res_col_major = array::from_fn(|_| self.next_interaction_mask(interaction, offsets)); - array::from_fn(|i| Self::combine_ef(res_col_major.map(|c| c[i]))) + let mut res_col_major = + array::from_fn(|_| self.next_interaction_mask(interaction, offsets)) + .map(|col| col.into_iter()); + array::from_fn(|_| { + Self::combine_ef(res_col_major.each_mut().map(|iter| iter.next().unwrap())) + }) } /// Adds a constraint to the component. diff --git a/crates/prover/src/core/backend/simd/very_packed_m31.rs b/crates/prover/src/core/backend/simd/very_packed_m31.rs index 2e344b8a6..3b3d065c3 100644 --- a/crates/prover/src/core/backend/simd/very_packed_m31.rs +++ b/crates/prover/src/core/backend/simd/very_packed_m31.rs @@ -212,7 +212,7 @@ impl One for Vectorized { } } -impl FieldExpOps for Vectorized { +impl FieldExpOps for Vectorized { fn inverse(&self) -> Self { Vectorized::from_fn(|i| { assert!(!self.0[i].is_zero(), "0 has no inverse"); diff --git a/crates/prover/src/core/circle.rs b/crates/prover/src/core/circle.rs index 8804840fd..eb363414b 100644 --- a/crates/prover/src/core/circle.rs +++ b/crates/prover/src/core/circle.rs @@ -25,7 +25,7 @@ impl + FieldExpOps + Sub + Neg } pub fn double(&self) -> Self { - *self + *self + self.clone() + self.clone() } /// Applies the circle's x-coordinate doubling map. @@ -40,7 +40,7 @@ impl + FieldExpOps + Sub + Neg /// ``` pub fn double_x(x: F) -> F { let sx = x.square(); - sx + sx - F::one() + sx.clone() + sx - F::one() } /// Returns the log order of a point. @@ -61,7 +61,7 @@ impl + FieldExpOps + Sub + Neg // we only need the x-coordinate to check order since the only point // with x=1 is the circle's identity let mut res = 0; - let mut cur = self.x; + let mut cur = self.x.clone(); while cur != F::one() { cur = Self::double_x(cur); res += 1; @@ -71,10 +71,10 @@ impl + FieldExpOps + Sub + Neg pub fn mul(&self, mut scalar: u128) -> CirclePoint { let mut res = Self::zero(); - let mut cur = *self; + let mut cur = self.clone(); while scalar > 0 { if scalar & 1 == 1 { - res = res + cur; + res = res + cur.clone(); } cur = cur.double(); scalar >>= 1; @@ -83,7 +83,7 @@ impl + FieldExpOps + Sub + Neg } pub fn repeated_double(&self, n: u32) -> Self { - let mut res = *self; + let mut res = self.clone(); for _ in 0..n { res = res.double(); } @@ -92,19 +92,19 @@ impl + FieldExpOps + Sub + Neg pub fn conjugate(&self) -> CirclePoint { Self { - x: self.x, - y: -self.y, + x: self.x.clone(), + y: -self.y.clone(), } } pub fn antipode(&self) -> CirclePoint { Self { - x: -self.x, - y: -self.y, + x: -self.x.clone(), + y: -self.y.clone(), } } - pub fn into_ef>(&self) -> CirclePoint { + pub fn into_ef>(self) -> CirclePoint { CirclePoint { x: self.x.into(), y: self.y.into(), @@ -126,7 +126,7 @@ impl + FieldExpOps + Sub + Neg type Output = Self; fn add(self, rhs: Self) -> Self::Output { - let x = self.x * rhs.x - self.y * rhs.y; + let x = self.x.clone() * rhs.x.clone() - self.y.clone() * rhs.y.clone(); let y = self.x * rhs.y + self.y * rhs.x; Self { x, y } } diff --git a/crates/prover/src/core/fields/m31.rs b/crates/prover/src/core/fields/m31.rs index 852f95909..0ceb6c091 100644 --- a/crates/prover/src/core/fields/m31.rs +++ b/crates/prover/src/core/fields/m31.rs @@ -186,11 +186,11 @@ macro_rules! m31 { /// assert_eq!(pow2147483645(v), v.pow(2147483645)); /// ``` pub fn pow2147483645(v: T) -> T { - let t0 = sqn::<2, T>(v) * v; - let t1 = sqn::<1, T>(t0) * t0; - let t2 = sqn::<3, T>(t1) * t0; - let t3 = sqn::<1, T>(t2) * t0; - let t4 = sqn::<8, T>(t3) * t3; + let t0 = sqn::<2, T>(v.clone()) * v; + let t1 = sqn::<1, T>(t0.clone()) * t0.clone(); + let t2 = sqn::<3, T>(t1) * t0.clone(); + let t3 = sqn::<1, T>(t2.clone()) * t0; + let t4 = sqn::<8, T>(t3.clone()) * t3.clone(); let t5 = sqn::<8, T>(t4) * t3; sqn::<7, T>(t5) * t2 } diff --git a/crates/prover/src/core/fields/mod.rs b/crates/prover/src/core/fields/mod.rs index fbeefbb94..1717d3979 100644 --- a/crates/prover/src/core/fields/mod.rs +++ b/crates/prover/src/core/fields/mod.rs @@ -1,5 +1,6 @@ +use std::array; use std::fmt::{Debug, Display}; -use std::iter::{Product, Sum}; +use std::iter::{zip, Product, Sum}; use std::ops::{Mul, MulAssign, Neg}; use num_traits::{NumAssign, NumAssignOps, NumOps, One}; @@ -16,18 +17,18 @@ pub trait FieldOps: ColumnOps { fn batch_inverse(column: &Self::Column, dst: &mut Self::Column); } -pub trait FieldExpOps: Mul + MulAssign + Sized + One + Copy { +pub trait FieldExpOps: Mul + MulAssign + Sized + One + Clone { fn square(&self) -> Self { - (*self) * (*self) + self.clone() * self.clone() } fn pow(&self, exp: u128) -> Self { let mut res = Self::one(); - let mut base = *self; + let mut base = self.clone(); let mut exp = exp; while exp > 0 { if exp & 1 == 1 { - res *= base; + res *= base.clone(); } base = base.square(); exp >>= 1; @@ -50,24 +51,24 @@ pub trait FieldExpOps: Mul + MulAssign + Sized + One + Copy { // First pass. Compute 'WIDTH' cumulative products in an interleaving fashion, reducing // instruction dependency and allowing better pipelining. - let mut cum_prod: [Self; WIDTH] = [Self::one(); WIDTH]; - dst[..WIDTH].copy_from_slice(&cum_prod); + let mut cum_prod: [Self; WIDTH] = array::from_fn(|_| Self::one()); + zip(&mut dst[..WIDTH], cum_prod.clone()).for_each(|(dst, v)| *dst = v); for i in 0..n { - cum_prod[i % WIDTH] *= column[i]; - dst[i] = cum_prod[i % WIDTH]; + cum_prod[i % WIDTH] *= column[i].clone(); + dst[i] = cum_prod[i % WIDTH].clone(); } // Inverse cumulative products. // Use classic batch inversion. - let mut tail_inverses = [Self::one(); WIDTH]; + let mut tail_inverses: [Self; WIDTH] = array::from_fn(|_| Self::one()); batch_inverse_classic(&dst[n - WIDTH..], &mut tail_inverses); // Second pass. for i in (WIDTH..n).rev() { - dst[i] = dst[i - WIDTH] * tail_inverses[i % WIDTH]; - tail_inverses[i % WIDTH] *= column[i]; + dst[i] = dst[i - WIDTH].clone() * tail_inverses[i % WIDTH].clone(); + tail_inverses[i % WIDTH] *= column[i].clone(); } - dst[0..WIDTH].copy_from_slice(&tail_inverses); + zip(&mut dst[0..WIDTH], tail_inverses).for_each(|(dst, v)| *dst = v); } } @@ -76,10 +77,10 @@ fn batch_inverse_classic(column: &[T], dst: &mut [T]) { let n = column.len(); debug_assert!(dst.len() >= n); - dst[0] = column[0]; + dst[0] = column[0].clone(); // First pass. for i in 1..n { - dst[i] = dst[i - 1] * column[i]; + dst[i] = dst[i - 1].clone() * column[i].clone(); } // Inverse cumulative product. @@ -87,8 +88,8 @@ fn batch_inverse_classic(column: &[T], dst: &mut [T]) { // Second pass. for i in (1..n).rev() { - dst[i] = dst[i - 1] * curr_inverse; - curr_inverse *= column[i]; + dst[i] = dst[i - 1].clone() * curr_inverse.clone(); + curr_inverse *= column[i].clone(); } dst[0] = curr_inverse; } diff --git a/crates/prover/src/core/lookups/utils.rs b/crates/prover/src/core/lookups/utils.rs index 85ea4c32a..3224dd0ce 100644 --- a/crates/prover/src/core/lookups/utils.rs +++ b/crates/prover/src/core/lookups/utils.rs @@ -210,14 +210,15 @@ impl Fraction { } } -impl + Add + Mul + Mul + Copy> Add +impl + Add + Mul + Mul + Clone> Add for Fraction { type Output = Fraction; fn add(self, rhs: Self) -> Fraction { Fraction { - numerator: rhs.denominator * self.numerator + self.denominator * rhs.numerator, + numerator: rhs.denominator.clone() * self.numerator + + self.denominator.clone() * rhs.numerator, denominator: self.denominator * rhs.denominator, } } @@ -260,13 +261,25 @@ impl Reciprocal { } } -impl + Mul + Copy> Add for Reciprocal { +impl + Mul + Clone> Add for Reciprocal { type Output = Fraction; fn add(self, rhs: Self) -> Fraction { // `1/a + 1/b = (a + b)/(a * b)` Fraction { - numerator: self.x + rhs.x, + numerator: self.x.clone() + rhs.x.clone(), + denominator: self.x * rhs.x, + } + } +} + +impl + Mul + Clone> Sub for Reciprocal { + type Output = Fraction; + + fn sub(self, rhs: Self) -> Fraction { + // `1/a - 1/b = (a - b)/(a * b)` + Fraction { + numerator: self.x.clone() - rhs.x.clone(), denominator: self.x * rhs.x, } } diff --git a/crates/prover/src/examples/blake/mod.rs b/crates/prover/src/examples/blake/mod.rs index 6fbe6d81b..3e4ad042e 100644 --- a/crates/prover/src/examples/blake/mod.rs +++ b/crates/prover/src/examples/blake/mod.rs @@ -88,11 +88,11 @@ impl BlakeXorElements { } /// Utility for representing a u32 as two field elements, for constraint evaluation. -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Debug)] struct Fu32 where F: FieldExpOps - + Copy + + Clone + Debug + AddAssign + Add @@ -105,14 +105,14 @@ where impl Fu32 where F: FieldExpOps - + Copy + + Clone + Debug + AddAssign + Add + Sub + Mul, { - fn to_felts(self) -> [F; 2] { + fn into_felts(self) -> [F; 2] { [self.l, self.h] } } diff --git a/crates/prover/src/examples/blake/round/constraints.rs b/crates/prover/src/examples/blake/round/constraints.rs index 0a2732d12..1ab3b7166 100644 --- a/crates/prover/src/examples/blake/round/constraints.rs +++ b/crates/prover/src/examples/blake/round/constraints.rs @@ -5,6 +5,7 @@ use super::{BlakeXorElements, RoundElements}; use crate::constraint_framework::logup::LogupAtRow; use crate::constraint_framework::EvalAtRow; use crate::core::fields::m31::BaseField; +use crate::core::fields::FieldExpOps; use crate::core::lookups::utils::Fraction; use crate::examples::blake::{Fu32, STATE_SIZE}; @@ -20,17 +21,18 @@ pub struct BlakeRoundEval<'a, E: EvalAtRow> { impl<'a, E: EvalAtRow> BlakeRoundEval<'a, E> { pub fn eval(mut self) -> E { let mut v: [Fu32; STATE_SIZE] = std::array::from_fn(|_| self.next_u32()); - let input_v = v; + let input_v = v.clone(); let m: [Fu32; STATE_SIZE] = std::array::from_fn(|_| self.next_u32()); - self.g(v.get_many_mut([0, 4, 8, 12]).unwrap(), m[0], m[1]); - self.g(v.get_many_mut([1, 5, 9, 13]).unwrap(), m[2], m[3]); - self.g(v.get_many_mut([2, 6, 10, 14]).unwrap(), m[4], m[5]); - self.g(v.get_many_mut([3, 7, 11, 15]).unwrap(), m[6], m[7]); - self.g(v.get_many_mut([0, 5, 10, 15]).unwrap(), m[8], m[9]); - self.g(v.get_many_mut([1, 6, 11, 12]).unwrap(), m[10], m[11]); - self.g(v.get_many_mut([2, 7, 8, 13]).unwrap(), m[12], m[13]); - self.g(v.get_many_mut([3, 4, 9, 14]).unwrap(), m[14], m[15]); + let [m0, m1, m2, m3, m4, m5, m6, m7, m8, m9, m10, m11, m12, m13, m14, m15] = m.clone(); + self.g(v.get_many_mut([0, 4, 8, 12]).unwrap(), m0, m1); + self.g(v.get_many_mut([1, 5, 9, 13]).unwrap(), m2, m3); + self.g(v.get_many_mut([2, 6, 10, 14]).unwrap(), m4, m5); + self.g(v.get_many_mut([3, 7, 11, 15]).unwrap(), m6, m7); + self.g(v.get_many_mut([0, 5, 10, 15]).unwrap(), m8, m9); + self.g(v.get_many_mut([1, 6, 11, 12]).unwrap(), m10, m11); + self.g(v.get_many_mut([2, 7, 8, 13]).unwrap(), m12, m13); + self.g(v.get_many_mut([3, 4, 9, 14]).unwrap(), m14, m15); // Yield `Round(input_v, output_v, message)`. self.logup.write_frac( @@ -39,9 +41,9 @@ impl<'a, E: EvalAtRow> BlakeRoundEval<'a, E> { -E::EF::one(), self.round_lookup_elements.combine( &chain![ - input_v.iter().copied().flat_map(Fu32::to_felts), - v.iter().copied().flat_map(Fu32::to_felts), - m.iter().copied().flat_map(Fu32::to_felts) + input_v.iter().cloned().flat_map(Fu32::into_felts), + v.iter().cloned().flat_map(Fu32::into_felts), + m.iter().cloned().flat_map(Fu32::into_felts) ] .collect_vec(), ), @@ -59,14 +61,14 @@ impl<'a, E: EvalAtRow> BlakeRoundEval<'a, E> { fn g(&mut self, v: [&mut Fu32; 4], m0: Fu32, m1: Fu32) { let [a, b, c, d] = v; - *a = self.add3_u32_unchecked(*a, *b, m0); - *d = self.xor_rotr16_u32(*a, *d); - *c = self.add2_u32_unchecked(*c, *d); - *b = self.xor_rotr_u32(*b, *c, 12); - *a = self.add3_u32_unchecked(*a, *b, m1); - *d = self.xor_rotr_u32(*a, *d, 8); - *c = self.add2_u32_unchecked(*c, *d); - *b = self.xor_rotr_u32(*b, *c, 7); + *a = self.add3_u32_unchecked(a.clone(), b.clone(), m0); + *d = self.xor_rotr16_u32(a.clone(), d.clone()); + *c = self.add2_u32_unchecked(c.clone(), d.clone()); + *b = self.xor_rotr_u32(b.clone(), c.clone(), 12); + *a = self.add3_u32_unchecked(a.clone(), b.clone(), m1); + *d = self.xor_rotr_u32(a.clone(), d.clone(), 8); + *c = self.add2_u32_unchecked(c.clone(), d.clone()); + *b = self.xor_rotr_u32(b.clone(), c.clone(), 7); } /// Adds two u32s, returning the sum. @@ -77,11 +79,12 @@ impl<'a, E: EvalAtRow> BlakeRoundEval<'a, E> { let sl = self.eval.next_trace_mask(); let sh = self.eval.next_trace_mask(); - let carry_l = (a.l + b.l - sl) * E::F::from(INV16); - self.eval.add_constraint(carry_l * carry_l - carry_l); + let carry_l = (a.l + b.l - sl.clone()) * E::F::from(INV16); + self.eval + .add_constraint(carry_l.clone().square() - carry_l.clone()); - let carry_h = (a.h + b.h + carry_l - sh) * E::F::from(INV16); - self.eval.add_constraint(carry_h * carry_h - carry_h); + let carry_h = (a.h + b.h + carry_l - sh.clone()) * E::F::from(INV16); + self.eval.add_constraint(carry_h.clone().square() - carry_h); Fu32 { l: sl, h: sh } } @@ -94,13 +97,15 @@ impl<'a, E: EvalAtRow> BlakeRoundEval<'a, E> { let sl = self.eval.next_trace_mask(); let sh = self.eval.next_trace_mask(); - let carry_l = (a.l + b.l + c.l - sl) * E::F::from(INV16); - self.eval - .add_constraint(carry_l * (carry_l - E::F::one()) * (carry_l - E::F::from(TWO))); + let carry_l = (a.l + b.l + c.l - sl.clone()) * E::F::from(INV16); + self.eval.add_constraint( + carry_l.clone() * (carry_l.clone() - E::F::one()) * (carry_l.clone() - E::F::from(TWO)), + ); - let carry_h = (a.h + b.h + c.h + carry_l - sh) * E::F::from(INV16); - self.eval - .add_constraint(carry_h * (carry_h - E::F::one()) * (carry_h - E::F::from(TWO))); + let carry_h = (a.h + b.h + c.h + carry_l - sh.clone()) * E::F::from(INV16); + self.eval.add_constraint( + carry_h.clone() * (carry_h.clone() - E::F::one()) * (carry_h - E::F::from(TWO)), + ); Fu32 { l: sl, h: sh } } @@ -109,7 +114,7 @@ impl<'a, E: EvalAtRow> BlakeRoundEval<'a, E> { /// Caller is responsible for checking that the ranges of h * 2^r and l don't overlap. fn split_unchecked(&mut self, a: E::F, r: u32) -> (E::F, E::F) { let h = self.eval.next_trace_mask(); - let l = a - h * E::F::from(BaseField::from_u32_unchecked(1 << r)); + let l = a - h.clone() * E::F::from(BaseField::from_u32_unchecked(1 << r)); (l, h) } @@ -154,10 +159,13 @@ impl<'a, E: EvalAtRow> BlakeRoundEval<'a, E> { // TODO: Separate lookups by w. let c = [self.eval.next_trace_mask(), self.eval.next_trace_mask()]; let lookup_elements = self.xor_lookup_elements.get(w); - let comb0 = lookup_elements.combine::(&[a[0], b[0], c[0]]); - let comb1 = lookup_elements.combine::(&[a[1], b[1], c[1]]); + let [a0, a1] = a; + let [b0, b1] = b; + let [c0, c1] = c.clone(); + let comb0 = lookup_elements.combine::(&[a0, b0, c0]); + let comb1 = lookup_elements.combine::(&[a1, b1, c1]); let frac = Fraction { - numerator: comb0 + comb1, + numerator: comb0.clone() + comb1.clone(), denominator: comb0 * comb1, }; diff --git a/crates/prover/src/examples/blake/scheduler/constraints.rs b/crates/prover/src/examples/blake/scheduler/constraints.rs index ee9a1c654..a5b296d15 100644 --- a/crates/prover/src/examples/blake/scheduler/constraints.rs +++ b/crates/prover/src/examples/blake/scheduler/constraints.rs @@ -4,7 +4,7 @@ use num_traits::Zero; use super::BlakeElements; use crate::constraint_framework::logup::LogupAtRow; use crate::constraint_framework::EvalAtRow; -use crate::core::lookups::utils::Fraction; +use crate::core::lookups::utils::{Fraction, Reciprocal}; use crate::core::vcs::blake2s_ref::SIGMA; use crate::examples::blake::round::RoundElements; use crate::examples::blake::{Fu32, N_ROUNDS, STATE_SIZE}; @@ -25,17 +25,17 @@ pub fn eval_blake_scheduler_constraints( let [denom_i, denom_j] = [i, j].map(|idx| { let input_state = &states[idx]; let output_state = &states[idx + 1]; - let round_messages = SIGMA[idx].map(|k| messages[k as usize]); + let round_messages = SIGMA[idx].map(|k| messages[k as usize].clone()); round_lookup_elements.combine::( &chain![ - input_state.iter().copied().flat_map(Fu32::to_felts), - output_state.iter().copied().flat_map(Fu32::to_felts), - round_messages.iter().copied().flat_map(Fu32::to_felts) + input_state.iter().cloned().flat_map(Fu32::into_felts), + output_state.iter().cloned().flat_map(Fu32::into_felts), + round_messages.iter().cloned().flat_map(Fu32::into_felts) ] .collect_vec(), ) }); - logup.write_frac(eval, Fraction::new(denom_i + denom_j, denom_i * denom_j)); + logup.write_frac(eval, Reciprocal::new(denom_i) + Reciprocal::new(denom_j)); } let input_state = &states[0]; @@ -49,9 +49,9 @@ pub fn eval_blake_scheduler_constraints( E::EF::zero(), blake_lookup_elements.combine( &chain![ - input_state.iter().copied().flat_map(Fu32::to_felts), - output_state.iter().copied().flat_map(Fu32::to_felts), - messages.iter().copied().flat_map(Fu32::to_felts) + input_state.iter().cloned().flat_map(Fu32::into_felts), + output_state.iter().cloned().flat_map(Fu32::into_felts), + messages.iter().cloned().flat_map(Fu32::into_felts) ] .collect_vec(), ), diff --git a/crates/prover/src/examples/blake/xor_table/constraints.rs b/crates/prover/src/examples/blake/xor_table/constraints.rs index f43d0088b..f1737e3f8 100644 --- a/crates/prover/src/examples/blake/xor_table/constraints.rs +++ b/crates/prover/src/examples/blake/xor_table/constraints.rs @@ -28,15 +28,15 @@ impl<'a, E: EvalAtRow, const ELEM_BITS: u32, const EXPAND_BITS: u32> let (i, j) = ((i >> EXPAND_BITS) as u32, (i % (1 << EXPAND_BITS)) as u32); let multiplicity = self.eval.next_trace_mask(); - let a = al + let a = al.clone() + E::F::from(BaseField::from_u32_unchecked( i << limb_bits::(), )); - let b = bl + let b = bl.clone() + E::F::from(BaseField::from_u32_unchecked( j << limb_bits::(), )); - let c = cl + let c = cl.clone() + E::F::from(BaseField::from_u32_unchecked( (i ^ j) << limb_bits::(), )); @@ -49,7 +49,7 @@ impl<'a, E: EvalAtRow, const ELEM_BITS: u32, const EXPAND_BITS: u32> .collect_vec(); for frac_chunk in frac_chunks.chunks(2) { - let sum_frac: Fraction = frac_chunk.iter().copied().sum(); + let sum_frac: Fraction = frac_chunk.iter().cloned().sum(); self.logup.write_frac(&mut self.eval, sum_frac); } self.logup.finalize(&mut self.eval); diff --git a/crates/prover/src/examples/plonk/mod.rs b/crates/prover/src/examples/plonk/mod.rs index f2340e681..a46862cac 100644 --- a/crates/prover/src/examples/plonk/mod.rs +++ b/crates/prover/src/examples/plonk/mod.rs @@ -14,7 +14,7 @@ use crate::core::backend::Column; use crate::core::channel::Blake2sChannel; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; -use crate::core::lookups::utils::Fraction; +use crate::core::lookups::utils::{Fraction, Reciprocal}; use crate::core::pcs::{CommitmentSchemeProver, PcsConfig, TreeSubspan}; use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps}; use crate::core::poly::BitReversedOrder; @@ -58,14 +58,17 @@ impl FrameworkEval for PlonkEval { let b_val = eval.next_trace_mask(); let c_val = eval.next_trace_mask(); - eval.add_constraint(c_val - op * (a_val + b_val) + (E::F::one() - op) * a_val * b_val); + eval.add_constraint( + c_val.clone() - op.clone() * (a_val.clone() + b_val.clone()) + + (E::F::one() - op) * a_val.clone() * b_val.clone(), + ); let denom_a: E::EF = self.lookup_elements.combine(&[a_wire, a_val]); let denom_b: E::EF = self.lookup_elements.combine(&[b_wire, b_val]); logup.write_frac( &mut eval, - Fraction::new(denom_a + denom_b, denom_a * denom_b), + Reciprocal::new(denom_a) + Reciprocal::new(denom_b), ); logup.write_frac( &mut eval, diff --git a/crates/prover/src/examples/poseidon/mod.rs b/crates/prover/src/examples/poseidon/mod.rs index d25cc1865..2bb854626 100644 --- a/crates/prover/src/examples/poseidon/mod.rs +++ b/crates/prover/src/examples/poseidon/mod.rs @@ -18,7 +18,7 @@ use crate::core::channel::Blake2sChannel; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; use crate::core::fields::FieldExpOps; -use crate::core::lookups::utils::Fraction; +use crate::core::lookups::utils::Reciprocal; use crate::core::pcs::{CommitmentSchemeProver, PcsConfig}; use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps}; use crate::core::poly::BitReversedOrder; @@ -26,7 +26,7 @@ use crate::core::prover::{prove, StarkProof}; use crate::core::vcs::blake2_merkle::{Blake2sMerkleChannel, Blake2sMerkleHasher}; use crate::core::ColumnVec; -const N_LOG_INSTANCES_PER_ROW: usize = 3; +const N_LOG_INSTANCES_PER_ROW: usize = 0; const N_INSTANCES_PER_ROW: usize = 1 << N_LOG_INSTANCES_PER_ROW; const N_STATE: usize = 16; const N_PARTIAL_ROUNDS: usize = 14; @@ -69,18 +69,18 @@ impl FrameworkEval for PoseidonEval { /// Applies the M4 MDS matrix described in 5.1. fn apply_m4(x: [F; 4]) -> [F; 4] where - F: Copy + AddAssign + Add + Sub + Mul, + F: Clone + AddAssign + Add + Sub + Mul, { - let t0 = x[0] + x[1]; - let t02 = t0 + t0; - let t1 = x[2] + x[3]; - let t12 = t1 + t1; - let t2 = x[1] + x[1] + t1; - let t3 = x[3] + x[3] + t0; - let t4 = t12 + t12 + t3; - let t5 = t02 + t02 + t2; - let t6 = t3 + t5; - let t7 = t2 + t4; + let t0 = x[0].clone() + x[1].clone(); + let t02 = t0.clone() + t0.clone(); + let t1 = x[2].clone() + x[3].clone(); + let t12 = t1.clone() + t1.clone(); + let t2 = x[1].clone() + x[1].clone() + t1; + let t3 = x[3].clone() + x[3].clone() + t0; + let t4 = t12.clone() + t12 + t3.clone(); + let t5 = t02.clone() + t02 + t2.clone(); + let t6 = t3 + t5.clone(); + let t7 = t2 + t4.clone(); [t6, t5, t7, t4] } @@ -88,7 +88,7 @@ where /// See 5.1 and Appendix B. fn apply_external_round_matrix(state: &mut [F; 16]) where - F: Copy + AddAssign + Add + Sub + Mul, + F: Clone + AddAssign + Add + Sub + Mul, { // Applies circ(2M4, M4, M4, M4). for i in 0..4 { @@ -98,16 +98,17 @@ where state[4 * i + 2], state[4 * i + 3], ] = apply_m4([ - state[4 * i], - state[4 * i + 1], - state[4 * i + 2], - state[4 * i + 3], + state[4 * i].clone(), + state[4 * i + 1].clone(), + state[4 * i + 2].clone(), + state[4 * i + 3].clone(), ]); } for j in 0..4 { - let s = state[j] + state[j + 4] + state[j + 8] + state[j + 12]; + let s = + state[j].clone() + state[j + 4].clone() + state[j + 8].clone() + state[j + 12].clone(); for i in 0..4 { - state[4 * i + j] += s; + state[4 * i + j] += s.clone(); } } } @@ -117,20 +118,22 @@ where // See 5.2. fn apply_internal_round_matrix(state: &mut [F; 16]) where - F: Copy + AddAssign + Add + Sub + Mul, + F: Clone + AddAssign + Add + Sub + Mul, { // TODO(spapini): Check that these coefficients are good according to section 5.3 of Poseidon2 // paper. - let sum = state[1..].iter().fold(state[0], |acc, s| acc + *s); + let sum = state[1..] + .iter() + .fold(state[0].clone(), |acc, s| acc + s.clone()); state.iter_mut().enumerate().for_each(|(i, s)| { // TODO(spapini): Change to rotations. - *s = *s * BaseField::from_u32_unchecked(1 << (i + 1)) + sum; + *s = s.clone() * BaseField::from_u32_unchecked(1 << (i + 1)) + sum.clone(); }); } fn pow5(x: F) -> F { - let x2 = x * x; - let x4 = x2 * x2; + let x2 = x.clone().square(); + let x4 = x2.square(); x4 * x } @@ -151,10 +154,10 @@ pub fn eval_poseidon_constraints( state[i] += EXTERNAL_ROUND_CONSTS[round][i]; }); apply_external_round_matrix(&mut state); - state = std::array::from_fn(|i| pow5(state[i])); + state = std::array::from_fn(|i| pow5(state[i].clone())); state.iter_mut().for_each(|s| { let m = eval.next_trace_mask(); - eval.add_constraint(*s - m); + eval.add_constraint(s.clone() - m.clone()); *s = m; }); }); @@ -163,9 +166,9 @@ pub fn eval_poseidon_constraints( (0..N_PARTIAL_ROUNDS).for_each(|round| { state[0] += INTERNAL_ROUND_CONSTS[round]; apply_internal_round_matrix(&mut state); - state[0] = pow5(state[0]); + state[0] = pow5(state[0].clone()); let m = eval.next_trace_mask(); - eval.add_constraint(state[0] - m); + eval.add_constraint(state[0].clone() - m.clone()); state[0] = m; }); @@ -175,23 +178,19 @@ pub fn eval_poseidon_constraints( state[i] += EXTERNAL_ROUND_CONSTS[round + N_HALF_FULL_ROUNDS][i]; }); apply_external_round_matrix(&mut state); - state = std::array::from_fn(|i| pow5(state[i])); + state = std::array::from_fn(|i| pow5(state[i].clone())); state.iter_mut().for_each(|s| { let m = eval.next_trace_mask(); - eval.add_constraint(*s - m); + eval.add_constraint(s.clone() - m.clone()); *s = m; }); }); // Provide state lookups. let final_state_denom: E::EF = lookup_elements.combine(&state); - // (1 / denom0) - (1 / denom1) = (denom1 - denom0) / (denom0 * denom1). logup.write_frac( eval, - Fraction::new( - final_state_denom - initial_state_denom, - initial_state_denom * final_state_denom, - ), + Reciprocal::new(final_state_denom) - Reciprocal::new(initial_state_denom), ); } diff --git a/crates/prover/src/examples/wide_fibonacci/mod.rs b/crates/prover/src/examples/wide_fibonacci/mod.rs index 7b0e9b766..7fe9bd093 100644 --- a/crates/prover/src/examples/wide_fibonacci/mod.rs +++ b/crates/prover/src/examples/wide_fibonacci/mod.rs @@ -35,7 +35,7 @@ impl FrameworkEval for WideFibonacciEval { let mut b = eval.next_trace_mask(); for _ in 2..N { let c = eval.next_trace_mask(); - eval.add_constraint(c - (a.square() + b.square())); + eval.add_constraint(c.clone() - (a.square() + b.square())); a = b; b = c; } diff --git a/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs b/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs index 72dc133a5..8aaa00ba1 100644 --- a/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs +++ b/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs @@ -108,12 +108,12 @@ fn eval_eq_constraints( // Check the initial value on half_coset0 and final value on half_coset1. // Combining these constraints is safe because `is_first` and `is_second` are never // non-zero at the same time on the trace. - let half_coset0_initial_check = (curr - mle_eval_point.eq_0_p) * is_first; - let half_coset1_final_check = (curr - mle_eval_point.eq_1_p) * is_second; + let half_coset0_initial_check = (curr.clone() - mle_eval_point.eq_0_p) * is_first; + let half_coset1_final_check = (curr.clone() - mle_eval_point.eq_1_p) * is_second; eval.add_constraint(half_coset0_initial_check + half_coset1_final_check); // Check all the steps. - eval.add_constraint(curr - next_next * carry_quotients_col_eval); + eval.add_constraint(curr.clone() - next_next * carry_quotients_col_eval); curr }