Skip to content

Commit

Permalink
Create CairoAir.
Browse files Browse the repository at this point in the history
  • Loading branch information
alonh5 committed Jul 22, 2024
1 parent 8272f0b commit 5af696a
Show file tree
Hide file tree
Showing 6 changed files with 159 additions and 5 deletions.
141 changes: 141 additions & 0 deletions stwo_cairo_prover/src/air/air.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
use std::collections::BTreeMap;

use stwo_prover::core::air::{Air, AirProver, Component, ComponentProver};
use stwo_prover::core::backend::CpuBackend;
use stwo_prover::core::channel::{Blake2sChannel, Channel};
use stwo_prover::core::fields::m31::BaseField;
use stwo_prover::core::poly::circle::CircleEvaluation;
use stwo_prover::core::poly::BitReversedOrder;
use stwo_prover::core::prover::VerificationError;
use stwo_prover::core::{ColumnVec, InteractionElements, LookupValues};
use stwo_prover::trace_generation::registry::ComponentGenerationRegistry;
use stwo_prover::trace_generation::{AirTraceGenerator, AirTraceVerifier, ComponentTraceGenerator};

use crate::components::memory::component::{
MemoryComponent, MemoryTraceGenerator, MEMORY_ADDRESS_BOUND, MEMORY_ALPHA, MEMORY_COMPONENT_ID,
MEMORY_Z, N_MEMORY_COLUMNS,
};
use crate::components::range_check_unit::component::{
RangeCheckUnitComponent, RangeCheckUnitTraceGenerator, N_RC_COLUMNS, RC_COMPONENT_ID, RC_Z,
};

struct CairoAirGenerator {
pub registry: ComponentGenerationRegistry,
}

impl CairoAirGenerator {
pub fn new(path: String) -> Self {
let mut registry = ComponentGenerationRegistry::default();
registry.register(MEMORY_COMPONENT_ID, MemoryTraceGenerator::new(path));
registry.register(
RC_COMPONENT_ID,
RangeCheckUnitTraceGenerator::new(MEMORY_ADDRESS_BOUND),
);
Self { registry }
}
}

impl AirTraceVerifier for CairoAirGenerator {
fn interaction_elements(&self, channel: &mut Blake2sChannel) -> InteractionElements {
let elements = channel.draw_felts(3);
InteractionElements::new(BTreeMap::from_iter(vec![
(MEMORY_ALPHA.to_string(), elements[0]),
(MEMORY_Z.to_string(), elements[1]),
(RC_Z.to_string(), elements[2]),
]))
}
}

impl AirTraceGenerator<CpuBackend> for CairoAirGenerator {
fn write_trace(&mut self) -> Vec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>> {
let mut trace = Vec::with_capacity(N_MEMORY_COLUMNS + N_RC_COLUMNS);
trace.extend(MemoryTraceGenerator::write_trace(
MEMORY_COMPONENT_ID,
&mut self.registry,
));
trace.extend(RangeCheckUnitTraceGenerator::write_trace(
RC_COMPONENT_ID,
&mut self.registry,
));
trace
}

fn interact(
&self,
trace: &ColumnVec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>>,
elements: &InteractionElements,
) -> Vec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>> {
let mut interaction_trace = Vec::new();
let trace_iter = &mut trace.iter();
let memory_generator = self
.registry
.get_generator::<MemoryTraceGenerator>(MEMORY_COMPONENT_ID);
interaction_trace.extend(
memory_generator
.write_interaction_trace(&trace_iter.take(N_MEMORY_COLUMNS).collect(), elements),
);
let rc_generator = self
.registry
.get_generator::<RangeCheckUnitTraceGenerator>(RC_COMPONENT_ID);
interaction_trace.extend(
rc_generator
.write_interaction_trace(&trace_iter.take(N_RC_COLUMNS).collect(), elements),
);
interaction_trace
}

fn to_air_prover(&self) -> impl AirProver<CpuBackend> {
let memory = self
.registry
.get_generator::<MemoryTraceGenerator>(MEMORY_COMPONENT_ID);
let range_check_unit = self
.registry
.get_generator::<RangeCheckUnitTraceGenerator>(RC_COMPONENT_ID);
CairoAir {
memory: memory.component(),
range_check_unit: range_check_unit.component(),
}
}

fn composition_log_degree_bound(&self) -> u32 {
let component_generator = self
.registry
.get_generator::<MemoryTraceGenerator>(MEMORY_COMPONENT_ID);
component_generator
.component()
.max_constraint_log_degree_bound()
}
}

#[derive(Clone)]
pub struct CairoAir {
pub memory: MemoryComponent,
pub range_check_unit: RangeCheckUnitComponent,
}

impl Air for CairoAir {
fn components(&self) -> Vec<&dyn Component> {
vec![&self.memory, &self.range_check_unit]
}

fn verify_lookups(&self, _lookup_values: &LookupValues) -> Result<(), VerificationError> {
Ok(())
}
}

impl AirProver<CpuBackend> for CairoAir {
fn prover_components(&self) -> Vec<&dyn ComponentProver<CpuBackend>> {
vec![&self.memory, &self.range_check_unit]
}
}

