diff --git a/crates/prover/src/constraint_framework/component.rs b/crates/prover/src/constraint_framework/component.rs index ec61d7615..fc7ebfa3e 100644 --- a/crates/prover/src/constraint_framework/component.rs +++ b/crates/prover/src/constraint_framework/component.rs @@ -332,6 +332,11 @@ impl ComponentProver for FrameworkComponen .step_by(CHUNK_SIZE) .zip(col.chunks_mut(CHUNK_SIZE)); + // Define any `self` values outside the loop to prevent the compiler thinking there is a + // `Sync` requirement on `Self`. + let self_eval = &self.eval; + let self_logup_sums = self.logup_sums; + iter.for_each(|(chunk_idx, mut chunk)| { let trace_cols = trace.as_cols_ref().map_cols(|c| c.as_ref()); @@ -344,10 +349,10 @@ impl ComponentProver for FrameworkComponen &accum.random_coeff_powers, trace_domain.log_size(), eval_domain.log_size(), - self.eval.log_size(), - self.logup_sums, + self_eval.log_size(), + self_logup_sums, ); - let row_res = self.eval.evaluate(eval).row_res; + let row_res = self_eval.evaluate(eval).row_res; // Finalize row. unsafe { diff --git a/crates/prover/src/constraint_framework/expr.rs b/crates/prover/src/constraint_framework/expr.rs index 41886f7c0..d83a5250e 100644 --- a/crates/prover/src/constraint_framework/expr.rs +++ b/crates/prover/src/constraint_framework/expr.rs @@ -145,6 +145,13 @@ impl FieldExpOps for Expr { } } +impl Add for Expr { + type Output = Self; + fn add(self, rhs: BaseField) -> Self { + self + Expr::from(rhs) + } +} + impl Mul for Expr { type Output = Self; fn mul(self, rhs: BaseField) -> Self { diff --git a/crates/prover/src/constraint_framework/info.rs b/crates/prover/src/constraint_framework/info.rs index 2d2831b44..f8a6257e3 100644 --- a/crates/prover/src/constraint_framework/info.rs +++ b/crates/prover/src/constraint_framework/info.rs @@ -1,6 +1,9 @@ -use std::ops::Mul; +use std::array; +use std::cell::{RefCell, RefMut}; +use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub}; +use std::rc::Rc; -use num_traits::One; +use num_traits::{One, Zero}; use super::logup::{LogupAtRow, LogupSums}; use super::preprocessed_columns::PreprocessedColumn; @@ -8,17 +11,20 @@ use super::{EvalAtRow, INTERACTION_TRACE_IDX}; use crate::constraint_framework::PREPROCESSED_TRACE_IDX; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; +use crate::core::fields::FieldExpOps; use crate::core::lookups::utils::Fraction; use crate::core::pcs::TreeVec; /// Collects information about the constraints. -/// This includes mask offsets and columns at each interaction, and the number of constraints. +/// This includes mask offsets and columns at each interaction, the number of constraints and number +/// of arithmetic operations. #[derive(Default)] pub struct InfoEvaluator { pub mask_offsets: TreeVec>>, pub n_constraints: usize, pub preprocessed_columns: Vec, pub logup: LogupAtRow, + pub arithmetic_counts: ArithmeticCounts, } impl InfoEvaluator { pub fn new( @@ -31,6 +37,7 @@ impl InfoEvaluator { n_constraints: Default::default(), preprocessed_columns, logup: LogupAtRow::new(INTERACTION_TRACE_IDX, logup_sums.0, logup_sums.1, log_size), + arithmetic_counts: Default::default(), } } @@ -41,8 +48,8 @@ impl InfoEvaluator { } } impl EvalAtRow for InfoEvaluator { - type F = BaseField; - type EF = SecureField; + type F = FieldCounter; + type EF = ExtensionFieldCounter; fn next_interaction_mask( &mut self, @@ -60,24 +67,390 @@ impl EvalAtRow for InfoEvaluator { self.mask_offsets.resize(interaction + 1, vec![]); } self.mask_offsets[interaction].push(offsets.into_iter().collect()); - [BaseField::one(); N] + array::from_fn(|_| FieldCounter::one()) } fn get_preprocessed_column(&mut self, column: PreprocessedColumn) -> Self::F { self.preprocessed_columns.push(column); - BaseField::one() + FieldCounter::one() } - fn add_constraint(&mut self, _constraint: G) + fn add_constraint(&mut self, constraint: G) where Self::EF: Mul, { + let lin_combination = + ExtensionFieldCounter::one() + ExtensionFieldCounter::one() * constraint; + self.arithmetic_counts.merge(lin_combination.drain()); self.n_constraints += 1; } - fn combine_ef(_values: [Self::F; 4]) -> Self::EF { - SecureField::one() + fn combine_ef(values: [Self::F; 4]) -> Self::EF { + let mut res = ExtensionFieldCounter::zero(); + values.map(|v| res.merge(v)); + res } super::logup_proxy!(); } + +/// Stores a count of field operations. +#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)] +pub struct ArithmeticCounts { + /// Number of [`EvalAtRow::EF`] * [`EvalAtRow::EF`] operations. + pub n_ef_mul_ef: usize, + /// Number of [`EvalAtRow::EF`] * [`EvalAtRow::F`] operations. + pub n_ef_mul_f: usize, + /// Number of [`EvalAtRow::F`] * [`BaseField`] operations. + pub n_ef_mul_base_field: usize, + /// Number of [`EvalAtRow::EF`] + [`EvalAtRow::EF`] operations. + pub n_ef_add_ef: usize, + /// Number of [`EvalAtRow::EF`] + [`EvalAtRow::F`] operations. + pub n_ef_add_f: usize, + /// Number of [`EvalAtRow::EF`] * [`BaseField`] operations. + pub n_ef_add_base_field: usize, + /// Number of [`EvalAtRow::F`] * [`EvalAtRow::F`] operations. + pub n_f_mul_f: usize, + /// Number of [`EvalAtRow::F`] * [`BaseField`] operations. + pub n_f_mul_base_field: usize, + /// Number of [`EvalAtRow::F`] + [`EvalAtRow::F`] operations. + pub n_f_add_f: usize, + /// Number of [`EvalAtRow::F`] + [`BaseField`] operations. + pub n_f_add_base_field: usize, +} + +impl ArithmeticCounts { + fn merge(&mut self, other: ArithmeticCounts) { + let Self { + n_ef_mul_ef, + n_ef_mul_f, + n_ef_mul_base_field, + n_ef_add_ef, + n_ef_add_f, + n_ef_add_base_field, + n_f_mul_f, + n_f_mul_base_field, + n_f_add_f, + n_f_add_base_field, + } = self; + + *n_ef_mul_ef += other.n_ef_mul_ef; + *n_ef_mul_f += other.n_ef_mul_f; + *n_ef_mul_base_field += other.n_ef_mul_base_field; + *n_ef_add_f += other.n_ef_add_f; + *n_ef_add_base_field += other.n_ef_add_base_field; + *n_ef_add_ef += other.n_ef_add_ef; + *n_f_mul_f += other.n_f_mul_f; + *n_f_mul_base_field += other.n_f_mul_base_field; + *n_f_add_f += other.n_f_add_f; + *n_f_add_base_field += other.n_f_add_base_field; + } +} + +#[derive(Debug, Default, Clone)] +pub struct ArithmeticCounter(Rc>); + +/// Counts operations on [`EvalAtRow::F`]. +pub type FieldCounter = ArithmeticCounter; + +/// Counts operations on [`EvalAtRow::EF`]. +pub type ExtensionFieldCounter = ArithmeticCounter; + +impl ArithmeticCounter { + fn merge( + &mut self, + other: ArithmeticCounter, + ) { + // Skip if they come from the same source. + if Rc::ptr_eq(&self.0, &other.0) { + return; + } + + self.counts().merge(other.drain()); + } + + fn drain(self) -> ArithmeticCounts { + self.0.take() + } + + fn counts(&mut self) -> RefMut<'_, ArithmeticCounts> { + self.0.borrow_mut() + } +} + +impl Zero for ArithmeticCounter { + fn zero() -> Self { + Self::default() + } + + fn is_zero(&self) -> bool { + // TODO(andrew): Consider removing Zero from EvalAtRow::F, EvalAtRow::EF since is_zero + // doesn't make sense. Creating zero elements does though. + panic!() + } +} + +impl One for ArithmeticCounter { + fn one() -> Self { + Self::default() + } +} + +impl Add for ArithmeticCounter { + type Output = Self; + + fn add(mut self, rhs: Self) -> Self { + self.merge(rhs); + { + let mut counts = self.counts(); + match IS_EXT_FIELD { + true => counts.n_ef_add_ef += 1, + false => counts.n_f_add_f += 1, + } + } + self + } +} + +impl Sub for ArithmeticCounter { + type Output = Self; + + #[allow(clippy::suspicious_arithmetic_impl)] + fn sub(self, rhs: Self) -> Self { + // Treat as addition. + self + rhs + } +} + +impl Add for ExtensionFieldCounter { + type Output = Self; + + fn add(mut self, rhs: FieldCounter) -> Self { + self.merge(rhs); + self.counts().n_ef_add_f += 1; + self + } +} + +impl Mul for ArithmeticCounter { + type Output = Self; + + fn mul(mut self, rhs: Self) -> Self { + self.merge(rhs); + { + let mut counts = self.counts(); + match IS_EXT_FIELD { + true => counts.n_ef_mul_ef += 1, + false => counts.n_f_mul_f += 1, + } + } + self + } +} + +impl Mul for ExtensionFieldCounter { + type Output = ExtensionFieldCounter; + + #[allow(clippy::suspicious_arithmetic_impl)] + fn mul(mut self, rhs: FieldCounter) -> Self { + self.merge(rhs); + self.counts().n_ef_mul_f += 1; + self + } +} + +impl MulAssign for ArithmeticCounter { + fn mul_assign(&mut self, rhs: Self) { + *self = self.clone() * rhs + } +} + +impl AddAssign for ArithmeticCounter { + fn add_assign(&mut self, rhs: Self) { + *self = self.clone() + rhs + } +} + +impl AddAssign for FieldCounter { + fn add_assign(&mut self, _rhs: BaseField) { + self.counts().n_f_add_base_field += 1; + } +} + +impl Mul for FieldCounter { + type Output = Self; + + #[allow(clippy::suspicious_arithmetic_impl)] + fn mul(mut self, _rhs: BaseField) -> Self { + self.counts().n_f_mul_base_field += 1; + self + } +} + +impl Mul for ExtensionFieldCounter { + type Output = Self; + + #[allow(clippy::suspicious_arithmetic_impl)] + fn mul(mut self, _rhs: BaseField) -> Self { + self.counts().n_ef_mul_base_field += 1; + self + } +} + +impl Mul for ExtensionFieldCounter { + type Output = Self; + + #[allow(clippy::suspicious_arithmetic_impl)] + fn mul(self, _rhs: SecureField) -> Self { + self * ExtensionFieldCounter::zero() + } +} + +impl Add for FieldCounter { + type Output = ExtensionFieldCounter; + + fn add(self, _rhs: SecureField) -> ExtensionFieldCounter { + ExtensionFieldCounter::zero() + self + } +} + +impl Add for ExtensionFieldCounter { + type Output = Self; + + fn add(mut self, _rhs: BaseField) -> Self { + self.counts().n_ef_add_base_field += 1; + self + } +} + +impl Add for ExtensionFieldCounter { + type Output = Self; + + fn add(self, _rhs: SecureField) -> Self { + self + ExtensionFieldCounter::zero() + } +} + +impl Sub for ExtensionFieldCounter { + type Output = Self; + + #[allow(clippy::suspicious_arithmetic_impl)] + fn sub(self, rhs: SecureField) -> Self { + // Tread subtraction as addition + self + rhs + } +} + +impl Mul for FieldCounter { + type Output = ExtensionFieldCounter; + + fn mul(self, _rhs: SecureField) -> ExtensionFieldCounter { + ExtensionFieldCounter::zero() * self + } +} + +impl From for FieldCounter { + fn from(_value: BaseField) -> Self { + Self::one() + } +} + +impl From for ExtensionFieldCounter { + fn from(_value: SecureField) -> Self { + Self::one() + } +} + +impl From for ExtensionFieldCounter { + fn from(value: FieldCounter) -> Self { + Self(value.0) + } +} + +impl Neg for ArithmeticCounter { + type Output = Self; + + fn neg(self) -> Self { + // Treat as addition. + self + ArithmeticCounter::::zero() + } +} + +impl FieldExpOps for ArithmeticCounter { + fn inverse(&self) -> Self { + todo!() + } +} + +#[cfg(test)] +mod tests { + use num_traits::{One, Zero}; + + use super::ExtensionFieldCounter; + use crate::constraint_framework::info::{ArithmeticCounts, FieldCounter}; + use crate::core::fields::m31::BaseField; + use crate::core::fields::qm31::SecureField; + + #[test] + fn test_arithmetic_counter() { + const N_EF_MUL_EF: usize = 1; + const N_EF_MUL_F: usize = 2; + const N_EF_MUL_BASE_FIELD: usize = 3; + const N_EF_MUL_ASSIGN_EF: usize = 4; + const N_EF_MUL_SECURE_FIELD: usize = 5; + const N_EF_ADD_EF: usize = 6; + const N_EF_ADD_ASSIGN_EF: usize = 7; + const N_EF_ADD_F: usize = 8; + const N_EF_NEG: usize = 9; + const N_EF_SUB_EF: usize = 10; + const N_EF_ADD_BASE_FIELD: usize = 11; + const N_F_MUL_F: usize = 12; + const N_F_MUL_ASSIGN_F: usize = 13; + const N_F_MUL_BASE_FIELD: usize = 14; + const N_F_ADD_F: usize = 15; + const N_F_ADD_ASSIGN_F: usize = 16; + const N_F_ADD_ASSIGN_BASE_FIELD: usize = 17; + const N_F_NEG: usize = 18; + const N_F_SUB_F: usize = 19; + let mut ef = ExtensionFieldCounter::zero(); + let mut f = FieldCounter::zero(); + + (0..N_EF_MUL_EF).for_each(|_| ef = ef.clone() * ef.clone()); + (0..N_EF_MUL_F).for_each(|_| ef = ef.clone() * f.clone()); + (0..N_EF_MUL_BASE_FIELD).for_each(|_| ef = ef.clone() * BaseField::one()); + (0..N_EF_MUL_SECURE_FIELD).for_each(|_| ef = ef.clone() * SecureField::one()); + (0..N_EF_MUL_ASSIGN_EF).for_each(|_| ef *= ef.clone()); + (0..N_EF_ADD_EF).for_each(|_| ef = ef.clone() + ef.clone()); + (0..N_EF_ADD_ASSIGN_EF).for_each(|_| ef += ef.clone()); + (0..N_EF_ADD_F).for_each(|_| ef = ef.clone() + f.clone()); + (0..N_EF_ADD_BASE_FIELD).for_each(|_| ef = ef.clone() + BaseField::one()); + (0..N_EF_NEG).for_each(|_| ef = -ef.clone()); + (0..N_EF_SUB_EF).for_each(|_| ef = ef.clone() - ef.clone()); + (0..N_F_MUL_F).for_each(|_| f = f.clone() * f.clone()); + (0..N_F_MUL_ASSIGN_F).for_each(|_| f *= f.clone()); + (0..N_F_MUL_BASE_FIELD).for_each(|_| f = f.clone() * BaseField::one()); + (0..N_F_ADD_F).for_each(|_| f = f.clone() + f.clone()); + (0..N_F_ADD_ASSIGN_F).for_each(|_| f += f.clone()); + (0..N_F_ADD_ASSIGN_BASE_FIELD).for_each(|_| f += BaseField::one()); + (0..N_F_NEG).for_each(|_| f = -f.clone()); + (0..N_F_SUB_F).for_each(|_| f = f.clone() - f.clone()); + let mut res = f.drain(); + res.merge(ef.drain()); + + assert_eq!( + res, + ArithmeticCounts { + n_ef_mul_ef: N_EF_MUL_EF + N_EF_MUL_SECURE_FIELD + N_EF_MUL_ASSIGN_EF, + n_ef_mul_base_field: N_EF_MUL_BASE_FIELD, + n_ef_mul_f: N_EF_MUL_F, + n_ef_add_ef: N_EF_ADD_EF + N_EF_NEG + N_EF_SUB_EF + N_EF_ADD_ASSIGN_EF, + n_ef_add_f: N_EF_ADD_F, + n_ef_add_base_field: N_EF_ADD_BASE_FIELD, + n_f_mul_f: N_F_MUL_F + N_F_MUL_ASSIGN_F, + n_f_mul_base_field: N_F_MUL_BASE_FIELD, + n_f_add_f: N_F_ADD_F + N_F_NEG + N_F_SUB_F + N_F_ADD_ASSIGN_F, + n_f_add_base_field: N_F_ADD_ASSIGN_BASE_FIELD, + } + ); + } +} diff --git a/crates/prover/src/constraint_framework/mod.rs b/crates/prover/src/constraint_framework/mod.rs index 9675e4dbc..aa871b5bc 100644 --- a/crates/prover/src/constraint_framework/mod.rs +++ b/crates/prover/src/constraint_framework/mod.rs @@ -61,6 +61,8 @@ pub trait EvalAtRow { + From + Neg + AddAssign + + Add + + Mul + Add + Sub + Mul diff --git a/crates/prover/src/core/backend/simd/qm31.rs b/crates/prover/src/core/backend/simd/qm31.rs index 13d03ce39..078f6ef56 100644 --- a/crates/prover/src/core/backend/simd/qm31.rs +++ b/crates/prover/src/core/backend/simd/qm31.rs @@ -8,6 +8,7 @@ use rand::distributions::{Distribution, Standard}; use super::cm31::PackedCM31; use super::m31::{PackedM31, N_LANES}; +use crate::core::fields::m31::M31; use crate::core::fields::qm31::QM31; use crate::core::fields::FieldExpOps; @@ -231,6 +232,24 @@ impl Mul for PackedQM31 { } } +impl Mul for PackedQM31 { + type Output = Self; + + #[inline(always)] + fn mul(self, rhs: M31) -> Self::Output { + self * PackedM31::broadcast(rhs) + } +} + +impl Add for PackedQM31 { + type Output = Self; + + #[inline(always)] + fn add(self, rhs: M31) -> Self::Output { + self + PackedM31::broadcast(rhs) + } +} + impl SubAssign for PackedQM31 { fn sub_assign(&mut self, rhs: Self) { *self = *self - rhs; diff --git a/crates/prover/src/core/circle.rs b/crates/prover/src/core/circle.rs index 0fbdd64b4..8cfe48ab8 100644 --- a/crates/prover/src/core/circle.rs +++ b/crates/prover/src/core/circle.rs @@ -104,7 +104,7 @@ impl + FieldExpOps + Sub + Neg } } - pub fn into_ef>(&self) -> CirclePoint { + pub fn into_ef>(self) -> CirclePoint { CirclePoint { x: self.x.clone().into(), y: self.y.clone().into(), diff --git a/crates/prover/src/core/lookups/utils.rs b/crates/prover/src/core/lookups/utils.rs index 035579e5f..ed67477f7 100644 --- a/crates/prover/src/core/lookups/utils.rs +++ b/crates/prover/src/core/lookups/utils.rs @@ -195,14 +195,12 @@ where } /// Projective fraction. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Copy)] pub struct Fraction { pub numerator: N, pub denominator: D, } -impl Copy for Fraction {} - impl Fraction { pub fn new(numerator: N, denominator: D) -> Self { Self { @@ -212,17 +210,15 @@ impl Fraction { } } -impl< - N: Clone, - D: Add + Add + Mul + Mul + Clone, - > Add for Fraction +impl + Add + Mul + Mul + Clone> Add + for Fraction { type Output = Fraction; fn add(self, rhs: Self) -> Fraction { Fraction { - numerator: rhs.denominator.clone() * self.numerator.clone() - + self.denominator.clone() * rhs.numerator.clone(), + numerator: rhs.denominator.clone() * self.numerator + + self.denominator.clone() * rhs.numerator, denominator: self.denominator * rhs.denominator, } } diff --git a/crates/prover/src/examples/blake/mod.rs b/crates/prover/src/examples/blake/mod.rs index 70c937f43..ff62f9f7d 100644 --- a/crates/prover/src/examples/blake/mod.rs +++ b/crates/prover/src/examples/blake/mod.rs @@ -149,8 +149,7 @@ where + Sub + Mul, { - #[allow(clippy::wrong_self_convention)] - fn to_felts(self) -> [F; 2] { + fn into_felts(self) -> [F; 2] { [self.l, self.h] } } diff --git a/crates/prover/src/examples/blake/round/constraints.rs b/crates/prover/src/examples/blake/round/constraints.rs index b4568db80..ada5fb287 100644 --- a/crates/prover/src/examples/blake/round/constraints.rs +++ b/crates/prover/src/examples/blake/round/constraints.rs @@ -69,9 +69,9 @@ impl<'a, E: EvalAtRow> BlakeRoundEval<'a, E> { self.round_lookup_elements, -E::EF::one(), &chain![ - input_v.iter().cloned().flat_map(Fu32::to_felts), - v.iter().cloned().flat_map(Fu32::to_felts), - m.iter().cloned().flat_map(Fu32::to_felts) + input_v.iter().cloned().flat_map(Fu32::into_felts), + v.iter().cloned().flat_map(Fu32::into_felts), + m.iter().cloned().flat_map(Fu32::into_felts) ] .collect_vec(), )]); diff --git a/crates/prover/src/examples/blake/scheduler/constraints.rs b/crates/prover/src/examples/blake/scheduler/constraints.rs index fb2f33849..1bf93d1aa 100644 --- a/crates/prover/src/examples/blake/scheduler/constraints.rs +++ b/crates/prover/src/examples/blake/scheduler/constraints.rs @@ -24,9 +24,9 @@ pub fn eval_blake_scheduler_constraints( let output_state = &states[idx + 1]; let round_messages = SIGMA[idx].map(|k| messages[k as usize].clone()); chain![ - input_state.iter().cloned().flat_map(Fu32::to_felts), - output_state.iter().cloned().flat_map(Fu32::to_felts), - round_messages.iter().cloned().flat_map(Fu32::to_felts) + input_state.iter().cloned().flat_map(Fu32::into_felts), + output_state.iter().cloned().flat_map(Fu32::into_felts), + round_messages.iter().cloned().flat_map(Fu32::into_felts) ] .collect_vec() }); @@ -44,9 +44,9 @@ pub fn eval_blake_scheduler_constraints( blake_lookup_elements, E::EF::zero(), &chain![ - input_state.iter().cloned().flat_map(Fu32::to_felts), - output_state.iter().cloned().flat_map(Fu32::to_felts), - messages.iter().cloned().flat_map(Fu32::to_felts) + input_state.iter().cloned().flat_map(Fu32::into_felts), + output_state.iter().cloned().flat_map(Fu32::into_felts), + messages.iter().cloned().flat_map(Fu32::into_felts) ] .collect_vec(), )]); diff --git a/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs b/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs index 46e608bb4..5b1b78aeb 100644 --- a/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs +++ b/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs @@ -401,10 +401,10 @@ fn mle_eval_info(interaction: usize, n_variables: usize) -> InfoEvaluator { let mut eval = InfoEvaluator::empty(); let mle_eval_point = MleEvalPoint::new(&vec![SecureField::from(2); n_variables]); let mle_claim_shift = SecureField::zero(); - let mle_coeffs_col_eval = SecureField::zero(); - let carry_quotients_col_eval = SecureField::zero(); - let is_first = BaseField::zero(); - let is_second = BaseField::zero(); + let mle_coeffs_col_eval = SecureField::zero().into(); + let carry_quotients_col_eval = SecureField::zero().into(); + let is_first = BaseField::zero().into(); + let is_second = BaseField::zero().into(); eval_mle_eval_constraints( interaction, &mut eval,