Skip to content

Commit

Permalink
support builtins in the adapter
Browse files Browse the repository at this point in the history
  • Loading branch information
Stavbe committed Dec 22, 2024
1 parent b366d04 commit 7a75311
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 20 deletions.
3 changes: 2 additions & 1 deletion stwo_cairo_prover/crates/adapted_prover/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,9 @@ fn run(args: impl Iterator<Item = String>) -> Result<CairoProof<Blake2sMerkleHas
let _span = span!(Level::INFO, "run").entered();
let args = Args::try_parse_from(args)?;

// dev_mode = true
let vm_output: CairoInput =
import_from_vm_output(args.pub_json.as_path(), args.priv_json.as_path())?;
import_from_vm_output(args.pub_json.as_path(), args.priv_json.as_path(), true)?;

let casm_states_by_opcode_count = &vm_output.state_transitions.casm_states_by_opcode.counts();
log::info!("Casm states by opcode count: {casm_states_by_opcode_count:?}");
Expand Down
48 changes: 48 additions & 0 deletions stwo_cairo_prover/crates/prover/src/input/builtins_segments.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
use std::collections::HashMap;

use cairo_vm::air_public_input::MemorySegmentAddresses;

// TODO(Stav): Understand if the adapter should pass builtins that won't be supported by stwo.
/// This struct holds the builtins used in a Cairo program.
/// Most of them are not implemented yet by Stwo.
#[derive(Debug, Default)]
pub struct BuiltinsSegments {
pub range_check_bits_128_builtin: Option<MemorySegmentAddresses>,
pub range_check_bits_96_builtin: Option<MemorySegmentAddresses>,
pub bitwise_builtin: Option<MemorySegmentAddresses>,
pub add_mod_builtin: Option<MemorySegmentAddresses>,
pub ec_op_builtin: Option<MemorySegmentAddresses>,
pub ecdsa_builtin: Option<MemorySegmentAddresses>,
pub keccak_builtin: Option<MemorySegmentAddresses>,
pub mul_mod_builtin: Option<MemorySegmentAddresses>,
pub pedersen_builtin: Option<MemorySegmentAddresses>,
pub poseidon_builtin: Option<MemorySegmentAddresses>,
}

impl BuiltinsSegments {
/// Create a new `BuiltinsSegments` struct from a map of memory MemorySegmentAddressess.
pub fn from_memory_segments(
memory_segments: &HashMap<&str, cairo_vm::air_public_input::MemorySegmentAddresses>,
) -> Self {
let mut res = BuiltinsSegments::default();
for (name, value) in memory_segments.iter() {
let value = Some((value.begin_addr, value.stop_ptr).into());
match *name {
"range_check" => res.range_check_bits_128_builtin = value,
"range_check96" => res.range_check_bits_96_builtin = value,
"bitwise" => res.bitwise_builtin = value,
"add_mod" => res.add_mod_builtin = value,
"ec_op" => res.ec_op_builtin = value,
"ecdsa" => res.ecdsa_builtin = value,
"keccak" => res.keccak_builtin = value,
"mul_mod" => res.mul_mod_builtin = value,
"pedersen" => res.pedersen_builtin = value,
"poseidon" => res.poseidon_builtin = value,
// Not builtins.
"output" | "execution" | "program" => {}
_ => panic!("Unknown memory segment: {name}"),
}
}
res
}
}
5 changes: 3 additions & 2 deletions stwo_cairo_prover/crates/prover/src/input/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use cairo_vm::air_public_input::MemorySegmentAddresses;
use builtins_segments::BuiltinsSegments;
use mem::Memory;
use state_transitions::StateTransitions;

pub mod builtins_segments;
mod decode;
pub mod mem;
pub mod plain;
Expand All @@ -19,5 +20,5 @@ pub struct CairoInput {
pub public_mem_addresses: Vec<u32>,

// Builtins.
pub range_check_builtin: MemorySegmentAddresses,
pub builtins_segments: BuiltinsSegments,
}
15 changes: 9 additions & 6 deletions stwo_cairo_prover/crates/prover/src/input/plain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@ use itertools::Itertools;
use super::mem::{MemConfig, MemoryBuilder};
use super::state_transitions::StateTransitions;
use super::vm_import::MemEntry;
use super::CairoInput;
use crate::input::MemorySegmentAddresses;
use super::{BuiltinsSegments, CairoInput};

