Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow out-of-crate preprocessed column definitions #939

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions crates/prover/src/constraint_framework/assert.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::Arc;

use num_traits::Zero;

use super::logup::{LogupAtRow, LogupSums};
Expand Down
27 changes: 19 additions & 8 deletions crates/prover/src/constraint_framework/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::collections::HashMap;
use std::fmt::{self, Display, Formatter};
use std::iter::zip;
use std::ops::Deref;
use std::sync::Arc;

use itertools::Itertools;
#[cfg(feature = "parallel")]
Expand All @@ -11,7 +12,7 @@ use tracing::{span, Level};

use super::cpu_domain::CpuDomainEvaluator;
use super::logup::LogupSums;
use super::preprocessed_columns::PreprocessedColumn;
use super::preprocessed_columns::PreprocessedColumnOps;
use super::{
EvalAtRow, InfoEvaluator, PointEvaluator, SimdDomainEvaluator, PREPROCESSED_TRACE_IDX,
};
Expand Down Expand Up @@ -49,7 +50,8 @@ pub struct TraceLocationAllocator {
/// Mapping of tree index to next available column offset.
next_tree_offsets: TreeVec<usize>,
/// Mapping of preprocessed columns to their index.
preprocessed_columns: HashMap<PreprocessedColumn, usize>,
/// A preprocessed column implementation is indicated by its TypeId
preprocessed_columns: HashMap<Arc<dyn PreprocessedColumnOps>, usize>,
/// Controls whether the preprocessed columns are dynamic or static (default=Dynamic).
preprocessed_columns_allocation_mode: PreprocessedColumnsAllocationMode,
}
Expand Down Expand Up @@ -81,30 +83,39 @@ impl TraceLocationAllocator {
}

