diff --git a/stwo_cairo_prover/crates/prover/src/input/mod.rs b/stwo_cairo_prover/crates/prover/src/input/mod.rs index ba13e800..be6994cf 100644 --- a/stwo_cairo_prover/crates/prover/src/input/mod.rs +++ b/stwo_cairo_prover/crates/prover/src/input/mod.rs @@ -1,5 +1,5 @@ +use cairo_vm::air_public_input::MemorySegmentAddresses; use mem::Memory; -use serde::{Deserialize, Serialize}; use state_transitions::StateTransitions; mod decode; @@ -19,17 +19,5 @@ pub struct CairoInput { pub public_mem_addresses: Vec, // Builtins. - pub range_check_builtin: SegmentAddrs, -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct SegmentAddrs { - pub begin_addr: u32, - pub end_addr: u32, -} - -impl SegmentAddrs { - pub fn addresses(&self) -> Vec { - (self.begin_addr..self.end_addr).collect() - } + pub range_check_builtin: MemorySegmentAddresses, } diff --git a/stwo_cairo_prover/crates/prover/src/input/plain.rs b/stwo_cairo_prover/crates/prover/src/input/plain.rs index 8eadffbc..7f9b3d6e 100644 --- a/stwo_cairo_prover/crates/prover/src/input/plain.rs +++ b/stwo_cairo_prover/crates/prover/src/input/plain.rs @@ -9,7 +9,8 @@ use itertools::Itertools; use super::mem::{MemConfig, MemoryBuilder}; use super::state_transitions::StateTransitions; use super::vm_import::MemEntry; -use super::{CairoInput, SegmentAddrs}; +use super::CairoInput; +use crate::input::MemorySegmentAddresses; // TODO(Ohad): remove dev_mode after adding the rest of the opcodes. /// Translates a plain casm into a CairoInput by running the program and extracting the memory and @@ -83,9 +84,9 @@ pub fn input_from_finished_runner(runner: CairoRunner, dev_mode: bool) -> CairoI state_transitions, mem: mem.build(), public_mem_addresses, - range_check_builtin: SegmentAddrs { + range_check_builtin: MemorySegmentAddresses { begin_addr: 24, - end_addr: 64, + stop_ptr: 64, }, } } diff --git a/stwo_cairo_prover/crates/prover/src/input/vm_import/json.rs b/stwo_cairo_prover/crates/prover/src/input/vm_import/json.rs index e44001fe..6f068064 100644 --- a/stwo_cairo_prover/crates/prover/src/input/vm_import/json.rs +++ b/stwo_cairo_prover/crates/prover/src/input/vm_import/json.rs @@ -1,99 +1,10 @@ -use std::collections::BTreeMap; - -use serde::{Deserialize, Deserializer, Serialize, Serializer}; - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct PublicInput { - pub layout: String, - pub rc_min: u64, - pub rc_max: u64, - pub n_steps: u64, - pub memory_segments: BTreeMap, - pub public_memory: Vec, - pub dynamic_params: Option<()>, -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct Segment { - pub begin_addr: u64, - pub stop_ptr: u64, -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct PublicMemEntry { - pub address: u64, - pub value: FeltValue, - pub page: u64, -} +use serde::{Deserialize, Serialize}; +// TODO(Stav): Replace with original struct once fields are public. +/// Struct to store Cairo private input. +/// Replicated from `cairo_vm::air_private_input::AirPrivateInputSerializable`. #[derive(Clone, Debug, Serialize, Deserialize)] pub struct PrivateInput { pub trace_path: String, pub memory_path: String, - pub pedersen: Vec, - pub range_check: Vec, -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct PedersenData { - pub index: u64, - pub x: FeltValue, - pub y: FeltValue, -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct RangeCheckData { - pub index: u64, - pub value: FeltValue, -} - -#[derive(Clone, Debug)] -pub struct FeltValue([u8; 32]); - -impl Serialize for FeltValue { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - // Convert the [u8; 32] to a hexadecimal string - let hex_string = format!("0x{}", hex::encode(self.0)); - serializer.serialize_str(&hex_string) - } -} - -impl<'de> Deserialize<'de> for FeltValue { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - let hex_string = String::deserialize(deserializer)?; - - // Remove the "0x" prefix if present - let hex_str = hex_string.strip_prefix("0x").unwrap_or(&hex_string); - let hex_str = format!("{:0>64}", hex_str); - - // Convert the hexadecimal string back into a [u8; 32] - let mut bytes = [0u8; 32]; - hex::decode_to_slice(hex_str, &mut bytes).map_err(serde::de::Error::custom)?; - - Ok(FeltValue(bytes)) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_felt_value_serde() { - let felt_value = FeltValue([0x12; 32]); - let json = sonic_rs::to_string(&felt_value).unwrap(); - assert_eq!( - json, - r#""0x1212121212121212121212121212121212121212121212121212121212121212""# - ); - - let deserialized: FeltValue = sonic_rs::from_str(&json).unwrap(); - assert_eq!(felt_value.0, deserialized.0); - } } diff --git a/stwo_cairo_prover/crates/prover/src/input/vm_import/mod.rs b/stwo_cairo_prover/crates/prover/src/input/vm_import/mod.rs index c7f28f3e..dff25c1a 100644 --- a/stwo_cairo_prover/crates/prover/src/input/vm_import/mod.rs +++ b/stwo_cairo_prover/crates/prover/src/input/vm_import/mod.rs @@ -4,8 +4,9 @@ use std::io::Read; use std::path::Path; use bytemuck::{bytes_of_mut, Pod, Zeroable}; +use cairo_vm::air_public_input::PublicInput; use cairo_vm::vm::trace::trace_entry::RelocatedTraceEntry; -use json::{PrivateInput, PublicInput}; +use json::PrivateInput; use thiserror::Error; use tracing::{span, Level}; @@ -13,7 +14,7 @@ use super::mem::MemConfig; use super::state_transitions::StateTransitions; use super::CairoInput; use crate::input::mem::MemoryBuilder; -use crate::input::SegmentAddrs; +use crate::input::MemorySegmentAddresses; #[derive(Debug, Error)] pub enum VmImportError { @@ -32,7 +33,8 @@ pub fn import_from_vm_output( dev_mod: bool, ) -> Result { let _span = span!(Level::INFO, "import_from_vm_output").entered(); - let pub_data: PublicInput = sonic_rs::from_str(&std::fs::read_to_string(pub_json)?)?; + let pub_data_string = std::fs::read_to_string(pub_json)?; + let pub_data: PublicInput<'_> = sonic_rs::from_str(&pub_data_string)?; let priv_data: PrivateInput = sonic_rs::from_str(&std::fs::read_to_string(priv_json)?)?; let end_addr = pub_data @@ -63,9 +65,9 @@ pub fn import_from_vm_output( state_transitions, mem: mem.build(), public_mem_addresses, - range_check_builtin: SegmentAddrs { - begin_addr: pub_data.memory_segments["range_check"].begin_addr as u32, - end_addr: pub_data.memory_segments["range_check"].stop_ptr as u32, + range_check_builtin: MemorySegmentAddresses { + begin_addr: pub_data.memory_segments["range_check"].begin_addr as usize, + stop_ptr: pub_data.memory_segments["range_check"].stop_ptr as usize, }, }) }