impl AirTraceVerifier for CairoAir {
fn interaction_elements(&self, channel: &mut Blake2sChannel) -> InteractionElements {
let elements = channel.draw_felts(3);
InteractionElements::new(BTreeMap::from_iter(vec![
(MEMORY_ALPHA.to_string(), elements[0]),
(MEMORY_Z.to_string(), elements[1]),
(RC_Z.to_string(), elements[2]),
]))
}
}
1 change: 1 addition & 0 deletions stwo_cairo_prover/src/air/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod air;
3 changes: 2 additions & 1 deletion stwo_cairo_prover/src/components/memory/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ pub const MULTIPLICITY_COLUMN: usize = N_M31_IN_FELT252 + 1;
// TODO(AlonH): Make memory size configurable.
pub const LOG_MEMORY_ADDRESS_BOUND: u32 = 3;
pub const MEMORY_ADDRESS_BOUND: usize = 1 << LOG_MEMORY_ADDRESS_BOUND;
pub const N_MEMORY_COLUMNS: usize = N_M31_IN_FELT252 + 2;

/// Addresses are continuous and start from 0.
/// Values are Felt252 stored as `N_M31_IN_FELT252` M31 values (each value contain 12 bits).
Expand All @@ -51,7 +52,7 @@ pub struct MemoryComponent {

impl MemoryComponent {
pub const fn n_columns(&self) -> usize {
N_M31_IN_FELT252 + 2
N_MEMORY_COLUMNS
}
}

Expand Down
10 changes: 9 additions & 1 deletion stwo_cairo_prover/src/components/memory/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ mod tests {
use std::collections::BTreeMap;

use component::{
MemoryComponent, MemoryTraceGenerator, MEMORY_ALPHA, MEMORY_COMPONENT_ID, MEMORY_Z,
MemoryComponent, MemoryTraceGenerator, MEMORY_ADDRESS_BOUND, MEMORY_ALPHA,
MEMORY_COMPONENT_ID, MEMORY_Z,
};
use stwo_prover::core::air::{Air, AirProver, Component, ComponentProver};
use stwo_prover::core::backend::CpuBackend;
Expand All @@ -26,12 +27,19 @@ mod tests {
};

use super::*;
use crate::components::range_check_unit::component::{
RangeCheckUnitTraceGenerator, RC_COMPONENT_ID,
};

pub fn register_test_memory(registry: &mut ComponentGenerationRegistry) {
registry.register(
MEMORY_COMPONENT_ID,
MemoryTraceGenerator::new("".to_string()),
);
registry.register(
RC_COMPONENT_ID,
RangeCheckUnitTraceGenerator::new(MEMORY_ADDRESS_BOUND),
);
vec![
vec![BaseField::from_u32_unchecked(0); 3],
vec![BaseField::from_u32_unchecked(1); 1],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ pub const RC_LOOKUP_VALUE_1: &str = "RC_UNIT_LOOKUP_1";
pub const RC_LOOKUP_VALUE_2: &str = "RC_UNIT_LOOKUP_2";
pub const RC_LOOKUP_VALUE_3: &str = "RC_UNIT_LOOKUP_3";

pub const N_RC_COLUMNS: usize = 2;

#[derive(Clone)]
pub struct RangeCheckUnitComponent {
pub log_n_rows: u32,
Expand All @@ -48,7 +50,7 @@ impl Component for RangeCheckUnitComponent {

fn trace_log_degree_bounds(&self) -> TreeVec<ColumnVec<u32>> {
TreeVec::new(vec![
vec![self.log_n_rows; 2],
vec![self.log_n_rows; N_RC_COLUMNS],
vec![self.log_n_rows; SECURE_EXTENSION_DEGREE],
])
}
Expand All @@ -59,7 +61,7 @@ impl Component for RangeCheckUnitComponent {
) -> TreeVec<ColumnVec<Vec<CirclePoint<SecureField>>>> {
let domain = CanonicCoset::new(self.log_n_rows);
TreeVec::new(vec![
fixed_mask_points(&vec![vec![0_usize]; 2], point),
fixed_mask_points(&vec![vec![0_usize]; N_RC_COLUMNS], point),
vec![vec![point, point - domain.step().into_ef()]; SECURE_EXTENSION_DEGREE],
])
}
Expand Down Expand Up @@ -134,7 +136,7 @@ impl ComponentTraceGenerator<CpuBackend> for RangeCheckUnitTraceGenerator {
registry.get_generator::<RangeCheckUnitTraceGenerator>(component_id);
let rc_max_value = rc_unit_trace_generator.max_value;

let mut trace = vec![vec![BaseField::zero(); rc_max_value]; 2];
let mut trace = vec![vec![BaseField::zero(); rc_max_value]; N_RC_COLUMNS];
for (i, multiplicity) in rc_unit_trace_generator.multiplicities.iter().enumerate() {
// TODO(AlonH): Either create a constant column for the addresses and remove it from
// here or add constraints to the column here.
Expand Down
1 change: 1 addition & 0 deletions stwo_cairo_prover/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub mod air;
pub mod components;

fn main() {
Expand Down

0 comments on commit 5af696a

Please sign in to comment.