Skip to content

Commit

Permalink
Add CairoSerialize trait and implmentations (#251)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewmilson authored Dec 16, 2024
1 parent 7cda037 commit a7c7dc4
Show file tree
Hide file tree
Showing 74 changed files with 1,064 additions and 584 deletions.
23 changes: 23 additions & 0 deletions stwo_cairo_prover/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions stwo_cairo_prover/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
members = [
"crates/adapted_prover",
"crates/prover",
"crates/cairo-serialize",
"crates/cairo-serialize-derive",
"crates/utils",
"crates/vm_runner",
"crates/prover_types",
Expand Down
6 changes: 4 additions & 2 deletions stwo_cairo_prover/crates/adapted_prover/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use stwo_cairo_prover::input::vm_import::{import_from_vm_output, VmImportError};
use stwo_cairo_prover::input::CairoInput;
use stwo_cairo_utils::logging_utils::init_logging;
use stwo_prover::core::prover::ProvingError;
use stwo_prover::core::vcs::blake2_merkle::Blake2sMerkleHasher;
use stwo_prover::core::vcs::blake2_merkle::{Blake2sMerkleChannel, Blake2sMerkleHasher};
use thiserror::Error;
use tracing::{span, Level};

Expand Down Expand Up @@ -71,7 +71,9 @@ fn run(args: impl Iterator<Item = String>) -> Result<CairoProof<Blake2sMerkleHas
let casm_states_by_opcode_count = &vm_output.state_transitions.casm_states_by_opcode.counts();
log::info!("Casm states by opcode count: {casm_states_by_opcode_count:?}");

let proof = prove_cairo(vm_output, args.debug_lookup, args.display_components)?;
// TODO(Ohad): Propogate hash from CLI args.
let proof =
prove_cairo::<Blake2sMerkleChannel>(vm_output, args.debug_lookup, args.display_components)?;

// TODO(yuval): This is just some serialization for the sake of serialization. Find the right
// way to serialize the proof.
Expand Down
12 changes: 12 additions & 0 deletions stwo_cairo_prover/crates/cairo-serialize-derive/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[package]
name = "stwo-cairo-serialize-derive"
version = "0.1.0"
edition = "2021"

[lib]
proc-macro = true

[dependencies]
proc-macro2 = "1.0"
quote = "1.0"
syn = "2.0"
54 changes: 54 additions & 0 deletions stwo_cairo_prover/crates/cairo-serialize-derive/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, Data, DeriveInput, Fields};

#[proc_macro_derive(CairoSerialize)]
pub fn derive_cairo_serialize(input: TokenStream) -> TokenStream {
// Parse the input tokens into a syntax tree.
let input = parse_macro_input!(input as DeriveInput);

let struct_name = input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();

// Extract the fields of the struct.
let fields = match input.data {
Data::Struct(ref data_struct) => match &data_struct.fields {
Fields::Named(ref fields_named) => &fields_named.named,
Fields::Unnamed(_) | Fields::Unit => {
return syn::Error::new_spanned(
struct_name,
"CairoSerialize can only be derived for structs with named fields.",
)
.to_compile_error()
.into();
}
},
_ => {
return syn::Error::new_spanned(
struct_name,
"CairoSerialize can only be derived for structs.",
)
.to_compile_error()
.into();
}
};

// Generate code to serialize each field in the order they appear.
let serialize_body = fields.iter().map(|f| {
let field_name = &f.ident;
quote! {
CairoSerialize::serialize(&self.#field_name, output);
}
});

// Implement `CairoSerialize` for the type.
let expanded = quote! {
impl #impl_generics ::stwo_cairo_serialize::CairoSerialize for #struct_name #ty_generics #where_clause {
fn serialize(&self, output: &mut Vec<::starknet_ff::FieldElement>) {
#(#serialize_body)*
}
}
};

TokenStream::from(expanded)
}
9 changes: 9 additions & 0 deletions stwo_cairo_prover/crates/cairo-serialize/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
[package]
name = "stwo-cairo-serialize"
version = "0.1.0"
edition = "2021"

[dependencies]
starknet-ff.workspace = true
stwo-cairo-serialize-derive = { path = "../cairo-serialize-derive" }
stwo-prover.workspace = true
185 changes: 185 additions & 0 deletions stwo_cairo_prover/crates/cairo-serialize/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
use starknet_ff::FieldElement;
// Make derive macro available.
pub use stwo_cairo_serialize_derive::CairoSerialize;
use stwo_prover::core::fields::m31::BaseField;
use stwo_prover::core::fields::qm31::SecureField;
use stwo_prover::core::fri::{FriLayerProof, FriProof};
use stwo_prover::core::pcs::CommitmentSchemeProof;
use stwo_prover::core::poly::line::LinePoly;
use stwo_prover::core::prover::StarkProof;
use stwo_prover::core::vcs::ops::MerkleHasher;
use stwo_prover::core::vcs::prover::MerkleDecommitment;

/// Serializes types into a format for deserialization by corresponding types in a Cairo program.
pub trait CairoSerialize {
fn serialize(&self, output: &mut Vec<FieldElement>);
}

impl CairoSerialize for u32 {
fn serialize(&self, output: &mut Vec<FieldElement>) {
output.push((*self).into());
}
}

impl CairoSerialize for u64 {
fn serialize(&self, output: &mut Vec<FieldElement>) {
output.push((*self).into());
}
}

impl CairoSerialize for usize {
fn serialize(&self, output: &mut Vec<FieldElement>) {
output.push((*self).into());
}
}

impl CairoSerialize for BaseField {
fn serialize(&self, output: &mut Vec<FieldElement>) {
output.push(self.0.into());
}
}

impl CairoSerialize for SecureField {
fn serialize(&self, output: &mut Vec<FieldElement>) {
output.extend(self.to_m31_array().map(|c| FieldElement::from(c.0)));
}
}

impl<H: MerkleHasher> CairoSerialize for MerkleDecommitment<H>
where
H::Hash: CairoSerialize,
{
fn serialize(&self, output: &mut Vec<FieldElement>) {
let Self {
hash_witness,
column_witness,
} = self;
hash_witness.serialize(output);
column_witness.serialize(output);
}
}

impl CairoSerialize for LinePoly {
fn serialize(&self, output: &mut Vec<FieldElement>) {
(**self).serialize(output);
output.push((self.len().ilog2()).into());
}
}

impl<H: MerkleHasher> CairoSerialize for FriLayerProof<H>
where
H::Hash: CairoSerialize,
{
fn serialize(&self, output: &mut Vec<FieldElement>) {
let Self {
fri_witness,
decommitment,
commitment,
} = self;
fri_witness.serialize(output);
decommitment.serialize(output);
commitment.serialize(output);
}
}

impl<H: MerkleHasher> CairoSerialize for FriProof<H>
where
H::Hash: CairoSerialize,
{
fn serialize(&self, output: &mut Vec<FieldElement>) {
let Self {
first_layer,
inner_layers,
last_layer_poly,
} = self;
first_layer.serialize(output);
inner_layers.serialize(output);
last_layer_poly.serialize(output);
}
}

impl CairoSerialize for FieldElement {
fn serialize(&self, output: &mut Vec<FieldElement>) {
output.push(*self);
}
}

impl<H: MerkleHasher> CairoSerialize for CommitmentSchemeProof<H>
where
H::Hash: CairoSerialize,
{
fn serialize(&self, output: &mut Vec<FieldElement>) {
let Self {
commitments,
sampled_values,
decommitments,
queried_values,
proof_of_work,
fri_proof,
} = self;
commitments.serialize(output);
sampled_values.serialize(output);
decommitments.serialize(output);
queried_values.serialize(output);
output.push((*proof_of_work).into());
fri_proof.serialize(output);
}
}

impl<H: MerkleHasher> CairoSerialize for StarkProof<H>
where
H::Hash: CairoSerialize,
{
fn serialize(&self, output: &mut Vec<FieldElement>) {
let Self(commitment_scheme_proof) = self;
commitment_scheme_proof.serialize(output);
}
}

impl<T: CairoSerialize> CairoSerialize for Option<T> {
fn serialize(&self, output: &mut Vec<FieldElement>) {
match self {
Some(v) => {
output.push(FieldElement::ZERO);
v.serialize(output);
}
None => output.push(FieldElement::ONE),
}
}
}

impl<T: CairoSerialize> CairoSerialize for [T] {
fn serialize(&self, output: &mut Vec<FieldElement>) {
output.push(self.len().into());
self.iter().for_each(|v| v.serialize(output));
}
}

impl<T: CairoSerialize, const N: usize> CairoSerialize for [T; N] {
fn serialize(&self, output: &mut Vec<FieldElement>) {
self.iter().for_each(|v| v.serialize(output));
}
}

impl<T: CairoSerialize> CairoSerialize for Vec<T> {
fn serialize(&self, output: &mut Vec<FieldElement>) {
(**self).serialize(output);
}
}

impl<T0: CairoSerialize, T1: CairoSerialize> CairoSerialize for (T0, T1) {
fn serialize(&self, output: &mut Vec<FieldElement>) {
let (v0, v1) = self;
v0.serialize(output);
v1.serialize(output);
}
}

impl<T0: CairoSerialize, T1: CairoSerialize, T2: CairoSerialize> CairoSerialize for (T0, T1, T2) {
fn serialize(&self, output: &mut Vec<FieldElement>) {
let (v0, v1, v2) = self;
v0.serialize(output);
v1.serialize(output);
v2.serialize(output);
}
}
9 changes: 6 additions & 3 deletions stwo_cairo_prover/crates/prover/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,17 @@ cairo-vm.workspace = true
hex.workspace = true
itertools.workspace = true
num-traits.workspace = true
paste.workspace = true
prover_types = { path = "../prover_types" }
rayon = { version = "1.10.0", optional = true }
serde.workspace = true
sonic-rs.workspace = true
starknet-ff.workspace = true
stwo_cairo_utils = { path = "../utils" }
stwo-cairo-serialize = { path = "../cairo-serialize" }
stwo-prover.workspace = true
thiserror.workspace = true
tracing.workspace = true
paste.workspace = true
prover_types = { path = "../prover_types" }
rayon = { version = "1.10.0", optional = true }

[dev-dependencies]
cairo-lang-casm.workspace = true
Expand Down
Loading

0 comments on commit a7c7dc4

Please sign in to comment.