From dd827e75f9b025e23ffe43e9722e38c7d8bf3e6f Mon Sep 17 00:00:00 2001 From: Ohad Agadi Date: Sun, 1 Dec 2024 17:30:05 +0200 Subject: [PATCH] migrated opcodes air out of cairo air --- .../crates/prover/src/cairo_air/air.rs | 38 +- .../crates/prover/src/cairo_air/mod.rs | 3 +- .../prover/src/cairo_air/opcodes_air.rs | 529 ++++++++++++++++++ .../crates/prover/src/input/vm_import/mod.rs | 3 +- 4 files changed, 537 insertions(+), 36 deletions(-) create mode 100644 stwo_cairo_prover/crates/prover/src/cairo_air/opcodes_air.rs diff --git a/stwo_cairo_prover/crates/prover/src/cairo_air/air.rs b/stwo_cairo_prover/crates/prover/src/cairo_air/air.rs index 722c7f8b..e6a05ab2 100644 --- a/stwo_cairo_prover/crates/prover/src/cairo_air/air.rs +++ b/stwo_cairo_prover/crates/prover/src/cairo_air/air.rs @@ -15,17 +15,17 @@ use stwo_prover::core::prover::StarkProof; use stwo_prover::core::vcs::blake2_merkle::Blake2sMerkleChannel; use stwo_prover::core::vcs::ops::MerkleHasher; +use super::opcodes_air::{ + OpcodeClaim, OpcodeComponents, OpcodeInteractionClaim, OpcodesClaimGenerator, + OpcodesInteractionClaimGenerator, +}; use super::IS_FIRST_LOG_SIZES; use crate::components::memory::{memory_address_to_id, memory_id_to_big}; use crate::components::range_check_vector::{ range_check_19, range_check_4_3, range_check_7_2_5, range_check_9_9, }; -use crate::components::{ - add_ap_opcode_is_imm_f_op1_base_fp_f, add_ap_opcode_is_imm_f_op1_base_fp_t, - add_ap_opcode_is_imm_t_op1_base_fp_f, generic_opcode, ret_opcode, verify_instruction, -}; +use crate::components::verify_instruction; use crate::felt::split_f252; -use crate::input::state_transitions::StateTransitions; use crate::input::CairoInput; use crate::relations; @@ -39,34 +39,6 @@ pub struct CairoProof { // (Address, Id, Value) pub type PublicMemory = Vec<(u32, u32, [u32; 8])>; -#[derive(Serialize, Deserialize)] -pub struct OpcodeClaim { - add_ap_f_f: Vec, - add_ap_f_t: Vec, - add_ap_t_f: Vec, - ret: Vec, - generic: Vec, -} -impl OpcodeClaim { - pub fn mix_into(&self, channel: &mut impl Channel) { - self.add_ap_f_f.iter().for_each(|c| c.mix_into(channel)); - self.add_ap_f_t.iter().for_each(|c| c.mix_into(channel)); - self.add_ap_t_f.iter().for_each(|c| c.mix_into(channel)); - self.ret.iter().for_each(|c| c.mix_into(channel)); - self.generic.iter().for_each(|c| c.mix_into(channel)); - } - - pub fn log_sizes(&self) -> TreeVec> { - TreeVec::concat_cols(chain!( - self.add_ap_f_f.iter().map(|c| c.log_sizes()), - self.add_ap_f_t.iter().map(|c| c.log_sizes()), - self.add_ap_t_f.iter().map(|c| c.log_sizes()), - self.generic.iter().map(|c| c.log_sizes()), - self.ret.iter().map(|c| c.log_sizes()), - )) - } -} - #[derive(Serialize, Deserialize)] pub struct CairoClaim { pub public_data: PublicData, diff --git a/stwo_cairo_prover/crates/prover/src/cairo_air/mod.rs b/stwo_cairo_prover/crates/prover/src/cairo_air/mod.rs index 0eb66898..dfeb136a 100644 --- a/stwo_cairo_prover/crates/prover/src/cairo_air/mod.rs +++ b/stwo_cairo_prover/crates/prover/src/cairo_air/mod.rs @@ -1,4 +1,5 @@ pub mod air; +pub mod opcodes_air; use air::{lookup_sum, CairoClaimGenerator, CairoComponents, CairoInteractionElements, CairoProof}; use num_traits::Zero; @@ -17,7 +18,7 @@ use crate::input::CairoInput; const LOG_MAX_ROWS: u32 = 20; -const IS_FIRST_LOG_SIZES: [u32; 7] = [18, 4, 14, 19, 7, 6, 5]; +const IS_FIRST_LOG_SIZES: [u32; 16] = [19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4]; pub fn prove_cairo(input: CairoInput) -> Result, ProvingError> { let _span = span!(Level::INFO, "prove_cairo").entered(); let config = PcsConfig::default(); diff --git a/stwo_cairo_prover/crates/prover/src/cairo_air/opcodes_air.rs b/stwo_cairo_prover/crates/prover/src/cairo_air/opcodes_air.rs new file mode 100644 index 00000000..953c04bc --- /dev/null +++ b/stwo_cairo_prover/crates/prover/src/cairo_air/opcodes_air.rs @@ -0,0 +1,529 @@ +use itertools::{chain, Itertools}; +use num_traits::Zero; +use serde::{Deserialize, Serialize}; +use stwo_prover::constraint_framework::TraceLocationAllocator; +use stwo_prover::core::air::ComponentProver; +use stwo_prover::core::backend::simd::SimdBackend; +use stwo_prover::core::channel::Channel; +use stwo_prover::core::fields::qm31::{SecureField, QM31}; +use stwo_prover::core::pcs::{TreeBuilder, TreeVec}; +use stwo_prover::core::vcs::blake2_merkle::Blake2sMerkleChannel; + +use super::air::CairoInteractionElements; +use crate::components::{ + add_ap_opcode_is_imm_f_op1_base_fp_f, add_ap_opcode_is_imm_f_op1_base_fp_t, + add_ap_opcode_is_imm_t_op1_base_fp_f, generic_opcode, memory_address_to_id, memory_id_to_big, + range_check_19, range_check_9_9, ret_opcode, verify_instruction, +}; +use crate::input::state_transitions::StateTransitions; + +#[derive(Serialize, Deserialize)] +pub struct OpcodeClaim { + add_ap_f_f: Vec, + add_ap_f_t: Vec, + add_ap_t_f: Vec, + ret: Vec, + generic: Vec, +} +impl OpcodeClaim { + pub fn mix_into(&self, channel: &mut impl Channel) { + self.add_ap_f_f.iter().for_each(|c| c.mix_into(channel)); + self.add_ap_f_t.iter().for_each(|c| c.mix_into(channel)); + self.add_ap_t_f.iter().for_each(|c| c.mix_into(channel)); + self.ret.iter().for_each(|c| c.mix_into(channel)); + self.generic.iter().for_each(|c| c.mix_into(channel)); + } + + pub fn log_sizes(&self) -> TreeVec> { + TreeVec::concat_cols(chain!( + self.add_ap_f_f.iter().map(|c| c.log_sizes()), + self.add_ap_f_t.iter().map(|c| c.log_sizes()), + self.add_ap_t_f.iter().map(|c| c.log_sizes()), + self.generic.iter().map(|c| c.log_sizes()), + self.ret.iter().map(|c| c.log_sizes()), + )) + } +} + +pub struct OpcodesClaimGenerator { + add_ap_f_f: Vec, + add_ap_f_t: Vec, + add_ap_t_f: Vec, + generic: Vec, + ret: Vec, +} +impl OpcodesClaimGenerator { + pub fn new(input: StateTransitions) -> Self { + // TODO(Ohad): decide split sizes for opcode traces. + let mut add_ap_f_f = vec![]; + let mut add_ap_f_t = vec![]; + let mut add_ap_t_f = vec![]; + let mut generic = vec![]; + let mut ret = vec![]; + if !input + .casm_states_by_opcode + .add_ap_opcode_is_imm_f_op1_base_fp_f + .is_empty() + { + add_ap_f_f.push(add_ap_opcode_is_imm_f_op1_base_fp_f::ClaimGenerator::new( + input + .casm_states_by_opcode + .add_ap_opcode_is_imm_f_op1_base_fp_f, + )); + } + if !input + .casm_states_by_opcode + .add_ap_opcode_is_imm_f_op1_base_fp_t + .is_empty() + { + add_ap_f_t.push(add_ap_opcode_is_imm_f_op1_base_fp_t::ClaimGenerator::new( + input + .casm_states_by_opcode + .add_ap_opcode_is_imm_f_op1_base_fp_t, + )); + } + if !input + .casm_states_by_opcode + .add_ap_opcode_is_imm_t_op1_base_fp_f + .is_empty() + { + add_ap_t_f.push(add_ap_opcode_is_imm_t_op1_base_fp_f::ClaimGenerator::new( + input + .casm_states_by_opcode + .add_ap_opcode_is_imm_t_op1_base_fp_f, + )); + } + if !input.casm_states_by_opcode.generic_opcode.is_empty() { + generic.push(generic_opcode::ClaimGenerator::new( + input.casm_states_by_opcode.generic_opcode, + )); + } + if !input.casm_states_by_opcode.ret_opcode.is_empty() { + ret.push(ret_opcode::ClaimGenerator::new( + input.casm_states_by_opcode.ret_opcode, + )); + } + Self { + add_ap_f_f, + add_ap_f_t, + add_ap_t_f, + ret, + generic, + } + } + + pub fn write_trace( + self, + tree_builder: &mut TreeBuilder<'_, '_, SimdBackend, Blake2sMerkleChannel>, + memory_memory_address_to_id_trace_generator: &mut memory_address_to_id::ClaimGenerator, + memory_id_to_value_trace_generator: &mut memory_id_to_big::ClaimGenerator, + range_check_19_trace_generator: &mut range_check_19::ClaimGenerator, + range_check_9_9_trace_generator: &mut range_check_9_9::ClaimGenerator, + verify_instruction_trace_generator: &mut verify_instruction::ClaimGenerator, + ) -> (OpcodeClaim, OpcodesInteractionClaimGenerator) { + let (add_ap_f_f_claims, add_ap_f_f_interaction_gens) = self + .add_ap_f_f + .into_iter() + .map(|gen| { + gen.write_trace( + tree_builder, + memory_memory_address_to_id_trace_generator, + memory_id_to_value_trace_generator, + verify_instruction_trace_generator, + ) + }) + .unzip(); + let (add_ap_f_t_claims, add_ap_f_t_interaction_gens) = self + .add_ap_f_t + .into_iter() + .map(|gen| { + gen.write_trace( + tree_builder, + memory_memory_address_to_id_trace_generator, + memory_id_to_value_trace_generator, + verify_instruction_trace_generator, + ) + }) + .unzip(); + let (add_ap_t_f_claims, add_ap_t_f_interaction_gens) = self + .add_ap_t_f + .into_iter() + .map(|gen| { + gen.write_trace( + tree_builder, + memory_memory_address_to_id_trace_generator, + memory_id_to_value_trace_generator, + verify_instruction_trace_generator, + ) + }) + .unzip(); + let (generic_opcode_claims, generic_opcode_interaction_gens) = self + .generic + .into_iter() + .map(|gen| { + gen.write_trace( + tree_builder, + memory_memory_address_to_id_trace_generator, + memory_id_to_value_trace_generator, + range_check_19_trace_generator, + range_check_9_9_trace_generator, + verify_instruction_trace_generator, + ) + }) + .unzip(); + let (ret_claims, ret_interaction_gens) = self + .ret + .into_iter() + .map(|gen| { + gen.write_trace( + tree_builder, + memory_memory_address_to_id_trace_generator, + memory_id_to_value_trace_generator, + verify_instruction_trace_generator, + ) + }) + .unzip(); + ( + OpcodeClaim { + add_ap_f_f: add_ap_f_f_claims, + add_ap_f_t: add_ap_f_t_claims, + add_ap_t_f: add_ap_t_f_claims, + generic: generic_opcode_claims, + ret: ret_claims, + }, + OpcodesInteractionClaimGenerator { + add_ap_f_f: add_ap_f_f_interaction_gens, + add_ap_f_t: add_ap_f_t_interaction_gens, + add_ap_t_f: add_ap_t_f_interaction_gens, + generic_opcode_interaction_gens, + ret_interaction_gens, + }, + ) + } +} + +#[derive(Serialize, Deserialize)] +pub struct OpcodeInteractionClaim { + add_ap_f_f: Vec, + add_ap_f_t: Vec, + add_ap_t_f: Vec, + generic: Vec, + ret: Vec, +} +impl OpcodeInteractionClaim { + pub fn mix_into(&self, channel: &mut impl Channel) { + self.add_ap_f_f.iter().for_each(|c| c.mix_into(channel)); + self.add_ap_f_t.iter().for_each(|c| c.mix_into(channel)); + self.add_ap_t_f.iter().for_each(|c| c.mix_into(channel)); + self.generic.iter().for_each(|c| c.mix_into(channel)); + self.ret.iter().for_each(|c| c.mix_into(channel)); + } + + pub fn sum(&self) -> SecureField { + let mut sum = QM31::zero(); + for interaction_claim in &self.add_ap_f_f { + let (total_sum, claimed_sum) = interaction_claim.logup_sums; + sum += match claimed_sum { + Some((claimed_sum, ..)) => claimed_sum, + None => total_sum, + }; + } + for interaction_claim in &self.add_ap_f_t { + let (total_sum, claimed_sum) = interaction_claim.logup_sums; + sum += match claimed_sum { + Some((claimed_sum, ..)) => claimed_sum, + None => total_sum, + }; + } + for interaction_claim in &self.add_ap_t_f { + let (total_sum, claimed_sum) = interaction_claim.logup_sums; + sum += match claimed_sum { + Some((claimed_sum, ..)) => claimed_sum, + None => total_sum, + }; + } + for interaction_claim in &self.generic { + let (total_sum, claimed_sum) = interaction_claim.logup_sums; + sum += match claimed_sum { + Some((claimed_sum, ..)) => claimed_sum, + None => total_sum, + }; + } + for interaction_claim in &self.ret { + let (total_sum, claimed_sum) = interaction_claim.logup_sums; + sum += match claimed_sum { + Some((claimed_sum, ..)) => claimed_sum, + None => total_sum, + }; + } + sum + } +} + +pub struct OpcodesInteractionClaimGenerator { + add_ap_f_f: Vec, + add_ap_f_t: Vec, + add_ap_t_f: Vec, + generic_opcode_interaction_gens: Vec, + ret_interaction_gens: Vec, +} +impl OpcodesInteractionClaimGenerator { + pub fn write_interaction_trace( + self, + tree_builder: &mut TreeBuilder<'_, '_, SimdBackend, Blake2sMerkleChannel>, + interaction_elements: &CairoInteractionElements, + ) -> OpcodeInteractionClaim { + let add_ap_f_f_interaction_claims = self + .add_ap_f_f + .into_iter() + .map(|gen| { + gen.write_interaction_trace( + tree_builder, + &interaction_elements.memory_memory_address_to_id, + &interaction_elements.memory_id_to_value, + &interaction_elements.opcodes, + &interaction_elements.verify_instruction, + ) + }) + .collect(); + let add_ap_f_t_interaction_claims = self + .add_ap_f_t + .into_iter() + .map(|gen| { + gen.write_interaction_trace( + tree_builder, + &interaction_elements.memory_memory_address_to_id, + &interaction_elements.memory_id_to_value, + &interaction_elements.opcodes, + &interaction_elements.verify_instruction, + ) + }) + .collect(); + let add_ap_t_f_interaction_claims = self + .add_ap_t_f + .into_iter() + .map(|gen| { + gen.write_interaction_trace( + tree_builder, + &interaction_elements.memory_memory_address_to_id, + &interaction_elements.memory_id_to_value, + &interaction_elements.opcodes, + &interaction_elements.verify_instruction, + ) + }) + .collect(); + let generic_opcode_interaction_claims = self + .generic_opcode_interaction_gens + .into_iter() + .map(|gen| { + gen.write_interaction_trace( + tree_builder, + &interaction_elements.memory_memory_address_to_id, + &interaction_elements.memory_id_to_value, + &interaction_elements.opcodes, + &interaction_elements.range_check_19, + &interaction_elements.range_check_9_9, + &interaction_elements.verify_instruction, + ) + }) + .collect(); + let ret_interaction_claims = self + .ret_interaction_gens + .into_iter() + .map(|gen| { + gen.write_interaction_trace( + tree_builder, + &interaction_elements.memory_memory_address_to_id, + &interaction_elements.memory_id_to_value, + &interaction_elements.opcodes, + &interaction_elements.verify_instruction, + ) + }) + .collect(); + OpcodeInteractionClaim { + add_ap_f_f: add_ap_f_f_interaction_claims, + add_ap_f_t: add_ap_f_t_interaction_claims, + add_ap_t_f: add_ap_t_f_interaction_claims, + generic: generic_opcode_interaction_claims, + ret: ret_interaction_claims, + } + } +} + +pub struct OpcodeComponents { + add_ap_f_f: Vec, + add_ap_f_t: Vec, + add_ap_t_f: Vec, + generic: Vec, + ret: Vec, +} +impl OpcodeComponents { + pub fn new( + tree_span_provider: &mut TraceLocationAllocator, + claim: &OpcodeClaim, + interaction_elements: &CairoInteractionElements, + interaction_claim: &OpcodeInteractionClaim, + ) -> Self { + let add_ap_f_f_components = claim + .add_ap_f_f + .iter() + .zip(interaction_claim.add_ap_f_f.iter()) + .map(|(&claim, &interaction_claim)| { + add_ap_opcode_is_imm_f_op1_base_fp_f::Component::new( + tree_span_provider, + add_ap_opcode_is_imm_f_op1_base_fp_f::Eval { + claim, + memoryaddresstoid_lookup_elements: interaction_elements + .memory_memory_address_to_id + .clone(), + memoryidtobig_lookup_elements: interaction_elements + .memory_id_to_value + .clone(), + opcodes_lookup_elements: interaction_elements.opcodes.clone(), + verifyinstruction_lookup_elements: interaction_elements + .verify_instruction + .clone(), + }, + interaction_claim.logup_sums, + ) + }) + .collect_vec(); + let add_ap_f_t_components = claim + .add_ap_f_t + .iter() + .zip(interaction_claim.add_ap_f_t.iter()) + .map(|(&claim, &interaction_claim)| { + add_ap_opcode_is_imm_f_op1_base_fp_t::Component::new( + tree_span_provider, + add_ap_opcode_is_imm_f_op1_base_fp_t::Eval { + claim, + memoryaddresstoid_lookup_elements: interaction_elements + .memory_memory_address_to_id + .clone(), + memoryidtobig_lookup_elements: interaction_elements + .memory_id_to_value + .clone(), + opcodes_lookup_elements: interaction_elements.opcodes.clone(), + verifyinstruction_lookup_elements: interaction_elements + .verify_instruction + .clone(), + }, + interaction_claim.logup_sums, + ) + }) + .collect_vec(); + let add_ap_t_f_components = claim + .add_ap_t_f + .iter() + .zip(interaction_claim.add_ap_t_f.iter()) + .map(|(&claim, &interaction_claim)| { + add_ap_opcode_is_imm_t_op1_base_fp_f::Component::new( + tree_span_provider, + add_ap_opcode_is_imm_t_op1_base_fp_f::Eval { + claim, + memoryaddresstoid_lookup_elements: interaction_elements + .memory_memory_address_to_id + .clone(), + memoryidtobig_lookup_elements: interaction_elements + .memory_id_to_value + .clone(), + opcodes_lookup_elements: interaction_elements.opcodes.clone(), + verifyinstruction_lookup_elements: interaction_elements + .verify_instruction + .clone(), + }, + interaction_claim.logup_sums, + ) + }) + .collect_vec(); + let generic_components = claim + .generic + .iter() + .zip(interaction_claim.generic.iter()) + .map(|(&claim, &interaction_claim)| { + generic_opcode::Component::new( + tree_span_provider, + generic_opcode::Eval { + claim, + memoryaddresstoid_lookup_elements: interaction_elements + .memory_memory_address_to_id + .clone(), + memoryidtobig_lookup_elements: interaction_elements + .memory_id_to_value + .clone(), + opcodes_lookup_elements: interaction_elements.opcodes.clone(), + rangecheck_19_lookup_elements: interaction_elements.range_check_19.clone(), + rangecheck_9_9_lookup_elements: interaction_elements + .range_check_9_9 + .clone(), + verifyinstruction_lookup_elements: interaction_elements + .verify_instruction + .clone(), + }, + interaction_claim.logup_sums, + ) + }) + .collect_vec(); + let ret_components = claim + .ret + .iter() + .zip(interaction_claim.ret.iter()) + .map(|(&claim, &interaction_claim)| { + ret_opcode::Component::new( + tree_span_provider, + ret_opcode::Eval { + claim, + memoryaddresstoid_lookup_elements: interaction_elements + .memory_memory_address_to_id + .clone(), + memoryidtobig_lookup_elements: interaction_elements + .memory_id_to_value + .clone(), + verifyinstruction_lookup_elements: interaction_elements + .verify_instruction + .clone(), + opcodes_lookup_elements: interaction_elements.opcodes.clone(), + }, + interaction_claim.logup_sums, + ) + }) + .collect_vec(); + Self { + add_ap_f_f: add_ap_f_f_components, + add_ap_f_t: add_ap_f_t_components, + add_ap_t_f: add_ap_t_f_components, + generic: generic_components, + ret: ret_components, + } + } + + pub fn provers(&self) -> Vec<&dyn ComponentProver> { + let mut vec: Vec<&dyn ComponentProver> = vec![]; + vec.extend( + self.add_ap_f_f + .iter() + .map(|component| component as &dyn ComponentProver), + ); + vec.extend( + self.add_ap_f_t + .iter() + .map(|component| component as &dyn ComponentProver), + ); + vec.extend( + self.add_ap_t_f + .iter() + .map(|component| component as &dyn ComponentProver), + ); + vec.extend( + self.generic + .iter() + .map(|component| component as &dyn ComponentProver), + ); + vec.extend( + self.ret + .iter() + .map(|component| component as &dyn ComponentProver), + ); + vec + } +} diff --git a/stwo_cairo_prover/crates/prover/src/input/vm_import/mod.rs b/stwo_cairo_prover/crates/prover/src/input/vm_import/mod.rs index aa30c203..ef6d2472 100644 --- a/stwo_cairo_prover/crates/prover/src/input/vm_import/mod.rs +++ b/stwo_cairo_prover/crates/prover/src/input/vm_import/mod.rs @@ -48,8 +48,7 @@ pub fn import_from_vm_output( let mut trace_file = std::io::BufReader::new(std::fs::File::open(trace_path)?); let mut mem_file = std::io::BufReader::new(std::fs::File::open(mem_path)?); let mut mem = MemoryBuilder::from_iter(mem_config, MemEntryIter(&mut mem_file)); - let state_transitions = - StateTransitions::from_iter(TraceIter(&mut trace_file), &mut mem, false); + let state_transitions = StateTransitions::from_iter(TraceIter(&mut trace_file), &mut mem, true); let public_mem_addresses = pub_data .public_memory