// TODO(Ohad): remove dev_mode after adding the rest of the opcodes.
/// Translates a plain casm into a CairoInput by running the program and extracting the memory and
Expand Down Expand Up @@ -71,6 +70,13 @@ pub fn input_from_finished_runner(runner: CairoRunner, dev_mode: bool) -> CairoI
val: bytemuck::cast(v.to_bytes_le()),
})
});

let memory_segments = &runner
.get_air_public_input()
.expect("Unable to get public input from the runner")
.memory_segments;
let builtins_segments = BuiltinsSegments::from_memory_segments(memory_segments);

let trace = runner.relocated_trace.unwrap();
let trace = trace.iter().map(|t| t.clone().into());

Expand All @@ -84,9 +90,6 @@ pub fn input_from_finished_runner(runner: CairoRunner, dev_mode: bool) -> CairoI
state_transitions,
mem: mem.build(),
public_mem_addresses,
range_check_builtin: MemorySegmentAddresses {
begin_addr: 24,
stop_ptr: 64,
},
builtins_segments,
}
}
115 changes: 104 additions & 11 deletions stwo_cairo_prover/crates/prover/src/input/vm_import/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use super::mem::MemConfig;
use super::state_transitions::StateTransitions;
use super::CairoInput;
use crate::input::mem::MemoryBuilder;
use crate::input::MemorySegmentAddresses;
use crate::input::BuiltinsSegments;

