From cec714bb3e4292db20a8b4875c973baeb071d28d Mon Sep 17 00:00:00 2001 From: Yoichi Hirai Date: Wed, 18 Dec 2024 13:39:21 +0000 Subject: [PATCH] Allow preprocessed column definitions out of crate --- .../prover/src/constraint_framework/assert.rs | 2 + .../src/constraint_framework/component.rs | 27 ++- .../src/constraint_framework/cpu_domain.rs | 1 + .../constraint_framework/expr/evaluator.rs | 6 +- .../prover/src/constraint_framework/info.rs | 9 +- crates/prover/src/constraint_framework/mod.rs | 18 +- .../prover/src/constraint_framework/point.rs | 1 + .../preprocessed_columns.rs | 154 ++++++++++++++--- .../src/constraint_framework/simd_domain.rs | 1 + crates/prover/src/examples/blake/air.rs | 159 +++++++++++++----- .../examples/blake/xor_table/constraints.rs | 42 ++--- .../src/examples/blake/xor_table/mod.rs | 7 +- crates/prover/src/examples/plonk/mod.rs | 15 +- .../prover/src/examples/state_machine/mod.rs | 22 ++- 14 files changed, 333 insertions(+), 131 deletions(-) diff --git a/crates/prover/src/constraint_framework/assert.rs b/crates/prover/src/constraint_framework/assert.rs index 376ff80b1..6dba5c10b 100644 --- a/crates/prover/src/constraint_framework/assert.rs +++ b/crates/prover/src/constraint_framework/assert.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use num_traits::Zero; use super::logup::{LogupAtRow, LogupSums}; diff --git a/crates/prover/src/constraint_framework/component.rs b/crates/prover/src/constraint_framework/component.rs index 86f00609c..cfe4d6cbe 100644 --- a/crates/prover/src/constraint_framework/component.rs +++ b/crates/prover/src/constraint_framework/component.rs @@ -3,6 +3,7 @@ use std::collections::HashMap; use std::fmt::{self, Display, Formatter}; use std::iter::zip; use std::ops::Deref; +use std::sync::Arc; use itertools::Itertools; #[cfg(feature = "parallel")] @@ -11,7 +12,7 @@ use tracing::{span, Level}; use super::cpu_domain::CpuDomainEvaluator; use super::logup::LogupSums; -use super::preprocessed_columns::PreprocessedColumn; +use super::preprocessed_columns::PreprocessedColumnOps; use super::{ EvalAtRow, InfoEvaluator, PointEvaluator, SimdDomainEvaluator, PREPROCESSED_TRACE_IDX, }; @@ -49,7 +50,8 @@ pub struct TraceLocationAllocator { /// Mapping of tree index to next available column offset. next_tree_offsets: TreeVec, /// Mapping of preprocessed columns to their index. - preprocessed_columns: HashMap, + /// A preprocessed column implementation is indicated by its TypeId + preprocessed_columns: HashMap, usize>, /// Controls whether the preprocessed columns are dynamic or static (default=Dynamic). preprocessed_columns_allocation_mode: PreprocessedColumnsAllocationMode, } @@ -81,30 +83,39 @@ impl TraceLocationAllocator { } /// Create a new `TraceLocationAllocator` with fixed preprocessed columns setup. - pub fn new_with_preproccessed_columns(preprocessed_columns: &[PreprocessedColumn]) -> Self { + pub fn new_with_preproccessed_columns( + preprocessed_columns: &[Arc], + ) -> Self { Self { next_tree_offsets: Default::default(), preprocessed_columns: preprocessed_columns .iter() .enumerate() - .map(|(i, &col)| (col, i)) + .map(|(i, col)| (col.clone(), i)) .collect(), preprocessed_columns_allocation_mode: PreprocessedColumnsAllocationMode::Static, } } - pub const fn preprocessed_columns(&self) -> &HashMap { + pub const fn preprocessed_columns(&self) -> &HashMap, usize> { &self.preprocessed_columns } // validates that `self.preprocessed_columns` is consistent with // `preprocessed_columns`. // I.e. preprocessed_columns[i] == self.preprocessed_columns[i]. - pub fn validate_preprocessed_columns(&self, preprocessed_columns: &[PreprocessedColumn]) { + // The equality comparison uses the pointer comparison of the boxes. + pub fn validate_preprocessed_columns( + &self, + preprocessed_columns: &[Arc], + ) { assert_eq!(preprocessed_columns.len(), self.preprocessed_columns.len()); for (column, idx) in self.preprocessed_columns.iter() { - assert_eq!(Some(column), preprocessed_columns.get(*idx)); + assert!(match preprocessed_columns.get(*idx) { + Some(preprocessed_column) => preprocessed_column == column, + None => false, + },) } } } @@ -146,7 +157,7 @@ impl FrameworkComponent { let next_column = location_allocator.preprocessed_columns.len(); *location_allocator .preprocessed_columns - .entry(*col) + .entry(col.clone()) .or_insert_with(|| { if matches!( location_allocator.preprocessed_columns_allocation_mode, diff --git a/crates/prover/src/constraint_framework/cpu_domain.rs b/crates/prover/src/constraint_framework/cpu_domain.rs index 03089bd17..fa9bd55fd 100644 --- a/crates/prover/src/constraint_framework/cpu_domain.rs +++ b/crates/prover/src/constraint_framework/cpu_domain.rs @@ -1,4 +1,5 @@ use std::ops::Mul; +use std::sync::Arc; use num_traits::Zero; diff --git a/crates/prover/src/constraint_framework/expr/evaluator.rs b/crates/prover/src/constraint_framework/expr/evaluator.rs index 7fef20254..5fd60ce1e 100644 --- a/crates/prover/src/constraint_framework/expr/evaluator.rs +++ b/crates/prover/src/constraint_framework/expr/evaluator.rs @@ -1,8 +1,10 @@ +use std::sync::Arc; + use num_traits::Zero; use super::{BaseExpr, ExtExpr}; use crate::constraint_framework::expr::ColumnExpr; -use crate::constraint_framework::preprocessed_columns::PreprocessedColumn; +use crate::constraint_framework::preprocessed_columns::PreprocessedColumnOps; use crate::constraint_framework::{EvalAtRow, Relation, RelationEntry, INTERACTION_TRACE_IDX}; use crate::core::fields::m31; use crate::core::lookups::utils::Fraction; @@ -174,7 +176,7 @@ impl EvalAtRow for ExprEvaluator { intermediate } - fn get_preprocessed_column(&mut self, column: PreprocessedColumn) -> Self::F { + fn get_preprocessed_column(&mut self, column: Arc) -> Self::F { BaseExpr::Param(column.name().to_string()) } diff --git a/crates/prover/src/constraint_framework/info.rs b/crates/prover/src/constraint_framework/info.rs index f8a6257e3..2acb44659 100644 --- a/crates/prover/src/constraint_framework/info.rs +++ b/crates/prover/src/constraint_framework/info.rs @@ -2,11 +2,12 @@ use std::array; use std::cell::{RefCell, RefMut}; use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub}; use std::rc::Rc; +use std::sync::Arc; use num_traits::{One, Zero}; use super::logup::{LogupAtRow, LogupSums}; -use super::preprocessed_columns::PreprocessedColumn; +use super::preprocessed_columns::PreprocessedColumnOps; use super::{EvalAtRow, INTERACTION_TRACE_IDX}; use crate::constraint_framework::PREPROCESSED_TRACE_IDX; use crate::core::fields::m31::BaseField; @@ -22,14 +23,14 @@ use crate::core::pcs::TreeVec; pub struct InfoEvaluator { pub mask_offsets: TreeVec>>, pub n_constraints: usize, - pub preprocessed_columns: Vec, + pub preprocessed_columns: Vec>, pub logup: LogupAtRow, pub arithmetic_counts: ArithmeticCounts, } impl InfoEvaluator { pub fn new( log_size: u32, - preprocessed_columns: Vec, + preprocessed_columns: Vec>, logup_sums: LogupSums, ) -> Self { Self { @@ -70,7 +71,7 @@ impl EvalAtRow for InfoEvaluator { array::from_fn(|_| FieldCounter::one()) } - fn get_preprocessed_column(&mut self, column: PreprocessedColumn) -> Self::F { + fn get_preprocessed_column(&mut self, column: Arc) -> Self::F { self.preprocessed_columns.push(column); FieldCounter::one() } diff --git a/crates/prover/src/constraint_framework/mod.rs b/crates/prover/src/constraint_framework/mod.rs index 37baa6167..963f3f79e 100644 --- a/crates/prover/src/constraint_framework/mod.rs +++ b/crates/prover/src/constraint_framework/mod.rs @@ -13,13 +13,14 @@ mod simd_domain; use std::array; use std::fmt::Debug; use std::ops::{Add, AddAssign, Mul, Neg, Sub}; +use std::sync::Arc; pub use assert::{assert_constraints, AssertEvaluator}; pub use component::{FrameworkComponent, FrameworkEval, TraceLocationAllocator}; pub use info::InfoEvaluator; use num_traits::{One, Zero}; pub use point::PointEvaluator; -use preprocessed_columns::PreprocessedColumn; +use preprocessed_columns::PreprocessedColumnOps; pub use simd_domain::SimdDomainEvaluator; use crate::core::fields::m31::BaseField; @@ -87,7 +88,7 @@ pub trait EvalAtRow { mask_item } - fn get_preprocessed_column(&mut self, _column: PreprocessedColumn) -> Self::F { + fn get_preprocessed_column(&mut self, _column: Arc) -> Self::F { let [mask_item] = self.next_interaction_mask(PREPROCESSED_TRACE_IDX, [0]); mask_item } @@ -165,18 +166,15 @@ pub trait EvalAtRow { } } -/// Default implementation for evaluators that have an element called "logup" that works like a -/// LogupAtRow, where the logup functionality can be proxied. -/// TODO(alont): Remove once LogupAtRow is no longer used. macro_rules! logup_proxy { () => { fn write_logup_frac(&mut self, fraction: Fraction) { if self.logup.fracs.is_empty() { - self.logup.is_first = self.get_preprocessed_column( - crate::constraint_framework::preprocessed_columns::PreprocessedColumn::IsFirst( - self.logup.log_size, - ), - ); + self.logup.is_first = self.get_preprocessed_column(Arc::new( + crate::constraint_framework::preprocessed_columns::IsFirst { + log_size: self.logup.log_size, + }, + )); self.logup.is_finalized = false; } self.logup.fracs.push(fraction.clone()); diff --git a/crates/prover/src/constraint_framework/point.rs b/crates/prover/src/constraint_framework/point.rs index ea01c647d..5b3012900 100644 --- a/crates/prover/src/constraint_framework/point.rs +++ b/crates/prover/src/constraint_framework/point.rs @@ -1,4 +1,5 @@ use std::ops::Mul; +use std::sync::Arc; use super::logup::{LogupAtRow, LogupSums}; use super::{EvalAtRow, INTERACTION_TRACE_IDX}; diff --git a/crates/prover/src/constraint_framework/preprocessed_columns.rs b/crates/prover/src/constraint_framework/preprocessed_columns.rs index f196567dd..d74483f48 100644 --- a/crates/prover/src/constraint_framework/preprocessed_columns.rs +++ b/crates/prover/src/constraint_framework/preprocessed_columns.rs @@ -1,26 +1,133 @@ +use std::any::Any; +use std::fmt::Debug; +use std::hash::Hash; +use std::sync::Arc; + use num_traits::One; -use crate::core::backend::{Backend, Col, Column}; +use crate::core::backend::simd::SimdBackend; +use crate::core::backend::{Backend, Col, Column, CpuBackend}; use crate::core::fields::m31::BaseField; use crate::core::poly::circle::{CanonicCoset, CircleEvaluation}; use crate::core::poly::BitReversedOrder; use crate::core::utils::{bit_reverse_index, coset_index_to_circle_domain_index}; -// TODO(ilya): Where should this enum be placed? +/// XorTable, etc will be implementation of this trait. +pub trait PreprocessedColumnOps: Debug + Any { + fn get_type_id(&self) -> std::any::TypeId { + self.type_id() + } + fn name(&self) -> &'static str; + fn log_size(&self) -> u32; + fn gen_preprocessed_column_cpu( + &self, + ) -> CircleEvaluation; + fn gen_preprocessed_column_simd( + &self, + ) -> CircleEvaluation; + fn as_bytes(&self) -> Vec; +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct IsFirst { + pub log_size: u32, +} + +impl PreprocessedColumnOps for IsFirst { + fn name(&self) -> &'static str { + "preprocessed.is_first" + } + fn log_size(&self) -> u32 { + self.log_size + } + fn gen_preprocessed_column_cpu( + &self, + ) -> CircleEvaluation { + gen_is_first(self.log_size) + } + fn gen_preprocessed_column_simd( + &self, + ) -> CircleEvaluation { + gen_is_first(self.log_size) + } + fn as_bytes(&self) -> Vec { + self.log_size.to_le_bytes().to_vec() + } +} + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub enum PreprocessedColumn { - XorTable(u32, u32, usize), - IsFirst(u32), - Plonk(usize), +pub struct XorTable { + pub elem_bits: u32, + pub expand_bits: u32, + pub kind: usize, } -impl PreprocessedColumn { - pub const fn name(&self) -> &'static str { - match self { - PreprocessedColumn::XorTable(..) => "preprocessed.xor_table", - PreprocessedColumn::IsFirst(_) => "preprocessed.is_first", - PreprocessedColumn::Plonk(_) => "preprocessed.plonk", - } +impl PreprocessedColumnOps for XorTable { + fn name(&self) -> &'static str { + "preprocessed.xor_table" + } + fn log_size(&self) -> u32 { + assert!(self.elem_bits >= self.expand_bits); + 2 * (self.elem_bits - self.expand_bits) + } + fn gen_preprocessed_column_cpu( + &self, + ) -> CircleEvaluation { + unimplemented!("XorTable is not supported.") + } + fn gen_preprocessed_column_simd( + &self, + ) -> CircleEvaluation { + unimplemented!("XorTable is not supported.") + } + fn as_bytes(&self) -> Vec { + let mut bytes = vec![]; + bytes.extend_from_slice(&self.elem_bits.to_le_bytes()); + bytes.extend_from_slice(&self.expand_bits.to_le_bytes()); + bytes.extend_from_slice(&self.kind.to_le_bytes()); + bytes + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct Plonk { + pub kind: u32, +} + +impl PreprocessedColumnOps for Plonk { + fn name(&self) -> &'static str { + "preprocessed.plonk" + } + fn log_size(&self) -> u32 { + unimplemented!("Plonk is not supported.") + } + fn gen_preprocessed_column_cpu( + &self, + ) -> CircleEvaluation { + unimplemented!("Plonk is not supported.") + } + fn gen_preprocessed_column_simd( + &self, + ) -> CircleEvaluation { + unimplemented!("Plonk is not supported.") + } + fn as_bytes(&self) -> Vec { + self.kind.to_le_bytes().to_vec() + } +} + +impl PartialEq for dyn PreprocessedColumnOps { + fn eq(&self, other: &Self) -> bool { + self.get_type_id() == other.get_type_id() && self.as_bytes() == other.as_bytes() + } +} + +impl Eq for dyn PreprocessedColumnOps {} + +impl Hash for dyn PreprocessedColumnOps { + fn hash(&self, state: &mut H) { + self.get_type_id().hash(state); + self.as_bytes().hash(state); } } @@ -54,19 +161,10 @@ pub fn gen_is_step_with_offset( CircleEvaluation::new(CanonicCoset::new(log_size).circle_domain(), col) } -pub fn gen_preprocessed_column( - preprocessed_column: &PreprocessedColumn, -) -> CircleEvaluation { - match preprocessed_column { - PreprocessedColumn::IsFirst(log_size) => gen_is_first(*log_size), - PreprocessedColumn::Plonk(_) | PreprocessedColumn::XorTable(..) => { - unimplemented!("eval_preprocessed_column: Plonk and XorTable are not supported.") - } - } -} - -pub fn gen_preprocessed_columns<'a, B: Backend>( - columns: impl Iterator, -) -> Vec> { - columns.map(gen_preprocessed_column).collect() +pub fn gen_preprocessed_columns_simd<'a>( + columns: impl Iterator>, +) -> Vec> { + columns + .map(|col| col.gen_preprocessed_column_simd()) + .collect() } diff --git a/crates/prover/src/constraint_framework/simd_domain.rs b/crates/prover/src/constraint_framework/simd_domain.rs index 65c52708c..76c7fd40c 100644 --- a/crates/prover/src/constraint_framework/simd_domain.rs +++ b/crates/prover/src/constraint_framework/simd_domain.rs @@ -1,4 +1,5 @@ use std::ops::Mul; +use std::sync::Arc; use num_traits::Zero; diff --git a/crates/prover/src/examples/blake/air.rs b/crates/prover/src/examples/blake/air.rs index 424e34f13..f08dac616 100644 --- a/crates/prover/src/examples/blake/air.rs +++ b/crates/prover/src/examples/blake/air.rs @@ -1,4 +1,5 @@ use std::simd::u32x16; +use std::sync::Arc; use itertools::{chain, multiunzip, Itertools}; use num_traits::Zero; @@ -8,7 +9,9 @@ use tracing::{span, Level}; use super::round::{blake_round_info, BlakeRoundComponent, BlakeRoundEval}; use super::scheduler::{BlakeSchedulerComponent, BlakeSchedulerEval}; use super::xor_table::{xor12, xor4, xor7, xor8, xor9}; -use crate::constraint_framework::preprocessed_columns::{gen_is_first, PreprocessedColumn}; +use crate::constraint_framework::preprocessed_columns::{ + self, gen_is_first, PreprocessedColumnOps, +}; use crate::constraint_framework::{TraceLocationAllocator, PREPROCESSED_TRACE_IDX}; use crate::core::air::{Component, ComponentProver}; use crate::core::backend::simd::m31::LOG_N_LANES; @@ -26,28 +29,100 @@ use crate::examples::blake::{ round, xor_table, BlakeXorElements, XorAccums, N_ROUNDS, ROUND_LOG_SPLIT, }; -const PREPROCESSED_XOR_COLUMNS: [PreprocessedColumn; 20] = [ - PreprocessedColumn::XorTable(12, 4, 0), - PreprocessedColumn::XorTable(12, 4, 1), - PreprocessedColumn::XorTable(12, 4, 2), - PreprocessedColumn::IsFirst(xor12::column_bits::<12, 4>()), - PreprocessedColumn::XorTable(9, 2, 0), - PreprocessedColumn::XorTable(9, 2, 1), - PreprocessedColumn::XorTable(9, 2, 2), - PreprocessedColumn::IsFirst(xor9::column_bits::<9, 2>()), - PreprocessedColumn::XorTable(8, 2, 0), - PreprocessedColumn::XorTable(8, 2, 1), - PreprocessedColumn::XorTable(8, 2, 2), - PreprocessedColumn::IsFirst(xor8::column_bits::<8, 2>()), - PreprocessedColumn::XorTable(7, 2, 0), - PreprocessedColumn::XorTable(7, 2, 1), - PreprocessedColumn::XorTable(7, 2, 2), - PreprocessedColumn::IsFirst(xor7::column_bits::<7, 2>()), - PreprocessedColumn::XorTable(4, 0, 0), - PreprocessedColumn::XorTable(4, 0, 1), - PreprocessedColumn::XorTable(4, 0, 2), - PreprocessedColumn::IsFirst(xor4::column_bits::<4, 0>()), -]; +fn preprocessed_xor_columns() -> [Arc; 20] { + [ + Arc::new(preprocessed_columns::XorTable { + elem_bits: 12, + expand_bits: 4, + kind: 0, + }), + Arc::new(preprocessed_columns::XorTable { + elem_bits: 12, + expand_bits: 4, + kind: 1, + }), + Arc::new(preprocessed_columns::XorTable { + elem_bits: 12, + expand_bits: 4, + kind: 2, + }), + Arc::new(preprocessed_columns::IsFirst { + log_size: xor12::column_bits::<12, 4>(), + }), + Arc::new(preprocessed_columns::XorTable { + elem_bits: 9, + expand_bits: 2, + kind: 0, + }), + Arc::new(preprocessed_columns::XorTable { + elem_bits: 9, + expand_bits: 2, + kind: 1, + }), + Arc::new(preprocessed_columns::XorTable { + elem_bits: 9, + expand_bits: 2, + kind: 2, + }), + Arc::new(preprocessed_columns::IsFirst { + log_size: xor9::column_bits::<9, 2>(), + }), + Arc::new(preprocessed_columns::XorTable { + elem_bits: 8, + expand_bits: 2, + kind: 0, + }), + Arc::new(preprocessed_columns::XorTable { + elem_bits: 8, + expand_bits: 2, + kind: 1, + }), + Arc::new(preprocessed_columns::XorTable { + elem_bits: 8, + expand_bits: 2, + kind: 2, + }), + Arc::new(preprocessed_columns::IsFirst { + log_size: xor8::column_bits::<8, 2>(), + }), + Arc::new(preprocessed_columns::XorTable { + elem_bits: 7, + expand_bits: 2, + kind: 0, + }), + Arc::new(preprocessed_columns::XorTable { + elem_bits: 7, + expand_bits: 2, + kind: 1, + }), + Arc::new(preprocessed_columns::XorTable { + elem_bits: 7, + expand_bits: 2, + kind: 2, + }), + Arc::new(preprocessed_columns::IsFirst { + log_size: xor7::column_bits::<7, 2>(), + }), + Arc::new(preprocessed_columns::XorTable { + elem_bits: 4, + expand_bits: 0, + kind: 0, + }), + Arc::new(preprocessed_columns::XorTable { + elem_bits: 4, + expand_bits: 0, + kind: 1, + }), + Arc::new(preprocessed_columns::XorTable { + elem_bits: 4, + expand_bits: 0, + kind: 2, + }), + Arc::new(preprocessed_columns::IsFirst { + log_size: xor4::column_bits::<4, 0>(), + }), + ] +} #[derive(Serialize)] pub struct BlakeStatement0 { @@ -86,12 +161,7 @@ impl BlakeStatement0 { log_sizes[PREPROCESSED_TRACE_IDX] = chain!( [scheduler_is_first_column_log_size], blake_round_is_first_column_log_sizes, - PREPROCESSED_XOR_COLUMNS.map(|column| match column { - PreprocessedColumn::XorTable(elem_bits, expand_bits, _) => - 2 * (elem_bits - expand_bits), - PreprocessedColumn::IsFirst(log_size) => log_size, - _ => panic!("Unexpected column"), - }), + preprocessed_xor_columns().map(|column| column.log_size()), ) .collect_vec(); @@ -164,19 +234,24 @@ impl BlakeComponents { fn new(stmt0: &BlakeStatement0, all_elements: &AllElements, stmt1: &BlakeStatement1) -> Self { let log_size = stmt0.log_size; - let scheduler_is_first_column = PreprocessedColumn::IsFirst(log_size); - let blake_round_is_first_columns_iter = ROUND_LOG_SPLIT - .iter() - .map(|l| PreprocessedColumn::IsFirst(log_size + l)); - - let tree_span_provider = &mut TraceLocationAllocator::new_with_preproccessed_columns( - &chain!( - [scheduler_is_first_column], - blake_round_is_first_columns_iter, - PREPROCESSED_XOR_COLUMNS, - ) - .collect_vec()[..], - ); + let scheduler_is_first_column: Arc = + Arc::new(preprocessed_columns::IsFirst { log_size }); + let blake_round_is_first_columns_iter = ROUND_LOG_SPLIT.iter().map(|l| { + let b: Arc = Arc::new(preprocessed_columns::IsFirst { + log_size: log_size + l, + }); + b + }); + + let columns: Vec> = chain!( + [scheduler_is_first_column], + blake_round_is_first_columns_iter, + preprocessed_xor_columns(), + ) + .collect_vec(); + + let tree_span_provider = + &mut TraceLocationAllocator::new_with_preproccessed_columns(&columns[..]); Self { scheduler_component: BlakeSchedulerComponent::new( diff --git a/crates/prover/src/examples/blake/xor_table/constraints.rs b/crates/prover/src/examples/blake/xor_table/constraints.rs index 60fef8bfe..25f0e031f 100644 --- a/crates/prover/src/examples/blake/xor_table/constraints.rs +++ b/crates/prover/src/examples/blake/xor_table/constraints.rs @@ -16,29 +16,29 @@ macro_rules! xor_table_eval { // al, bl are the constant columns for the inputs: All pairs of elements in [0, // 2^LIMB_BITS). // cl is the constant column for the xor: al ^ bl. - let al = self - .eval - .get_preprocessed_column(PreprocessedColumn::XorTable( - ELEM_BITS, - EXPAND_BITS, - 0, - )); + let al = + self.eval + .get_preprocessed_column(Arc::new(preprocessed_columns::XorTable { + elem_bits: ELEM_BITS, + expand_bits: EXPAND_BITS, + kind: 0, + })); - let bl = self - .eval - .get_preprocessed_column(PreprocessedColumn::XorTable( - ELEM_BITS, - EXPAND_BITS, - 1, - )); + let bl = + self.eval + .get_preprocessed_column(Arc::new(preprocessed_columns::XorTable { + elem_bits: ELEM_BITS, + expand_bits: EXPAND_BITS, + kind: 1, + })); - let cl = self - .eval - .get_preprocessed_column(PreprocessedColumn::XorTable( - ELEM_BITS, - EXPAND_BITS, - 2, - )); + let cl = + self.eval + .get_preprocessed_column(Arc::new(preprocessed_columns::XorTable { + elem_bits: ELEM_BITS, + expand_bits: EXPAND_BITS, + kind: 2, + })); for i in (0..(1 << (2 * EXPAND_BITS))) { let (i, j) = ((i >> EXPAND_BITS) as u32, (i % (1 << EXPAND_BITS)) as u32); diff --git a/crates/prover/src/examples/blake/xor_table/mod.rs b/crates/prover/src/examples/blake/xor_table/mod.rs index e653e0bb9..bc5092c4a 100644 --- a/crates/prover/src/examples/blake/xor_table/mod.rs +++ b/crates/prover/src/examples/blake/xor_table/mod.rs @@ -14,16 +14,17 @@ mod constraints; mod gen; use std::simd::u32x16; +use std::sync::Arc; use itertools::Itertools; use num_traits::Zero; use tracing::{span, Level}; use crate::constraint_framework::logup::{LogupAtRow, LogupTraceGenerator}; -use crate::constraint_framework::preprocessed_columns::{gen_is_first, PreprocessedColumn}; +use crate::constraint_framework::preprocessed_columns::gen_is_first; use crate::constraint_framework::{ - relation, EvalAtRow, FrameworkComponent, FrameworkEval, InfoEvaluator, Relation, RelationEntry, - INTERACTION_TRACE_IDX, PREPROCESSED_TRACE_IDX, + preprocessed_columns, relation, EvalAtRow, FrameworkComponent, FrameworkEval, InfoEvaluator, + Relation, RelationEntry, INTERACTION_TRACE_IDX, PREPROCESSED_TRACE_IDX, }; use crate::core::backend::simd::column::BaseColumn; use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; diff --git a/crates/prover/src/examples/plonk/mod.rs b/crates/prover/src/examples/plonk/mod.rs index a1e0362c9..78e88f670 100644 --- a/crates/prover/src/examples/plonk/mod.rs +++ b/crates/prover/src/examples/plonk/mod.rs @@ -1,9 +1,11 @@ +use std::sync::Arc; + use itertools::Itertools; use num_traits::One; use tracing::{span, Level}; use crate::constraint_framework::logup::{ClaimedPrefixSum, LogupTraceGenerator, LookupElements}; -use crate::constraint_framework::preprocessed_columns::{gen_is_first, PreprocessedColumn}; +use crate::constraint_framework::preprocessed_columns::{self, gen_is_first}; use crate::constraint_framework::{ assert_constraints, relation, EvalAtRow, FrameworkComponent, FrameworkEval, RelationEntry, TraceLocationAllocator, @@ -49,12 +51,15 @@ impl FrameworkEval for PlonkEval { } fn evaluate(&self, mut eval: E) -> E { - let a_wire = eval.get_preprocessed_column(PreprocessedColumn::Plonk(0)); - let b_wire = eval.get_preprocessed_column(PreprocessedColumn::Plonk(1)); + let a_wire = + eval.get_preprocessed_column(Arc::new(preprocessed_columns::Plonk { kind: 0 })); + let b_wire = + eval.get_preprocessed_column(Arc::new(preprocessed_columns::Plonk { kind: 1 })); // Note: c_wire could also be implicit: (self.eval.point() - M31_CIRCLE_GEN.into_ef()).x. // A constant column is easier though. - let c_wire = eval.get_preprocessed_column(PreprocessedColumn::Plonk(2)); - let op = eval.get_preprocessed_column(PreprocessedColumn::Plonk(3)); + let c_wire = + eval.get_preprocessed_column(Arc::new(preprocessed_columns::Plonk { kind: 2 })); + let op = eval.get_preprocessed_column(Arc::new(preprocessed_columns::Plonk { kind: 3 })); let mult = eval.next_trace_mask(); let a_val = eval.next_trace_mask(); diff --git a/crates/prover/src/examples/state_machine/mod.rs b/crates/prover/src/examples/state_machine/mod.rs index 684f04f76..3efe30d44 100644 --- a/crates/prover/src/examples/state_machine/mod.rs +++ b/crates/prover/src/examples/state_machine/mod.rs @@ -1,5 +1,10 @@ +use std::sync::Arc; + +use crate::constraint_framework::preprocessed_columns::{ + gen_preprocessed_columns_simd, PreprocessedColumnOps, +}; use crate::constraint_framework::relation_tracker::RelationSummary; -use crate::constraint_framework::Relation; +use crate::constraint_framework::{preprocessed_columns, Relation}; pub mod components; pub mod gen; @@ -11,9 +16,6 @@ use components::{ use gen::{gen_interaction_trace, gen_trace}; use itertools::{chain, Itertools}; -use crate::constraint_framework::preprocessed_columns::{ - gen_preprocessed_columns, PreprocessedColumn, -}; use crate::constraint_framework::TraceLocationAllocator; use crate::core::backend::simd::m31::LOG_N_LANES; use crate::core::backend::simd::SimdBackend; @@ -59,13 +61,17 @@ pub fn prove_state_machine( let mut commitment_scheme = CommitmentSchemeProver::<_, Blake2sMerkleChannel>::new(config, &twiddles); - let preprocessed_columns = [ - PreprocessedColumn::IsFirst(x_axis_log_rows), - PreprocessedColumn::IsFirst(y_axis_log_rows), + let preprocessed_columns: [Arc; 2] = [ + Arc::new(preprocessed_columns::IsFirst { + log_size: x_axis_log_rows, + }), + Arc::new(preprocessed_columns::IsFirst { + log_size: y_axis_log_rows, + }), ]; // Preprocessed trace. - let preprocessed_trace = gen_preprocessed_columns(preprocessed_columns.iter()); + let preprocessed_trace = gen_preprocessed_columns_simd(preprocessed_columns.iter().cloned()); // Trace. let trace_op0 = gen_trace(x_axis_log_rows, initial_state, 0);