Skip to content

Commit

Permalink
delete redudent struct from json.rs
Browse files Browse the repository at this point in the history
  • Loading branch information
Stavbe committed Dec 19, 2024
1 parent 7b89d37 commit b366d04
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 113 deletions.
16 changes: 2 additions & 14 deletions stwo_cairo_prover/crates/prover/src/input/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use cairo_vm::air_public_input::MemorySegmentAddresses;
use mem::Memory;
use serde::{Deserialize, Serialize};
use state_transitions::StateTransitions;

mod decode;
Expand All @@ -19,17 +19,5 @@ pub struct CairoInput {
pub public_mem_addresses: Vec<u32>,

// Builtins.
pub range_check_builtin: SegmentAddrs,
}

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct SegmentAddrs {
pub begin_addr: u32,
pub end_addr: u32,
}

impl SegmentAddrs {
pub fn addresses(&self) -> Vec<u32> {
(self.begin_addr..self.end_addr).collect()
}
pub range_check_builtin: MemorySegmentAddresses,
}
7 changes: 4 additions & 3 deletions stwo_cairo_prover/crates/prover/src/input/plain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ use itertools::Itertools;
use super::mem::{MemConfig, MemoryBuilder};
use super::state_transitions::StateTransitions;
use super::vm_import::MemEntry;
use super::{CairoInput, SegmentAddrs};
use super::CairoInput;
use crate::input::MemorySegmentAddresses;

// TODO(Ohad): remove dev_mode after adding the rest of the opcodes.
/// Translates a plain casm into a CairoInput by running the program and extracting the memory and
Expand Down Expand Up @@ -83,9 +84,9 @@ pub fn input_from_finished_runner(runner: CairoRunner, dev_mode: bool) -> CairoI
state_transitions,
mem: mem.build(),
public_mem_addresses,
range_check_builtin: SegmentAddrs {
range_check_builtin: MemorySegmentAddresses {
begin_addr: 24,
end_addr: 64,
stop_ptr: 64,
},
}
}
96 changes: 6 additions & 90 deletions stwo_cairo_prover/crates/prover/src/input/vm_import/json.rs
Original file line number Diff line number Diff line change
@@ -1,99 +1,15 @@
use std::collections::BTreeMap;
use cairo_vm::air_private_input::{PrivateInputPair, PrivateInputValue};
use serde::{Deserialize, Serialize};

use serde::{Deserialize, Deserializer, Serialize, Serializer};

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PublicInput {
pub layout: String,
pub rc_min: u64,
pub rc_max: u64,
pub n_steps: u64,
pub memory_segments: BTreeMap<String, Segment>,
pub public_memory: Vec<PublicMemEntry>,
pub dynamic_params: Option<()>,
}

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Segment {
pub begin_addr: u64,
pub stop_ptr: u64,
}

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PublicMemEntry {
pub address: u64,
pub value: FeltValue,
pub page: u64,
}
type PedersenData = PrivateInputPair;
type RangeCheckData = PrivateInputValue;

// Can't use cairo_vm::air_private_input::AirPrivateInputSerializable since trace_path and mem_path
// are private.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PrivateInput {
pub trace_path: String,
pub memory_path: String,
pub pedersen: Vec<PedersenData>,
pub range_check: Vec<RangeCheckData>,
}

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PedersenData {
pub index: u64,
pub x: FeltValue,
pub y: FeltValue,
}

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct RangeCheckData {
pub index: u64,
pub value: FeltValue,
}

#[derive(Clone, Debug)]
pub struct FeltValue([u8; 32]);

impl Serialize for FeltValue {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
// Convert the [u8; 32] to a hexadecimal string
let hex_string = format!("0x{}", hex::encode(self.0));
serializer.serialize_str(&hex_string)
}
}

impl<'de> Deserialize<'de> for FeltValue {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let hex_string = String::deserialize(deserializer)?;

// Remove the "0x" prefix if present
let hex_str = hex_string.strip_prefix("0x").unwrap_or(&hex_string);
let hex_str = format!("{:0>64}", hex_str);

// Convert the hexadecimal string back into a [u8; 32]
let mut bytes = [0u8; 32];
hex::decode_to_slice(hex_str, &mut bytes).map_err(serde::de::Error::custom)?;

Ok(FeltValue(bytes))
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_felt_value_serde() {
let felt_value = FeltValue([0x12; 32]);
let json = sonic_rs::to_string(&felt_value).unwrap();
assert_eq!(
json,
r#""0x1212121212121212121212121212121212121212121212121212121212121212""#
);

let deserialized: FeltValue = sonic_rs::from_str(&json).unwrap();
assert_eq!(felt_value.0, deserialized.0);
}
}
14 changes: 8 additions & 6 deletions stwo_cairo_prover/crates/prover/src/input/vm_import/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,17 @@ use std::io::Read;
use std::path::Path;

use bytemuck::{bytes_of_mut, Pod, Zeroable};
use cairo_vm::air_public_input::PublicInput;
use cairo_vm::vm::trace::trace_entry::RelocatedTraceEntry;
use json::{PrivateInput, PublicInput};
use json::PrivateInput;
use thiserror::Error;
use tracing::{span, Level};

use super::mem::MemConfig;
use super::state_transitions::StateTransitions;
use super::CairoInput;
use crate::input::mem::MemoryBuilder;
use crate::input::SegmentAddrs;
use crate::input::MemorySegmentAddresses;

#[derive(Debug, Error)]
pub enum VmImportError {
Expand All @@ -30,7 +31,8 @@ pub fn import_from_vm_output(
priv_json: &Path,
) -> Result<CairoInput, VmImportError> {
let _span = span!(Level::INFO, "import_from_vm_output").entered();
let pub_data: PublicInput = sonic_rs::from_str(&std::fs::read_to_string(pub_json)?)?;
let pub_data_string = std::fs::read_to_string(pub_json)?;
let pub_data: PublicInput<'_> = sonic_rs::from_str(&pub_data_string)?;
let priv_data: PrivateInput = sonic_rs::from_str(&std::fs::read_to_string(priv_json)?)?;

let end_addr = pub_data
Expand Down Expand Up @@ -60,9 +62,9 @@ pub fn import_from_vm_output(
state_transitions,
mem: mem.build(),
public_mem_addresses,
range_check_builtin: SegmentAddrs {
begin_addr: pub_data.memory_segments["range_check"].begin_addr as u32,
end_addr: pub_data.memory_segments["range_check"].stop_ptr as u32,
range_check_builtin: MemorySegmentAddresses {
begin_addr: pub_data.memory_segments["range_check"].begin_addr as usize,
stop_ptr: pub_data.memory_segments["range_check"].stop_ptr as usize,
},
})
}
Expand Down

0 comments on commit b366d04

Please sign in to comment.