Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement memory component prover. #24

Merged
merged 1 commit into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 60 additions & 16 deletions stwo_cairo_prover/src/components/memory/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ use stwo_prover::core::air::mask::fixed_mask_points;
use stwo_prover::core::air::Component;
use stwo_prover::core::backend::CpuBackend;
use stwo_prover::core::circle::CirclePoint;
use stwo_prover::core::fields::m31::{BaseField, M31};
use stwo_prover::core::constraints::{coset_vanishing, point_excluder, point_vanishing};
use stwo_prover::core::fields::m31::BaseField;
use stwo_prover::core::fields::qm31::SecureField;
use stwo_prover::core::fields::secure_column::{SecureColumn, SECURE_EXTENSION_DEGREE};
use stwo_prover::core::fields::FieldExpOps;
Expand All @@ -17,24 +18,33 @@ use stwo_prover::core::utils::{
};
use stwo_prover::core::{ColumnVec, InteractionElements, LookupValues};
use stwo_prover::trace_generation::registry::ComponentGenerationRegistry;
use stwo_prover::trace_generation::{ComponentGen, ComponentTraceGenerator};
use stwo_prover::trace_generation::{
ComponentGen, ComponentTraceGenerator, BASE_TRACE, INTERACTION_TRACE,
};

pub const MEMORY_ALPHA: &str = "MEMORY_ALPHA";
pub const MEMORY_Z: &str = "MEMORY_Z";

const N_M31_IN_FELT252: usize = 21;
const MULTIPLICITY_COLUMN: usize = 22;
const LOG_MEMORY_ADDRESS_BOUND: u32 = 20;
const MEMORY_ADDRESS_BOUND: usize = 1 << LOG_MEMORY_ADDRESS_BOUND;
pub const MEMORY_COMPONENT_ID: &str = "MEMORY";
pub const MEMORY_LOOKUP_VALUE_0: &str = "MEMORY_LOOKUP_0";
pub const MEMORY_LOOKUP_VALUE_1: &str = "MEMORY_LOOKUP_1";
pub const MEMORY_LOOKUP_VALUE_2: &str = "MEMORY_LOOKUP_2";
pub const MEMORY_LOOKUP_VALUE_3: &str = "MEMORY_LOOKUP_3";

pub const N_M31_IN_FELT252: usize = 21;
pub const MULTIPLICITY_COLUMN: usize = 22;
// TODO(AlonH): Make memory size configurable.
pub const LOG_MEMORY_ADDRESS_BOUND: u32 = 3;
pub const MEMORY_ADDRESS_BOUND: usize = 1 << LOG_MEMORY_ADDRESS_BOUND;

/// Addresses are continuous and start from 0.
/// Values are Felt252 stored as `N_M31_IN_FELT252` M31 values (each value contain 12 bits).
pub struct MemoryTraceGenerator {
// TODO(AlonH): Consider to change values to be Felt252.
pub values: Vec<[M31; N_M31_IN_FELT252]>,
pub values: Vec<[BaseField; N_M31_IN_FELT252]>,
pub multiplicities: Vec<u32>,
}