/// Create a new `TraceLocationAllocator` with fixed preprocessed columns setup.
pub fn new_with_preproccessed_columns(preprocessed_columns: &[PreprocessedColumn]) -> Self {
pub fn new_with_preproccessed_columns(
preprocessed_columns: &[Arc<dyn PreprocessedColumnOps>],
) -> Self {
Self {
next_tree_offsets: Default::default(),
preprocessed_columns: preprocessed_columns
.iter()
.enumerate()
.map(|(i, &col)| (col, i))
.map(|(i, col)| (col.clone(), i))
.collect(),
preprocessed_columns_allocation_mode: PreprocessedColumnsAllocationMode::Static,
}
}

pub const fn preprocessed_columns(&self) -> &HashMap<PreprocessedColumn, usize> {
pub const fn preprocessed_columns(&self) -> &HashMap<Arc<dyn PreprocessedColumnOps>, usize> {
&self.preprocessed_columns
}

// validates that `self.preprocessed_columns` is consistent with
// `preprocessed_columns`.
// I.e. preprocessed_columns[i] == self.preprocessed_columns[i].
pub fn validate_preprocessed_columns(&self, preprocessed_columns: &[PreprocessedColumn]) {
// The equality comparison uses the pointer comparison of the boxes.
pub fn validate_preprocessed_columns(
&self,
preprocessed_columns: &[Arc<dyn PreprocessedColumnOps>],
) {
assert_eq!(preprocessed_columns.len(), self.preprocessed_columns.len());

for (column, idx) in self.preprocessed_columns.iter() {
assert_eq!(Some(column), preprocessed_columns.get(*idx));
assert!(match preprocessed_columns.get(*idx) {
Some(preprocessed_column) => preprocessed_column == column,
None => false,
},)
}
}
}
Expand Down Expand Up @@ -146,7 +157,7 @@ impl<E: FrameworkEval> FrameworkComponent<E> {
let next_column = location_allocator.preprocessed_columns.len();
*location_allocator
.preprocessed_columns
.entry(*col)
.entry(col.clone())
.or_insert_with(|| {
if matches!(
location_allocator.preprocessed_columns_allocation_mode,
Expand Down
1 change: 1 addition & 0 deletions crates/prover/src/constraint_framework/cpu_domain.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::ops::Mul;
use std::sync::Arc;

use num_traits::Zero;

Expand Down
6 changes: 4 additions & 2 deletions crates/prover/src/constraint_framework/expr/evaluator.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use std::sync::Arc;

use num_traits::Zero;

use super::{BaseExpr, ExtExpr};
use crate::constraint_framework::expr::ColumnExpr;
use crate::constraint_framework::preprocessed_columns::PreprocessedColumn;
use crate::constraint_framework::preprocessed_columns::PreprocessedColumnOps;
use crate::constraint_framework::{EvalAtRow, Relation, RelationEntry, INTERACTION_TRACE_IDX};
use crate::core::fields::m31;
use crate::core::lookups::utils::Fraction;
Expand Down Expand Up @@ -174,7 +176,7 @@ impl EvalAtRow for ExprEvaluator {
intermediate
}

fn get_preprocessed_column(&mut self, column: PreprocessedColumn) -> Self::F {
fn get_preprocessed_column(&mut self, column: Arc<dyn PreprocessedColumnOps>) -> Self::F {
BaseExpr::Param(column.name().to_string())
}

Expand Down
9 changes: 5 additions & 4 deletions crates/prover/src/constraint_framework/info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ use std::array;
use std::cell::{RefCell, RefMut};
use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub};
use std::rc::Rc;
use std::sync::Arc;

use num_traits::{One, Zero};

use super::logup::{LogupAtRow, LogupSums};
use super::preprocessed_columns::PreprocessedColumn;
use super::preprocessed_columns::PreprocessedColumnOps;
use super::{EvalAtRow, INTERACTION_TRACE_IDX};
use crate::constraint_framework::PREPROCESSED_TRACE_IDX;
use crate::core::fields::m31::BaseField;
Expand All @@ -22,14 +23,14 @@ use crate::core::pcs::TreeVec;
pub struct InfoEvaluator {
pub mask_offsets: TreeVec<Vec<Vec<isize>>>,
pub n_constraints: usize,
pub preprocessed_columns: Vec<PreprocessedColumn>,
pub preprocessed_columns: Vec<Arc<dyn PreprocessedColumnOps>>,
pub logup: LogupAtRow<Self>,
pub arithmetic_counts: ArithmeticCounts,
}
impl InfoEvaluator {
pub fn new(
log_size: u32,
preprocessed_columns: Vec<PreprocessedColumn>,
preprocessed_columns: Vec<Arc<dyn PreprocessedColumnOps>>,
logup_sums: LogupSums,
) -> Self {
Self {
Expand Down Expand Up @@ -70,7 +71,7 @@ impl EvalAtRow for InfoEvaluator {
array::from_fn(|_| FieldCounter::one())
}

fn get_preprocessed_column(&mut self, column: PreprocessedColumn) -> Self::F {
fn get_preprocessed_column(&mut self, column: Arc<dyn PreprocessedColumnOps>) -> Self::F {
self.preprocessed_columns.push(column);
FieldCounter::one()
}
Expand Down
18 changes: 8 additions & 10 deletions crates/prover/src/constraint_framework/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@ mod simd_domain;
use std::array;
use std::fmt::Debug;
use std::ops::{Add, AddAssign, Mul, Neg, Sub};
use std::sync::Arc;

pub use assert::{assert_constraints, AssertEvaluator};
pub use component::{FrameworkComponent, FrameworkEval, TraceLocationAllocator};
pub use info::InfoEvaluator;
use num_traits::{One, Zero};
pub use point::PointEvaluator;
use preprocessed_columns::PreprocessedColumn;
use preprocessed_columns::PreprocessedColumnOps;
pub use simd_domain::SimdDomainEvaluator;

use crate::core::fields::m31::BaseField;
Expand Down Expand Up @@ -87,7 +88,7 @@ pub trait EvalAtRow {
mask_item
}

fn get_preprocessed_column(&mut self, _column: PreprocessedColumn) -> Self::F {
fn get_preprocessed_column(&mut self, _column: Arc<dyn PreprocessedColumnOps>) -> Self::F {
let [mask_item] = self.next_interaction_mask(PREPROCESSED_TRACE_IDX, [0]);
mask_item
}
Expand Down Expand Up @@ -165,18 +166,15 @@ pub trait EvalAtRow {
}
}

/// Default implementation for evaluators that have an element called "logup" that works like a
/// LogupAtRow, where the logup functionality can be proxied.
/// TODO(alont): Remove once LogupAtRow is no longer used.
macro_rules! logup_proxy {
() => {
fn write_logup_frac(&mut self, fraction: Fraction<Self::EF, Self::EF>) {
if self.logup.fracs.is_empty() {
self.logup.is_first = self.get_preprocessed_column(
crate::constraint_framework::preprocessed_columns::PreprocessedColumn::IsFirst(
self.logup.log_size,
),
);
self.logup.is_first = self.get_preprocessed_column(Arc::new(
crate::constraint_framework::preprocessed_columns::IsFirst {
log_size: self.logup.log_size,
},
));
self.logup.is_finalized = false;
}
self.logup.fracs.push(fraction.clone());
Expand Down
1 change: 1 addition & 0 deletions crates/prover/src/constraint_framework/point.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::ops::Mul;
use std::sync::Arc;

use super::logup::{LogupAtRow, LogupSums};
use super::{EvalAtRow, INTERACTION_TRACE_IDX};
Expand Down
154 changes: 126 additions & 28 deletions crates/prover/src/constraint_framework/preprocessed_columns.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,133 @@
use std::any::Any;
use std::fmt::Debug;
use std::hash::Hash;
use std::sync::Arc;

use num_traits::One;

use crate::core::backend::{Backend, Col, Column};
use crate::core::backend::simd::SimdBackend;
use crate::core::backend::{Backend, Col, Column, CpuBackend};
use crate::core::fields::m31::BaseField;
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::core::utils::{bit_reverse_index, coset_index_to_circle_domain_index};

// TODO(ilya): Where should this enum be placed?
/// XorTable, etc will be implementation of this trait.
pub trait PreprocessedColumnOps: Debug + Any {
fn get_type_id(&self) -> std::any::TypeId {
self.type_id()
}
fn name(&self) -> &'static str;
fn log_size(&self) -> u32;
fn gen_preprocessed_column_cpu(
&self,
) -> CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>;
fn gen_preprocessed_column_simd(
&self,
) -> CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>;
fn as_bytes(&self) -> Vec<u8>;
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct IsFirst {
pub log_size: u32,
}

impl PreprocessedColumnOps for IsFirst {
fn name(&self) -> &'static str {
"preprocessed.is_first"
}
fn log_size(&self) -> u32 {
self.log_size
}
fn gen_preprocessed_column_cpu(
&self,
) -> CircleEvaluation<CpuBackend, BaseField, BitReversedOrder> {
gen_is_first(self.log_size)
}
fn gen_preprocessed_column_simd(
&self,
) -> CircleEvaluation<SimdBackend, BaseField, BitReversedOrder> {
gen_is_first(self.log_size)
}
fn as_bytes(&self) -> Vec<u8> {
self.log_size.to_le_bytes().to_vec()
}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum PreprocessedColumn {
XorTable(u32, u32, usize),
IsFirst(u32),
Plonk(usize),
pub struct XorTable {
pub elem_bits: u32,
pub expand_bits: u32,
pub kind: usize,
}

impl PreprocessedColumn {
pub const fn name(&self) -> &'static str {
match self {
PreprocessedColumn::XorTable(..) => "preprocessed.xor_table",
PreprocessedColumn::IsFirst(_) => "preprocessed.is_first",
PreprocessedColumn::Plonk(_) => "preprocessed.plonk",
}
impl PreprocessedColumnOps for XorTable {
fn name(&self) -> &'static str {
"preprocessed.xor_table"
}
fn log_size(&self) -> u32 {
assert!(self.elem_bits >= self.expand_bits);
2 * (self.elem_bits - self.expand_bits)
}
fn gen_preprocessed_column_cpu(
&self,
) -> CircleEvaluation<CpuBackend, BaseField, BitReversedOrder> {
unimplemented!("XorTable is not supported.")
}
fn gen_preprocessed_column_simd(
&self,
) -> CircleEvaluation<SimdBackend, BaseField, BitReversedOrder> {
unimplemented!("XorTable is not supported.")
}
fn as_bytes(&self) -> Vec<u8> {
let mut bytes = vec![];
bytes.extend_from_slice(&self.elem_bits.to_le_bytes());
bytes.extend_from_slice(&self.expand_bits.to_le_bytes());
bytes.extend_from_slice(&self.kind.to_le_bytes());
bytes
}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct Plonk {
pub kind: u32,
}

impl PreprocessedColumnOps for Plonk {
fn name(&self) -> &'static str {
"preprocessed.plonk"
}
fn log_size(&self) -> u32 {
unimplemented!("Plonk is not supported.")
}
fn gen_preprocessed_column_cpu(
&self,
) -> CircleEvaluation<CpuBackend, BaseField, BitReversedOrder> {
unimplemented!("Plonk is not supported.")
}
fn gen_preprocessed_column_simd(
&self,
) -> CircleEvaluation<SimdBackend, BaseField, BitReversedOrder> {
unimplemented!("Plonk is not supported.")
}
fn as_bytes(&self) -> Vec<u8> {
self.kind.to_le_bytes().to_vec()
}
}

impl PartialEq for dyn PreprocessedColumnOps {
fn eq(&self, other: &Self) -> bool {
self.get_type_id() == other.get_type_id() && self.as_bytes() == other.as_bytes()
}
}

impl Eq for dyn PreprocessedColumnOps {}

impl Hash for dyn PreprocessedColumnOps {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.get_type_id().hash(state);
self.as_bytes().hash(state);
}
}

Expand Down Expand Up @@ -54,19 +161,10 @@ pub fn gen_is_step_with_offset<B: Backend>(
CircleEvaluation::new(CanonicCoset::new(log_size).circle_domain(), col)
}

pub fn gen_preprocessed_column<B: Backend>(
preprocessed_column: &PreprocessedColumn,
) -> CircleEvaluation<B, BaseField, BitReversedOrder> {
match preprocessed_column {
PreprocessedColumn::IsFirst(log_size) => gen_is_first(*log_size),
PreprocessedColumn::Plonk(_) | PreprocessedColumn::XorTable(..) => {
unimplemented!("eval_preprocessed_column: Plonk and XorTable are not supported.")
}
}
}

pub fn gen_preprocessed_columns<'a, B: Backend>(
columns: impl Iterator<Item = &'a PreprocessedColumn>,
) -> Vec<CircleEvaluation<B, BaseField, BitReversedOrder>> {
columns.map(gen_preprocessed_column).collect()
pub fn gen_preprocessed_columns_simd<'a>(
columns: impl Iterator<Item = Arc<dyn PreprocessedColumnOps>>,
) -> Vec<CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>> {
columns
.map(|col| col.gen_preprocessed_column_simd())
.collect()
}
1 change: 1 addition & 0 deletions crates/prover/src/constraint_framework/simd_domain.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::ops::Mul;
use std::sync::Arc;

use num_traits::Zero;

Expand Down
Loading