diff --git a/stwo_cairo_prover/crates/adapted_prover/src/main.rs b/stwo_cairo_prover/crates/adapted_prover/src/main.rs index b5ed0bfe..7ac99fad 100644 --- a/stwo_cairo_prover/crates/adapted_prover/src/main.rs +++ b/stwo_cairo_prover/crates/adapted_prover/src/main.rs @@ -65,8 +65,9 @@ fn run(args: impl Iterator) -> Result, + pub range_check_bits_96_builtin: Option, + pub bitwise_builtin: Option, + pub add_mod_builtin: Option, + pub ec_op_builtin: Option, + pub ecdsa_builtin: Option, + pub keccak_builtin: Option, + pub mul_mod_builtin: Option, + pub pedersen_builtin: Option, + pub poseidon_builtin: Option, +} + +impl BuiltinsSegments { + /// Create a new `BuiltinsSegments` struct from a map of memory MemorySegmentAddressess. + pub fn from_memory_segments( + memory_segments: &HashMap<&str, cairo_vm::air_public_input::MemorySegmentAddresses>, + ) -> Self { + let mut res = BuiltinsSegments::default(); + for (name, value) in memory_segments.iter() { + let value = Some((value.begin_addr, value.stop_ptr).into()); + match *name { + "range_check" => res.range_check_bits_128_builtin = value, + "range_check96" => res.range_check_bits_96_builtin = value, + "bitwise" => res.bitwise_builtin = value, + "add_mod" => res.add_mod_builtin = value, + "ec_op" => res.ec_op_builtin = value, + "ecdsa" => res.ecdsa_builtin = value, + "keccak" => res.keccak_builtin = value, + "mul_mod" => res.mul_mod_builtin = value, + "pedersen" => res.pedersen_builtin = value, + "poseidon" => res.poseidon_builtin = value, + // Not builtins. + "output" | "execution" | "program" => {} + _ => panic!("Unknown memory segment: {name}"), + } + } + res + } +} diff --git a/stwo_cairo_prover/crates/prover/src/input/mod.rs b/stwo_cairo_prover/crates/prover/src/input/mod.rs index be6994cf..36a35567 100644 --- a/stwo_cairo_prover/crates/prover/src/input/mod.rs +++ b/stwo_cairo_prover/crates/prover/src/input/mod.rs @@ -1,7 +1,8 @@ -use cairo_vm::air_public_input::MemorySegmentAddresses; +use builtins_segments::BuiltinsSegments; use mem::Memory; use state_transitions::StateTransitions; +pub mod builtins_segments; mod decode; pub mod mem; pub mod plain; @@ -19,5 +20,5 @@ pub struct CairoInput { pub public_mem_addresses: Vec, // Builtins. - pub range_check_builtin: MemorySegmentAddresses, + pub builtins_segments: BuiltinsSegments, } diff --git a/stwo_cairo_prover/crates/prover/src/input/plain.rs b/stwo_cairo_prover/crates/prover/src/input/plain.rs index 7f9b3d6e..8d13a6f8 100644 --- a/stwo_cairo_prover/crates/prover/src/input/plain.rs +++ b/stwo_cairo_prover/crates/prover/src/input/plain.rs @@ -9,8 +9,7 @@ use itertools::Itertools; use super::mem::{MemConfig, MemoryBuilder}; use super::state_transitions::StateTransitions; use super::vm_import::MemEntry; -use super::CairoInput; -use crate::input::MemorySegmentAddresses; +use super::{BuiltinsSegments, CairoInput}; // TODO(Ohad): remove dev_mode after adding the rest of the opcodes. /// Translates a plain casm into a CairoInput by running the program and extracting the memory and @@ -71,6 +70,13 @@ pub fn input_from_finished_runner(runner: CairoRunner, dev_mode: bool) -> CairoI val: bytemuck::cast(v.to_bytes_le()), }) }); + + let memory_segments = &runner + .get_air_public_input() + .expect("Unable to get public input from the runner") + .memory_segments; + let builtins_segments = BuiltinsSegments::from_memory_segments(memory_segments); + let trace = runner.relocated_trace.unwrap(); let trace = trace.iter().map(|t| t.clone().into()); @@ -84,9 +90,6 @@ pub fn input_from_finished_runner(runner: CairoRunner, dev_mode: bool) -> CairoI state_transitions, mem: mem.build(), public_mem_addresses, - range_check_builtin: MemorySegmentAddresses { - begin_addr: 24, - stop_ptr: 64, - }, + builtins_segments, } } diff --git a/stwo_cairo_prover/crates/prover/src/input/vm_import/mod.rs b/stwo_cairo_prover/crates/prover/src/input/vm_import/mod.rs index d1c098ff..7a3a2b82 100644 --- a/stwo_cairo_prover/crates/prover/src/input/vm_import/mod.rs +++ b/stwo_cairo_prover/crates/prover/src/input/vm_import/mod.rs @@ -14,7 +14,7 @@ use super::mem::MemConfig; use super::state_transitions::StateTransitions; use super::CairoInput; use crate::input::mem::MemoryBuilder; -use crate::input::MemorySegmentAddresses; +use crate::input::BuiltinsSegments; #[derive(Debug, Error)] pub enum VmImportError { @@ -26,9 +26,11 @@ pub enum VmImportError { NoMemorySegments, } +// TODO(Ohad): remove dev_mode after adding the rest of the opcodes. pub fn import_from_vm_output( pub_json: &Path, priv_json: &Path, + dev_mode: bool, ) -> Result { let _span = span!(Level::INFO, "import_from_vm_output").entered(); let pub_data_string = std::fs::read_to_string(pub_json)?; @@ -50,7 +52,8 @@ pub fn import_from_vm_output( let mut trace_file = std::io::BufReader::new(std::fs::File::open(trace_path)?); let mut mem_file = std::io::BufReader::new(std::fs::File::open(mem_path)?); let mut mem = MemoryBuilder::from_iter(mem_config, MemEntryIter(&mut mem_file)); - let state_transitions = StateTransitions::from_iter(TraceIter(&mut trace_file), &mut mem, true); + let state_transitions = + StateTransitions::from_iter(TraceIter(&mut trace_file), &mut mem, dev_mode); let public_mem_addresses = pub_data .public_memory @@ -58,14 +61,13 @@ pub fn import_from_vm_output( .map(|entry| entry.address as u32) .collect(); + let builtins_segments = BuiltinsSegments::from_memory_segments(&pub_data.memory_segments); + Ok(CairoInput { state_transitions, mem: mem.build(), public_mem_addresses, - range_check_builtin: MemorySegmentAddresses { - begin_addr: pub_data.memory_segments["range_check"].begin_addr as usize, - stop_ptr: pub_data.memory_segments["range_check"].stop_ptr as usize, - }, + builtins_segments, }) } @@ -135,7 +137,12 @@ pub mod tests { let mut d = PathBuf::from(env!("CARGO_MANIFEST_DIR")); d.push("test_data/large_cairo_input"); - import_from_vm_output(d.join("pub.json").as_path(), d.join("priv.json").as_path()).expect( + import_from_vm_output( + d.join("pub.json").as_path(), + d.join("priv.json").as_path(), + false, + ) + .expect( " Failed to read test files. Maybe git-lfs is not installed? Checkout README.md.", ) @@ -144,19 +151,24 @@ pub mod tests { pub fn small_cairo_input() -> CairoInput { let mut d = PathBuf::from(env!("CARGO_MANIFEST_DIR")); d.push("test_data/small_cairo_input"); - import_from_vm_output(d.join("pub.json").as_path(), d.join("priv.json").as_path()).expect( + import_from_vm_output( + d.join("pub.json").as_path(), + d.join("priv.json").as_path(), + false, + ) + .expect( " Failed to read test files. Maybe git-lfs is not installed? Checkout README.md.", ) } // TODO (Stav): Once all the components are in, verify the proof to ensure the sort was correct. - // TODO (Ohad): remove the following doc after deleting dev_mod. - /// When not ignored, the test passes only with dev_mod = false. #[ignore] #[test] fn test_read_from_large_files() { let input = large_cairo_input(); + + // Test opcode components. let components = input.state_transitions.casm_states_by_opcode; assert_eq!(components.generic_opcode.len(), 0); assert_eq!(components.add_ap_opcode_is_imm_f_op_1_base_fp_f.len(), 0); @@ -217,13 +229,57 @@ pub mod tests { assert_eq!(components.mul_opcode_is_small_f_is_imm_f.len(), 4583); assert_eq!(components.mul_opcode_is_small_f_is_imm_t.len(), 9047); assert_eq!(components.ret_opcode.len(), 49472); + + // Test builtins. + let builtins_segments = input.builtins_segments; + assert_eq!( + builtins_segments.range_check_bits_128_builtin, + Some((1715768, 1757348).into()) + ); + assert_eq!( + builtins_segments.range_check_bits_96_builtin, + Some((17706552, 17706552).into()) + ); + assert_eq!( + builtins_segments.bitwise_builtin, + Some((5942840, 5942840).into()) + ); + assert_eq!( + builtins_segments.add_mod_builtin, + Some((21900856, 21900856).into()) + ); + assert_eq!( + builtins_segments.ec_op_builtin, + Some((16428600, 16428747).into()) + ); + assert_eq!( + builtins_segments.ecdsa_builtin, + Some((5910072, 5910072).into()) + ); + assert_eq!( + builtins_segments.keccak_builtin, + Some((16657976, 16657976).into()) + ); + assert_eq!( + builtins_segments.mul_mod_builtin, + Some((23735864, 23735864).into()) + ); + assert_eq!( + builtins_segments.pedersen_builtin, + Some((1322552, 1337489).into()) + ); + assert_eq!( + builtins_segments.poseidon_builtin, + Some((16920120, 17444532).into()) + ); } - // When not ignored, the test passes only with dev_mod = false. #[ignore] #[test] fn test_read_from_small_files() { let input = small_cairo_input(); + + // Test opcode components. let components = input.state_transitions.casm_states_by_opcode; assert_eq!(components.generic_opcode.len(), 0); assert_eq!(components.add_ap_opcode_is_imm_f_op_1_base_fp_f.len(), 0); @@ -281,5 +337,42 @@ pub mod tests { assert_eq!(components.mul_opcode_is_small_f_is_imm_f.len(), 0); assert_eq!(components.mul_opcode_is_small_f_is_imm_t.len(), 0); assert_eq!(components.ret_opcode.len(), 462); + + // Test builtins. + let builtins_segments = input.builtins_segments; + assert_eq!( + builtins_segments.range_check_bits_128_builtin, + Some((6000, 6050).into()) + ); + assert_eq!( + builtins_segments.range_check_bits_96_builtin, + Some((68464, 68514).into()) + ); + assert_eq!( + builtins_segments.bitwise_builtin, + Some((22512, 22762).into()) + ); + assert_eq!( + builtins_segments.add_mod_builtin, + Some((84848, 84848).into()) + ); + assert_eq!(builtins_segments.ec_op_builtin, Some((63472, 63822).into())); + assert_eq!(builtins_segments.ecdsa_builtin, Some((22384, 22484).into())); + assert_eq!( + builtins_segments.keccak_builtin, + Some((64368, 65168).into()) + ); + assert_eq!( + builtins_segments.mul_mod_builtin, + Some((92016, 92016).into()) + ); + assert_eq!( + builtins_segments.pedersen_builtin, + Some((4464, 4614).into()) + ); + assert_eq!( + builtins_segments.poseidon_builtin, + Some((65392, 65692).into()) + ); } }