#[derive(Clone)]
pub struct MemoryComponent {
pub log_n_rows: u32,
}
Expand All @@ -48,7 +58,7 @@ impl MemoryComponent {
impl MemoryTraceGenerator {
pub fn new(_path: String) -> Self {
// TODO(AlonH): change to read from file.
let values = vec![[M31::zero(); N_M31_IN_FELT252]; MEMORY_ADDRESS_BOUND];
let values = vec![[BaseField::zero(); N_M31_IN_FELT252]; MEMORY_ADDRESS_BOUND];
let multiplicities = vec![0; MEMORY_ADDRESS_BOUND];
Self {
values,
Expand All @@ -61,7 +71,7 @@ impl ComponentGen for MemoryTraceGenerator {}

impl ComponentTraceGenerator<CpuBackend> for MemoryTraceGenerator {
type Component = MemoryComponent;
type Inputs = M31;
type Inputs = BaseField;

fn add_inputs(&mut self, inputs: &Self::Inputs) {
let input = inputs.0 as usize;
Expand Down Expand Up @@ -175,12 +185,46 @@ impl Component for MemoryComponent {

fn evaluate_constraint_quotients_at_point(
&self,
_point: CirclePoint<SecureField>,
_mask: &TreeVec<Vec<Vec<SecureField>>>,
_evaluation_accumulator: &mut PointEvaluationAccumulator,
_interaction_elements: &InteractionElements,
_lookup_values: &LookupValues,
point: CirclePoint<SecureField>,
mask: &TreeVec<Vec<Vec<SecureField>>>,
evaluation_accumulator: &mut PointEvaluationAccumulator,
interaction_elements: &InteractionElements,
lookup_values: &LookupValues,
) {
todo!()
// First lookup point boundary constraint.
let constraint_zero_domain = CanonicCoset::new(self.log_n_rows).coset;
let (alpha, z) = (
interaction_elements[MEMORY_ALPHA],
interaction_elements[MEMORY_Z],
);
let value =
SecureField::from_partial_evals(std::array::from_fn(|i| mask[INTERACTION_TRACE][i][0]));
let address_and_value: [SecureField; N_M31_IN_FELT252 + 1] =
std::array::from_fn(|i| mask[BASE_TRACE][i][0]);
let numerator = value * shifted_secure_combination(&address_and_value, alpha, z)
- mask[BASE_TRACE][MULTIPLICITY_COLUMN][0];
let denom = point_vanishing(constraint_zero_domain.at(0), point);
evaluation_accumulator.accumulate(numerator / denom);

// Last lookup point boundary constraint.
let lookup_value = SecureField::from_m31(
lookup_values[MEMORY_LOOKUP_VALUE_0],
lookup_values[MEMORY_LOOKUP_VALUE_1],
lookup_values[MEMORY_LOOKUP_VALUE_2],
lookup_values[MEMORY_LOOKUP_VALUE_3],
);
let numerator = value - lookup_value;
let denom = point_vanishing(constraint_zero_domain.at(1), point);
evaluation_accumulator.accumulate(numerator / denom);

// Lookup step constraint.
let prev_value =
SecureField::from_partial_evals(std::array::from_fn(|i| mask[INTERACTION_TRACE][i][1]));
let numerator = (value - prev_value)
* shifted_secure_combination(&address_and_value, alpha, z)
- mask[BASE_TRACE][22][0];
let denom = coset_vanishing(constraint_zero_domain, point)
/ point_excluder(constraint_zero_domain.at(0), point);
evaluation_accumulator.accumulate(numerator / denom);
}
}
137 changes: 137 additions & 0 deletions stwo_cairo_prover/src/components/memory/component_prover.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
use std::collections::BTreeMap;

use itertools::izip;
use num_traits::Zero;
use stwo_prover::core::air::accumulation::DomainEvaluationAccumulator;
use stwo_prover::core::air::{Component, ComponentProver, ComponentTrace};
use stwo_prover::core::backend::CpuBackend;
use stwo_prover::core::constraints::{coset_vanishing, point_excluder};
use stwo_prover::core::fields::m31::BaseField;
use stwo_prover::core::fields::qm31::SecureField;
use stwo_prover::core::fields::FieldExpOps;
use stwo_prover::core::poly::circle::CanonicCoset;
use stwo_prover::core::utils::{
bit_reverse, point_vanish_denominator_inverses, previous_bit_reversed_circle_domain_index,
shifted_secure_combination,
};
use stwo_prover::core::{InteractionElements, LookupValues};
use stwo_prover::trace_generation::{BASE_TRACE, INTERACTION_TRACE};

use super::component::{
MemoryComponent, MEMORY_ALPHA, MEMORY_LOOKUP_VALUE_0, MEMORY_LOOKUP_VALUE_1,
MEMORY_LOOKUP_VALUE_2, MEMORY_LOOKUP_VALUE_3, MEMORY_Z, MULTIPLICITY_COLUMN, N_M31_IN_FELT252,
};

impl ComponentProver<CpuBackend> for MemoryComponent {
fn evaluate_constraint_quotients_on_domain(
&self,
trace: &ComponentTrace<'_, CpuBackend>,
evaluation_accumulator: &mut DomainEvaluationAccumulator<CpuBackend>,
interaction_elements: &InteractionElements,
lookup_values: &LookupValues,
) {
let max_constraint_degree = self.max_constraint_log_degree_bound();
let trace_eval_domain = CanonicCoset::new(max_constraint_degree).circle_domain();
let trace_evals = &trace.evals;
let zero_domain = CanonicCoset::new(self.log_n_rows).coset;
let [mut accum] =
evaluation_accumulator.columns([(max_constraint_degree, self.n_constraints())]);

// TODO(AlonH): Get all denominators in one loop and don't perform unnecessary inversions.
let first_point_denom_inverses =
point_vanish_denominator_inverses(trace_eval_domain, zero_domain.at(0));
let last_point_denom_inverses =
point_vanish_denominator_inverses(trace_eval_domain, zero_domain.at(1));
let mut step_denoms = vec![];
for point in trace_eval_domain.iter() {
step_denoms.push(
coset_vanishing(zero_domain, point) / point_excluder(zero_domain.at(0), point),
);
}
bit_reverse(&mut step_denoms);
let mut step_denom_inverses = vec![BaseField::zero(); 1 << (max_constraint_degree)];
BaseField::batch_inverse(&step_denoms, &mut step_denom_inverses);
let (alpha, z) = (
interaction_elements[MEMORY_ALPHA],
interaction_elements[MEMORY_Z],
);
let lookup_value = SecureField::from_m31(
lookup_values[MEMORY_LOOKUP_VALUE_0],
lookup_values[MEMORY_LOOKUP_VALUE_1],
lookup_values[MEMORY_LOOKUP_VALUE_2],
lookup_values[MEMORY_LOOKUP_VALUE_3],
);
for (i, (first_point_denom_inverse, last_point_denom_inverse, step_denom_inverse)) in izip!(
first_point_denom_inverses,
last_point_denom_inverses,
step_denom_inverses,
)
.enumerate()
{
let value = SecureField::from_m31_array(std::array::from_fn(|j| {
trace_evals[INTERACTION_TRACE][j][i]
}));
let prev_index = previous_bit_reversed_circle_domain_index(
i,
zero_domain.log_size,
trace_eval_domain.log_size(),
);
let prev_value = SecureField::from_m31_array(std::array::from_fn(|j| {
trace_evals[INTERACTION_TRACE][j][prev_index]
}));
let address_and_value: [BaseField; N_M31_IN_FELT252 + 1] =
std::array::from_fn(|j| trace_evals[BASE_TRACE][j][i]);

let first_point_numerator = accum.random_coeff_powers[2]
* (value * shifted_secure_combination(&address_and_value, alpha, z)
- trace_evals[BASE_TRACE][MULTIPLICITY_COLUMN][i]);

let last_point_numerator = accum.random_coeff_powers[1] * (value - lookup_value);
let step_numerator = accum.random_coeff_powers[0]
* ((value - prev_value) * shifted_secure_combination(&address_and_value, alpha, z)
- trace_evals[BASE_TRACE][MULTIPLICITY_COLUMN][i]);
accum.accumulate(
i,
first_point_numerator * first_point_denom_inverse
+ last_point_numerator * last_point_denom_inverse
+ step_numerator * step_denom_inverse,
);
}
}

fn lookup_values(&self, trace: &ComponentTrace<'_, CpuBackend>) -> LookupValues {
let domain = CanonicCoset::new(self.log_n_rows);
let trace_poly = &trace.polys[INTERACTION_TRACE];
let values = BTreeMap::from_iter([
(
MEMORY_LOOKUP_VALUE_0.to_string(),
trace_poly[0]
.eval_at_point(domain.at(1).into_ef())
.try_into()
.unwrap(),
),
(
MEMORY_LOOKUP_VALUE_1.to_string(),
trace_poly[1]
.eval_at_point(domain.at(1).into_ef())
.try_into()
.unwrap(),
),
(
MEMORY_LOOKUP_VALUE_2.to_string(),
trace_poly[2]
.eval_at_point(domain.at(1).into_ef())
.try_into()
.unwrap(),
),
(
MEMORY_LOOKUP_VALUE_3.to_string(),
trace_poly[3]
.eval_at_point(domain.at(1).into_ef())
.try_into()
.unwrap(),
),
]);
LookupValues::new(values)
}
}
Loading
Loading