Skip to content

Commit

Permalink
support builtins in the adapter
Browse files Browse the repository at this point in the history
  • Loading branch information
Stavbe committed Dec 24, 2024
1 parent 4659106 commit 870f3c6
Show file tree
Hide file tree
Showing 5 changed files with 565 additions and 13 deletions.
79 changes: 79 additions & 0 deletions stwo_cairo_prover/crates/prover/src/input/builtins_segments.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
use std::collections::HashMap;

use cairo_vm::air_public_input::MemorySegmentAddresses;
use cairo_vm::types::builtin_name::BuiltinName;

/// This struct holds the builtins used in a Cairo program.
#[derive(Debug, Default)]
pub struct BuiltinsSegments {
pub range_check_bits_128: Option<MemorySegmentAddresses>,
pub range_check_bits_96: Option<MemorySegmentAddresses>,
pub bitwise: Option<MemorySegmentAddresses>,
pub add_mod: Option<MemorySegmentAddresses>,
pub ec_op: Option<MemorySegmentAddresses>,
pub ecdsa: Option<MemorySegmentAddresses>,
pub keccak: Option<MemorySegmentAddresses>,
pub mul_mod: Option<MemorySegmentAddresses>,
pub pedersen: Option<MemorySegmentAddresses>,
pub poseidon: Option<MemorySegmentAddresses>,
}

impl BuiltinsSegments {
/// Create a new `BuiltinsSegments` struct from a map of memory MemorySegmentAddressess.
pub fn from_memory_segments(memory_segments: &HashMap<&str, MemorySegmentAddresses>) -> Self {
let mut res = BuiltinsSegments::default();
for (name, value) in memory_segments.iter() {
let value = Some((value.begin_addr, value.stop_ptr).into());
if let Some(builtin) = BuiltinName::from_str(name) {
match builtin {
BuiltinName::range_check => res.range_check_bits_128 = value,
BuiltinName::range_check96 => res.range_check_bits_96 = value,
BuiltinName::bitwise => res.bitwise = value,
BuiltinName::add_mod => res.add_mod = value,
BuiltinName::ec_op => res.ec_op = value,
BuiltinName::ecdsa => res.ecdsa = value,
BuiltinName::keccak => res.keccak = value,
BuiltinName::mul_mod => res.mul_mod = value,
BuiltinName::pedersen => res.pedersen = value,
BuiltinName::poseidon => res.poseidon = value,
// Not builtins.
BuiltinName::output | BuiltinName::segment_arena => {}
}
};
}
res
}
}

#[cfg(test)]
mod test_builtins_segments {
use std::path::PathBuf;

use cairo_vm::air_public_input::PublicInput;

use crate::input::BuiltinsSegments;

#[test]
fn test_builtins_segments() {
let mut d = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
d.push("src/input/test_builtins_segments/air_pub.json");
let pub_data_string = std::fs::read_to_string(&d)
.unwrap_or_else(|_| panic!("Unable to read file: {}", d.display()));
let pub_data: PublicInput<'_> =
sonic_rs::from_str(&pub_data_string).expect("Unable to parse JSON");
let builtins_segments = BuiltinsSegments::from_memory_segments(&pub_data.memory_segments);
assert_eq!(
builtins_segments.range_check_bits_128,
Some((289, 289).into())
);
assert_eq!(builtins_segments.range_check_bits_96, None);
assert_eq!(builtins_segments.bitwise, None);
assert_eq!(builtins_segments.add_mod, None);
assert_eq!(builtins_segments.ec_op, None);
assert_eq!(builtins_segments.ecdsa, Some((353, 353).into()));
assert_eq!(builtins_segments.keccak, None);
assert_eq!(builtins_segments.mul_mod, None);
assert_eq!(builtins_segments.pedersen, Some((97, 97).into()));
assert_eq!(builtins_segments.poseidon, None);
}
}
5 changes: 3 additions & 2 deletions stwo_cairo_prover/crates/prover/src/input/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use cairo_vm::air_public_input::MemorySegmentAddresses;
use builtins_segments::BuiltinsSegments;
use mem::Memory;
use state_transitions::StateTransitions;

pub mod builtins_segments;
mod decode;
pub mod mem;
pub mod plain;
Expand All @@ -19,5 +20,5 @@ pub struct CairoInput {
pub public_mem_addresses: Vec<u32>,

// Builtins.
pub range_check_builtin: MemorySegmentAddresses,
pub builtins_segments: BuiltinsSegments,
}
15 changes: 9 additions & 6 deletions stwo_cairo_prover/crates/prover/src/input/plain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@ use itertools::Itertools;
use super::mem::{MemConfig, MemoryBuilder};
use super::state_transitions::StateTransitions;
use super::vm_import::MemEntry;
use super::CairoInput;
use crate::input::MemorySegmentAddresses;
use super::{BuiltinsSegments, CairoInput};

// TODO(Ohad): remove dev_mode after adding the rest of the opcodes.
/// Translates a plain casm into a CairoInput by running the program and extracting the memory and
Expand Down Expand Up @@ -71,6 +70,13 @@ pub fn input_from_finished_runner(runner: CairoRunner, dev_mode: bool) -> CairoI
val: bytemuck::cast(v.to_bytes_le()),
})
});

let memory_segments = &runner
.get_air_public_input()
.expect("Unable to get public input from the runner")
.memory_segments;
let builtins_segments = BuiltinsSegments::from_memory_segments(memory_segments);

let trace = runner.relocated_trace.unwrap();
let trace = trace.iter().map(|t| t.clone().into());

Expand All @@ -84,9 +90,6 @@ pub fn input_from_finished_runner(runner: CairoRunner, dev_mode: bool) -> CairoI
state_transitions,
mem: mem.build(),
public_mem_addresses,
range_check_builtin: MemorySegmentAddresses {
begin_addr: 24,
stop_ptr: 64,
},
builtins_segments,
}
}
Loading

0 comments on commit 870f3c6

Please sign in to comment.