#[derive(Debug, Error)]
pub enum VmImportError {
Expand All @@ -26,9 +26,11 @@ pub enum VmImportError {
NoMemorySegments,
}

// TODO(Ohad): remove dev_mode after adding the rest of the opcodes.
pub fn import_from_vm_output(
pub_json: &Path,
priv_json: &Path,
dev_mode: bool,
) -> Result<CairoInput, VmImportError> {
let _span = span!(Level::INFO, "import_from_vm_output").entered();
let pub_data_string = std::fs::read_to_string(pub_json)?;
Expand All @@ -50,22 +52,22 @@ pub fn import_from_vm_output(
let mut trace_file = std::io::BufReader::new(std::fs::File::open(trace_path)?);
let mut mem_file = std::io::BufReader::new(std::fs::File::open(mem_path)?);
let mut mem = MemoryBuilder::from_iter(mem_config, MemEntryIter(&mut mem_file));
let state_transitions = StateTransitions::from_iter(TraceIter(&mut trace_file), &mut mem, true);
let state_transitions =
StateTransitions::from_iter(TraceIter(&mut trace_file), &mut mem, dev_mode);

let public_mem_addresses = pub_data
.public_memory
.iter()
.map(|entry| entry.address as u32)
.collect();

let builtins_segments = BuiltinsSegments::from_memory_segments(&pub_data.memory_segments);

Ok(CairoInput {
state_transitions,
mem: mem.build(),
public_mem_addresses,
range_check_builtin: MemorySegmentAddresses {
begin_addr: pub_data.memory_segments["range_check"].begin_addr as usize,
stop_ptr: pub_data.memory_segments["range_check"].stop_ptr as usize,
},
builtins_segments,
})
}

Expand Down Expand Up @@ -135,7 +137,12 @@ pub mod tests {
let mut d = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
d.push("test_data/large_cairo_input");

import_from_vm_output(d.join("pub.json").as_path(), d.join("priv.json").as_path()).expect(
import_from_vm_output(
d.join("pub.json").as_path(),
d.join("priv.json").as_path(),
false,
)
.expect(
"
Failed to read test files. Maybe git-lfs is not installed? Checkout README.md.",
)
Expand All @@ -144,19 +151,24 @@ pub mod tests {
pub fn small_cairo_input() -> CairoInput {
let mut d = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
d.push("test_data/small_cairo_input");
import_from_vm_output(d.join("pub.json").as_path(), d.join("priv.json").as_path()).expect(
import_from_vm_output(
d.join("pub.json").as_path(),
d.join("priv.json").as_path(),
false,
)
.expect(
"
Failed to read test files. Maybe git-lfs is not installed? Checkout README.md.",
)
}

// TODO (Stav): Once all the components are in, verify the proof to ensure the sort was correct.
// TODO (Ohad): remove the following doc after deleting dev_mod.
/// When not ignored, the test passes only with dev_mod = false.
#[ignore]
#[test]
fn test_read_from_large_files() {
let input = large_cairo_input();

// Test opcode components.
let components = input.state_transitions.casm_states_by_opcode;
assert_eq!(components.generic_opcode.len(), 0);
assert_eq!(components.add_ap_opcode_is_imm_f_op_1_base_fp_f.len(), 0);
Expand Down Expand Up @@ -217,13 +229,57 @@ pub mod tests {
assert_eq!(components.mul_opcode_is_small_f_is_imm_f.len(), 4583);
assert_eq!(components.mul_opcode_is_small_f_is_imm_t.len(), 9047);
assert_eq!(components.ret_opcode.len(), 49472);

// Test builtins.
let builtins_segments = input.builtins_segments;
assert_eq!(
builtins_segments.range_check_bits_128_builtin,
Some((1715768, 1757348).into())
);
assert_eq!(
builtins_segments.range_check_bits_96_builtin,
Some((17706552, 17706552).into())
);
assert_eq!(
builtins_segments.bitwise_builtin,
Some((5942840, 5942840).into())
);
assert_eq!(
builtins_segments.add_mod_builtin,
Some((21900856, 21900856).into())
);
assert_eq!(
builtins_segments.ec_op_builtin,
Some((16428600, 16428747).into())
);
assert_eq!(
builtins_segments.ecdsa_builtin,
Some((5910072, 5910072).into())
);
assert_eq!(
builtins_segments.keccak_builtin,
Some((16657976, 16657976).into())
);
assert_eq!(
builtins_segments.mul_mod_builtin,
Some((23735864, 23735864).into())
);
assert_eq!(
builtins_segments.pedersen_builtin,
Some((1322552, 1337489).into())
);
assert_eq!(
builtins_segments.poseidon_builtin,
Some((16920120, 17444532).into())
);
}

// When not ignored, the test passes only with dev_mod = false.
#[ignore]
#[test]
fn test_read_from_small_files() {
let input = small_cairo_input();

// Test opcode components.
let components = input.state_transitions.casm_states_by_opcode;
assert_eq!(components.generic_opcode.len(), 0);
assert_eq!(components.add_ap_opcode_is_imm_f_op_1_base_fp_f.len(), 0);
Expand Down Expand Up @@ -281,5 +337,42 @@ pub mod tests {
assert_eq!(components.mul_opcode_is_small_f_is_imm_f.len(), 0);
assert_eq!(components.mul_opcode_is_small_f_is_imm_t.len(), 0);
assert_eq!(components.ret_opcode.len(), 462);

// Test builtins.
let builtins_segments = input.builtins_segments;
assert_eq!(
builtins_segments.range_check_bits_128_builtin,
Some((6000, 6050).into())
);
assert_eq!(
builtins_segments.range_check_bits_96_builtin,
Some((68464, 68514).into())
);
assert_eq!(
builtins_segments.bitwise_builtin,
Some((22512, 22762).into())
);
assert_eq!(
builtins_segments.add_mod_builtin,
Some((84848, 84848).into())
);
assert_eq!(builtins_segments.ec_op_builtin, Some((63472, 63822).into()));
assert_eq!(builtins_segments.ecdsa_builtin, Some((22384, 22484).into()));
assert_eq!(
builtins_segments.keccak_builtin,
Some((64368, 65168).into())
);
assert_eq!(
builtins_segments.mul_mod_builtin,
Some((92016, 92016).into())
);
assert_eq!(
builtins_segments.pedersen_builtin,
Some((4464, 4614).into())
);
assert_eq!(
builtins_segments.poseidon_builtin,
Some((65392, 65692).into())
);
}
}

0 comments on commit 7a75311

Please sign in to comment.