From 91e9a7ce58ad284978aecfdc54fe390bed4d79fa Mon Sep 17 00:00:00 2001 From: Ohad Agadi Date: Sun, 1 Dec 2024 18:04:50 +0200 Subject: [PATCH] assert eq gen --- .../prover/src/cairo_air/opcodes_air.rs | 265 ++++++++- .../component.rs | 172 ++++++ .../mod.rs | 5 + .../prover.rs | 420 ++++++++++++++ .../component.rs | 161 ++++++ .../mod.rs | 5 + .../prover.rs | 383 +++++++++++++ .../component.rs | 198 +++++++ .../mod.rs | 5 + .../prover.rs | 525 ++++++++++++++++++ .../crates/prover/src/components/mod.rs | 3 + .../prover/src/input/state_transitions.rs | 2 +- 12 files changed, 2139 insertions(+), 5 deletions(-) create mode 100644 stwo_cairo_prover/crates/prover/src/components/assert_eq_opcode_is_double_deref_f_is_imm_f/component.rs create mode 100644 stwo_cairo_prover/crates/prover/src/components/assert_eq_opcode_is_double_deref_f_is_imm_f/mod.rs create mode 100644 stwo_cairo_prover/crates/prover/src/components/assert_eq_opcode_is_double_deref_f_is_imm_f/prover.rs create mode 100644 stwo_cairo_prover/crates/prover/src/components/assert_eq_opcode_is_double_deref_f_is_imm_t/component.rs create mode 100644 stwo_cairo_prover/crates/prover/src/components/assert_eq_opcode_is_double_deref_f_is_imm_t/mod.rs create mode 100644 stwo_cairo_prover/crates/prover/src/components/assert_eq_opcode_is_double_deref_f_is_imm_t/prover.rs create mode 100644 stwo_cairo_prover/crates/prover/src/components/assert_eq_opcode_is_double_deref_t_is_imm_f/component.rs create mode 100644 stwo_cairo_prover/crates/prover/src/components/assert_eq_opcode_is_double_deref_t_is_imm_f/mod.rs create mode 100644 stwo_cairo_prover/crates/prover/src/components/assert_eq_opcode_is_double_deref_t_is_imm_f/prover.rs diff --git a/stwo_cairo_prover/crates/prover/src/cairo_air/opcodes_air.rs b/stwo_cairo_prover/crates/prover/src/cairo_air/opcodes_air.rs index 7cbedfcf..2bfeb660 100644 --- a/stwo_cairo_prover/crates/prover/src/cairo_air/opcodes_air.rs +++ b/stwo_cairo_prover/crates/prover/src/cairo_air/opcodes_air.rs @@ -12,10 +12,11 @@ use stwo_prover::core::vcs::blake2_merkle::Blake2sMerkleChannel; use super::air::CairoInteractionElements; use crate::components::{ add_ap_opcode_is_imm_f_op1_base_fp_f, add_ap_opcode_is_imm_f_op1_base_fp_t, - add_ap_opcode_is_imm_t_op1_base_fp_f, generic_opcode, jnz_opcode_is_taken_f_dst_base_fp_f, - jnz_opcode_is_taken_f_dst_base_fp_t, jnz_opcode_is_taken_t_dst_base_fp_f, - jnz_opcode_is_taken_t_dst_base_fp_t, memory_address_to_id, memory_id_to_big, range_check_19, - range_check_9_9, ret_opcode, verify_instruction, + add_ap_opcode_is_imm_t_op1_base_fp_f, assert_eq_opcode_is_double_deref_f_is_imm_f, + assert_eq_opcode_is_double_deref_f_is_imm_t, assert_eq_opcode_is_double_deref_t_is_imm_f, + generic_opcode, jnz_opcode_is_taken_f_dst_base_fp_f, jnz_opcode_is_taken_f_dst_base_fp_t, + jnz_opcode_is_taken_t_dst_base_fp_f, jnz_opcode_is_taken_t_dst_base_fp_t, memory_address_to_id, + memory_id_to_big, range_check_19, range_check_9_9, ret_opcode, verify_instruction, }; use crate::input::state_transitions::StateTransitions; @@ -24,6 +25,9 @@ pub struct OpcodeClaim { add_ap_f_f: Vec, add_ap_f_t: Vec, add_ap_t_f: Vec, + assert_eq_f_f: Vec, + assert_eq_f_t: Vec, + assert_eq_t_f: Vec, generic: Vec, jnz_f_f: Vec, jnz_f_t: Vec, @@ -36,6 +40,9 @@ impl OpcodeClaim { self.add_ap_f_f.iter().for_each(|c| c.mix_into(channel)); self.add_ap_f_t.iter().for_each(|c| c.mix_into(channel)); self.add_ap_t_f.iter().for_each(|c| c.mix_into(channel)); + self.assert_eq_f_f.iter().for_each(|c| c.mix_into(channel)); + self.assert_eq_f_t.iter().for_each(|c| c.mix_into(channel)); + self.assert_eq_t_f.iter().for_each(|c| c.mix_into(channel)); self.generic.iter().for_each(|c| c.mix_into(channel)); self.jnz_f_f.iter().for_each(|c| c.mix_into(channel)); self.jnz_f_t.iter().for_each(|c| c.mix_into(channel)); @@ -49,6 +56,9 @@ impl OpcodeClaim { self.add_ap_f_f.iter().map(|c| c.log_sizes()), self.add_ap_f_t.iter().map(|c| c.log_sizes()), self.add_ap_t_f.iter().map(|c| c.log_sizes()), + self.assert_eq_f_f.iter().map(|c| c.log_sizes()), + self.assert_eq_f_t.iter().map(|c| c.log_sizes()), + self.assert_eq_t_f.iter().map(|c| c.log_sizes()), self.generic.iter().map(|c| c.log_sizes()), self.jnz_f_f.iter().map(|c| c.log_sizes()), self.jnz_f_t.iter().map(|c| c.log_sizes()), @@ -63,6 +73,9 @@ pub struct OpcodesClaimGenerator { add_ap_f_f: Vec, add_ap_f_t: Vec, add_ap_t_f: Vec, + assert_eq_f_f: Vec, + assert_eq_f_t: Vec, + assert_eq_t_f: Vec, generic: Vec, jnz_f_f: Vec, jnz_f_t: Vec, @@ -76,6 +89,9 @@ impl OpcodesClaimGenerator { let mut add_ap_f_f = vec![]; let mut add_ap_f_t = vec![]; let mut add_ap_t_f = vec![]; + let mut assert_eq_f_f = vec![]; + let mut assert_eq_f_t = vec![]; + let mut assert_eq_t_f = vec![]; let mut generic = vec![]; let mut jnz_f_f = vec![]; let mut jnz_f_t = vec![]; @@ -115,6 +131,45 @@ impl OpcodesClaimGenerator { .add_ap_opcode_is_imm_t_op1_base_fp_f, )); } + if !input + .casm_states_by_opcode + .assert_eq_opcode_is_double_deref_f_is_imm_f + .is_empty() + { + assert_eq_f_f.push( + assert_eq_opcode_is_double_deref_f_is_imm_f::ClaimGenerator::new( + input + .casm_states_by_opcode + .assert_eq_opcode_is_double_deref_f_is_imm_f, + ), + ); + } + if !input + .casm_states_by_opcode + .assert_eq_opcode_is_double_deref_f_is_imm_t + .is_empty() + { + assert_eq_f_t.push( + assert_eq_opcode_is_double_deref_f_is_imm_t::ClaimGenerator::new( + input + .casm_states_by_opcode + .assert_eq_opcode_is_double_deref_f_is_imm_t, + ), + ); + } + if !input + .casm_states_by_opcode + .assert_eq_opcode_is_double_deref_t_is_imm_f + .is_empty() + { + assert_eq_t_f.push( + assert_eq_opcode_is_double_deref_t_is_imm_f::ClaimGenerator::new( + input + .casm_states_by_opcode + .assert_eq_opcode_is_double_deref_t_is_imm_f, + ), + ); + } if !input.casm_states_by_opcode.generic_opcode.is_empty() { generic.push(generic_opcode::ClaimGenerator::new( input.casm_states_by_opcode.generic_opcode, @@ -173,6 +228,9 @@ impl OpcodesClaimGenerator { add_ap_f_f, add_ap_f_t, add_ap_t_f, + assert_eq_f_f, + assert_eq_f_t, + assert_eq_t_f, generic, jnz_f_f, jnz_f_t, @@ -227,6 +285,42 @@ impl OpcodesClaimGenerator { ) }) .unzip(); + let (assert_eq_f_f_claims, assert_eq_f_f_interaction_gens) = self + .assert_eq_f_f + .into_iter() + .map(|gen| { + gen.write_trace( + tree_builder, + memory_address_to_id_trace_generator, + memory_id_to_value_trace_generator, + verify_instruction_trace_generator, + ) + }) + .unzip(); + let (assert_eq_f_t_claims, assert_eq_f_t_interaction_gens) = self + .assert_eq_f_t + .into_iter() + .map(|gen| { + gen.write_trace( + tree_builder, + memory_address_to_id_trace_generator, + memory_id_to_value_trace_generator, + verify_instruction_trace_generator, + ) + }) + .unzip(); + let (assert_eq_t_f_claims, assert_eq_t_f_interaction_gens) = self + .assert_eq_t_f + .into_iter() + .map(|gen| { + gen.write_trace( + tree_builder, + memory_address_to_id_trace_generator, + memory_id_to_value_trace_generator, + verify_instruction_trace_generator, + ) + }) + .unzip(); let (generic_opcode_claims, generic_opcode_interaction_gens) = self .generic .into_iter() @@ -307,6 +401,9 @@ impl OpcodesClaimGenerator { add_ap_f_f: add_ap_f_f_claims, add_ap_f_t: add_ap_f_t_claims, add_ap_t_f: add_ap_t_f_claims, + assert_eq_f_f: assert_eq_f_f_claims, + assert_eq_f_t: assert_eq_f_t_claims, + assert_eq_t_f: assert_eq_t_f_claims, generic: generic_opcode_claims, jnz_f_f: jnz_f_f_claims, jnz_f_t: jnz_f_t_claims, @@ -318,6 +415,9 @@ impl OpcodesClaimGenerator { add_ap_f_f: add_ap_f_f_interaction_gens, add_ap_f_t: add_ap_f_t_interaction_gens, add_ap_t_f: add_ap_t_f_interaction_gens, + assert_eq_f_f: assert_eq_f_f_interaction_gens, + assert_eq_f_t: assert_eq_f_t_interaction_gens, + assert_eq_t_f: assert_eq_t_f_interaction_gens, generic_opcode_interaction_gens, jnz_f_f: jnz_f_f_interaction_gens, jnz_f_t: jnz_f_t_interaction_gens, @@ -334,6 +434,9 @@ pub struct OpcodeInteractionClaim { add_ap_f_f: Vec, add_ap_f_t: Vec, add_ap_t_f: Vec, + assert_eq_f_f: Vec, + assert_eq_f_t: Vec, + assert_eq_t_f: Vec, generic: Vec, jnz_f_f: Vec, jnz_f_t: Vec, @@ -346,6 +449,9 @@ impl OpcodeInteractionClaim { self.add_ap_f_f.iter().for_each(|c| c.mix_into(channel)); self.add_ap_f_t.iter().for_each(|c| c.mix_into(channel)); self.add_ap_t_f.iter().for_each(|c| c.mix_into(channel)); + self.assert_eq_f_f.iter().for_each(|c| c.mix_into(channel)); + self.assert_eq_f_t.iter().for_each(|c| c.mix_into(channel)); + self.assert_eq_t_f.iter().for_each(|c| c.mix_into(channel)); self.generic.iter().for_each(|c| c.mix_into(channel)); self.jnz_f_f.iter().for_each(|c| c.mix_into(channel)); self.jnz_f_t.iter().for_each(|c| c.mix_into(channel)); @@ -377,6 +483,27 @@ impl OpcodeInteractionClaim { None => total_sum, }; } + for interaction_claim in &self.assert_eq_f_f { + let (total_sum, claimed_sum) = interaction_claim.logup_sums; + sum += match claimed_sum { + Some((claimed_sum, ..)) => claimed_sum, + None => total_sum, + }; + } + for interaction_claim in &self.assert_eq_f_t { + let (total_sum, claimed_sum) = interaction_claim.logup_sums; + sum += match claimed_sum { + Some((claimed_sum, ..)) => claimed_sum, + None => total_sum, + }; + } + for interaction_claim in &self.assert_eq_t_f { + let (total_sum, claimed_sum) = interaction_claim.logup_sums; + sum += match claimed_sum { + Some((claimed_sum, ..)) => claimed_sum, + None => total_sum, + }; + } for interaction_claim in &self.generic { let (total_sum, claimed_sum) = interaction_claim.logup_sums; sum += match claimed_sum { @@ -427,6 +554,9 @@ pub struct OpcodesInteractionClaimGenerator { add_ap_f_f: Vec, add_ap_f_t: Vec, add_ap_t_f: Vec, + assert_eq_f_f: Vec, + assert_eq_f_t: Vec, + assert_eq_t_f: Vec, generic_opcode_interaction_gens: Vec, jnz_f_f: Vec, jnz_f_t: Vec, @@ -479,6 +609,43 @@ impl OpcodesInteractionClaimGenerator { ) }) .collect(); + let assert_eq_f_f_interaction_claims = self + .assert_eq_f_f + .into_iter() + .map(|gen| { + gen.write_interaction_trace( + tree_builder, + &interaction_elements.memory_address_to_id, + &interaction_elements.opcodes, + &interaction_elements.verify_instruction, + ) + }) + .collect(); + let assert_eq_f_t_interaction_claims = self + .assert_eq_f_t + .into_iter() + .map(|gen| { + gen.write_interaction_trace( + tree_builder, + &interaction_elements.memory_address_to_id, + &interaction_elements.opcodes, + &interaction_elements.verify_instruction, + ) + }) + .collect(); + let assert_eq_t_f_interaction_claims = self + .assert_eq_t_f + .into_iter() + .map(|gen| { + gen.write_interaction_trace( + tree_builder, + &interaction_elements.memory_address_to_id, + &interaction_elements.memory_id_to_value, + &interaction_elements.opcodes, + &interaction_elements.verify_instruction, + ) + }) + .collect(); let generic_opcode_interaction_claims = self .generic_opcode_interaction_gens .into_iter() @@ -563,6 +730,9 @@ impl OpcodesInteractionClaimGenerator { add_ap_f_f: add_ap_f_f_interaction_claims, add_ap_f_t: add_ap_f_t_interaction_claims, add_ap_t_f: add_ap_t_f_interaction_claims, + assert_eq_f_f: assert_eq_f_f_interaction_claims, + assert_eq_f_t: assert_eq_f_t_interaction_claims, + assert_eq_t_f: assert_eq_t_f_interaction_claims, generic: generic_opcode_interaction_claims, jnz_f_f: jnz_f_f_interaction_claims, jnz_f_t: jnz_f_t_interaction_claims, @@ -577,6 +747,9 @@ pub struct OpcodeComponents { add_ap_f_f: Vec, add_ap_f_t: Vec, add_ap_t_f: Vec, + assert_eq_f_f: Vec, + assert_eq_f_t: Vec, + assert_eq_t_f: Vec, generic: Vec, jnz_f_f: Vec, jnz_f_t: Vec, @@ -663,6 +836,72 @@ impl OpcodeComponents { ) }) .collect_vec(); + let assert_eq_f_f_components = claim + .assert_eq_f_f + .iter() + .zip(interaction_claim.assert_eq_f_f.iter()) + .map(|(&claim, &interaction_claim)| { + assert_eq_opcode_is_double_deref_f_is_imm_f::Component::new( + tree_span_provider, + assert_eq_opcode_is_double_deref_f_is_imm_f::Eval { + claim, + memoryaddresstoid_lookup_elements: interaction_elements + .memory_address_to_id + .clone(), + opcodes_lookup_elements: interaction_elements.opcodes.clone(), + verifyinstruction_lookup_elements: interaction_elements + .verify_instruction + .clone(), + }, + interaction_claim.logup_sums, + ) + }) + .collect_vec(); + let assert_eq_f_t_components = claim + .assert_eq_f_t + .iter() + .zip(interaction_claim.assert_eq_f_t.iter()) + .map(|(&claim, &interaction_claim)| { + assert_eq_opcode_is_double_deref_f_is_imm_t::Component::new( + tree_span_provider, + assert_eq_opcode_is_double_deref_f_is_imm_t::Eval { + claim, + memoryaddresstoid_lookup_elements: interaction_elements + .memory_address_to_id + .clone(), + opcodes_lookup_elements: interaction_elements.opcodes.clone(), + verifyinstruction_lookup_elements: interaction_elements + .verify_instruction + .clone(), + }, + interaction_claim.logup_sums, + ) + }) + .collect_vec(); + let assert_eq_t_f_components = claim + .assert_eq_t_f + .iter() + .zip(interaction_claim.assert_eq_t_f.iter()) + .map(|(&claim, &interaction_claim)| { + assert_eq_opcode_is_double_deref_t_is_imm_f::Component::new( + tree_span_provider, + assert_eq_opcode_is_double_deref_t_is_imm_f::Eval { + claim, + memoryaddresstoid_lookup_elements: interaction_elements + .memory_address_to_id + .clone(), + memoryidtobig_lookup_elements: interaction_elements + .memory_id_to_value + .clone(), + opcodes_lookup_elements: interaction_elements.opcodes.clone(), + verifyinstruction_lookup_elements: interaction_elements + .verify_instruction + .clone(), + }, + interaction_claim.logup_sums, + ) + }) + .collect_vec(); let generic_components = claim .generic .iter() @@ -815,6 +1054,9 @@ impl OpcodeComponents { add_ap_f_f: add_ap_f_f_components, add_ap_f_t: add_ap_f_t_components, add_ap_t_f: add_ap_t_f_components, + assert_eq_f_f: assert_eq_f_f_components, + assert_eq_f_t: assert_eq_f_t_components, + assert_eq_t_f: assert_eq_t_f_components, generic: generic_components, jnz_f_f: jnz_f_f_components, jnz_f_t: jnz_f_t_components, @@ -841,6 +1083,21 @@ impl OpcodeComponents { .iter() .map(|component| component as &dyn ComponentProver), ); + vec.extend( + self.assert_eq_f_f + .iter() + .map(|component| component as &dyn ComponentProver), + ); + vec.extend( + self.assert_eq_f_t + .iter() + .map(|component| component as &dyn ComponentProver), + ); + vec.extend( + self.assert_eq_t_f + .iter() + .map(|component| component as &dyn ComponentProver), + ); vec.extend( self.generic .iter() diff --git a/stwo_cairo_prover/crates/prover/src/components/assert_eq_opcode_is_double_deref_f_is_imm_f/component.rs b/stwo_cairo_prover/crates/prover/src/components/assert_eq_opcode_is_double_deref_f_is_imm_f/component.rs new file mode 100644 index 00000000..e2d9cf59 --- /dev/null +++ b/stwo_cairo_prover/crates/prover/src/components/assert_eq_opcode_is_double_deref_f_is_imm_f/component.rs @@ -0,0 +1,172 @@ +#![allow(non_camel_case_types)] +#![allow(unused_imports)] +use num_traits::{One, Zero}; +use serde::{Deserialize, Serialize}; +use stwo_prover::constraint_framework::logup::{LogupAtRow, LogupSums, LookupElements}; +use stwo_prover::constraint_framework::{ + EvalAtRow, FrameworkComponent, FrameworkEval, RelationEntry, +}; +use stwo_prover::core::backend::simd::m31::LOG_N_LANES; +use stwo_prover::core::channel::Channel; +use stwo_prover::core::fields::m31::M31; +use stwo_prover::core::fields::qm31::SecureField; +use stwo_prover::core::fields::secure_column::SECURE_EXTENSION_DEGREE; +use stwo_prover::core::pcs::TreeVec; + +use crate::relations; + +pub struct Eval { + pub claim: Claim, + pub memoryaddresstoid_lookup_elements: relations::MemoryAddressToId, + pub opcodes_lookup_elements: relations::Opcodes, + pub verifyinstruction_lookup_elements: relations::VerifyInstruction, +} + +#[derive(Copy, Clone, Serialize, Deserialize)] +pub struct Claim { + pub n_calls: usize, +} +impl Claim { + pub fn log_sizes(&self) -> TreeVec> { + let log_size = std::cmp::max(self.n_calls.next_power_of_two().ilog2(), LOG_N_LANES); + let trace_log_sizes = vec![log_size; 10]; + let interaction_log_sizes = vec![log_size; SECURE_EXTENSION_DEGREE * 5]; + let preprocessed_log_sizes = vec![log_size]; + TreeVec::new(vec![ + preprocessed_log_sizes, + trace_log_sizes, + interaction_log_sizes, + ]) + } + + pub fn mix_into(&self, channel: &mut impl Channel) { + channel.mix_u64(self.n_calls as u64); + } +} + +#[derive(Copy, Clone, Serialize, Deserialize)] +pub struct InteractionClaim { + pub logup_sums: LogupSums, +} +impl InteractionClaim { + pub fn mix_into(&self, channel: &mut impl Channel) { + let (total_sum, claimed_sum) = self.logup_sums; + channel.mix_felts(&[total_sum]); + if let Some(claimed_sum) = claimed_sum { + channel.mix_felts(&[claimed_sum.0]); + channel.mix_u64(claimed_sum.1 as u64); + } + } +} + +pub type Component = FrameworkComponent; + +impl FrameworkEval for Eval { + fn log_size(&self) -> u32 { + std::cmp::max(self.claim.n_calls.next_power_of_two().ilog2(), LOG_N_LANES) + } + + fn max_constraint_log_degree_bound(&self) -> u32 { + self.log_size() + 1 + } + + #[allow(unused_parens)] + #[allow(clippy::double_parens)] + #[allow(non_snake_case)] + fn evaluate(&self, mut eval: E) -> E { + let M31_0 = E::F::from(M31::from(0)); + let M31_1 = E::F::from(M31::from(1)); + let M31_32767 = E::F::from(M31::from(32767)); + let M31_32768 = E::F::from(M31::from(32768)); + let input_pc_col0 = eval.next_trace_mask(); + let input_ap_col1 = eval.next_trace_mask(); + let input_fp_col2 = eval.next_trace_mask(); + let offset0_col3 = eval.next_trace_mask(); + let offset2_col4 = eval.next_trace_mask(); + let dst_base_fp_col5 = eval.next_trace_mask(); + let op1_base_fp_col6 = eval.next_trace_mask(); + let op1_base_ap_col7 = eval.next_trace_mask(); + let ap_update_add_1_col8 = eval.next_trace_mask(); + let dst_id_col9 = eval.next_trace_mask(); + + // decode_instruction_dc55adb272664963. + + eval.add_to_relation(&[RelationEntry::new( + &self.verifyinstruction_lookup_elements, + E::EF::one(), + &[ + input_pc_col0.clone(), + offset0_col3.clone(), + M31_32767.clone(), + offset2_col4.clone(), + dst_base_fp_col5.clone(), + M31_1.clone(), + M31_0.clone(), + op1_base_fp_col6.clone(), + op1_base_ap_col7.clone(), + M31_0.clone(), + M31_0.clone(), + M31_0.clone(), + M31_0.clone(), + M31_0.clone(), + M31_0.clone(), + ap_update_add_1_col8.clone(), + M31_0.clone(), + M31_0.clone(), + M31_1.clone(), + ], + )]); + + // Either flag op1_base_fp is on or flag op1_base_ap is on. + eval.add_constraint( + ((op1_base_fp_col6.clone() + op1_base_ap_col7.clone()) - M31_1.clone()), + ); + + // mem_verify_equal. + + eval.add_to_relation(&[RelationEntry::new( + &self.memoryaddresstoid_lookup_elements, + E::EF::one(), + &[ + (((dst_base_fp_col5.clone() * input_fp_col2.clone()) + + ((M31_1.clone() - dst_base_fp_col5.clone()) * input_ap_col1.clone())) + + (offset0_col3.clone() - M31_32768.clone())), + dst_id_col9.clone(), + ], + )]); + + eval.add_to_relation(&[RelationEntry::new( + &self.memoryaddresstoid_lookup_elements, + E::EF::one(), + &[ + (((op1_base_fp_col6.clone() * input_fp_col2.clone()) + + (op1_base_ap_col7.clone() * input_ap_col1.clone())) + + (offset2_col4.clone() - M31_32768.clone())), + dst_id_col9.clone(), + ], + )]); + + eval.add_to_relation(&[RelationEntry::new( + &self.opcodes_lookup_elements, + E::EF::one(), + &[ + input_pc_col0.clone(), + input_ap_col1.clone(), + input_fp_col2.clone(), + ], + )]); + + eval.add_to_relation(&[RelationEntry::new( + &self.opcodes_lookup_elements, + -E::EF::one(), + &[ + (input_pc_col0.clone() + M31_1.clone()), + (input_ap_col1.clone() + ap_update_add_1_col8.clone()), + input_fp_col2.clone(), + ], + )]); + + eval.finalize_logup(); + eval + } +} diff --git a/stwo_cairo_prover/crates/prover/src/components/assert_eq_opcode_is_double_deref_f_is_imm_f/mod.rs b/stwo_cairo_prover/crates/prover/src/components/assert_eq_opcode_is_double_deref_f_is_imm_f/mod.rs new file mode 100644 index 00000000..3f7a8d74 --- /dev/null +++ b/stwo_cairo_prover/crates/prover/src/components/assert_eq_opcode_is_double_deref_f_is_imm_f/mod.rs @@ -0,0 +1,5 @@ +pub mod component; +pub mod prover; + +pub use component::{Claim, Component, Eval, InteractionClaim}; +pub use prover::{ClaimGenerator, InputType, InteractionClaimGenerator}; diff --git a/stwo_cairo_prover/crates/prover/src/components/assert_eq_opcode_is_double_deref_f_is_imm_f/prover.rs b/stwo_cairo_prover/crates/prover/src/components/assert_eq_opcode_is_double_deref_f_is_imm_f/prover.rs new file mode 100644 index 00000000..a70f8203 --- /dev/null +++ b/stwo_cairo_prover/crates/prover/src/components/assert_eq_opcode_is_double_deref_f_is_imm_f/prover.rs @@ -0,0 +1,420 @@ +#![allow(unused_parens)] +#![allow(unused_imports)] +use itertools::{chain, zip_eq, Itertools}; +use num_traits::{One, Zero}; +use prover_types::cpu::*; +use prover_types::simd::*; +use stwo_prover::constraint_framework::logup::LogupTraceGenerator; +use stwo_prover::constraint_framework::Relation; +use stwo_prover::core::air::Component; +use stwo_prover::core::backend::simd::column::BaseColumn; +use stwo_prover::core::backend::simd::conversion::Unpack; +use stwo_prover::core::backend::simd::m31::{PackedM31, LOG_N_LANES, N_LANES}; +use stwo_prover::core::backend::simd::qm31::PackedQM31; +use stwo_prover::core::backend::simd::SimdBackend; +use stwo_prover::core::backend::{Col, Column}; +use stwo_prover::core::fields::m31::M31; +use stwo_prover::core::pcs::TreeBuilder; +use stwo_prover::core::poly::circle::{CanonicCoset, CircleEvaluation}; +use stwo_prover::core::poly::BitReversedOrder; +use stwo_prover::core::utils::bit_reverse_coset_to_circle_domain_order; +use stwo_prover::core::vcs::blake2_merkle::{Blake2sMerkleChannel, Blake2sMerkleHasher}; + +use super::component::{Claim, InteractionClaim}; +use crate::components::{memory_address_to_id, memory_id_to_big, pack_values, verify_instruction}; +use crate::relations; + +pub type InputType = CasmState; +pub type PackedInputType = PackedCasmState; +const N_TRACE_COLUMNS: usize = 10; + +#[derive(Default)] +pub struct ClaimGenerator { + pub inputs: Vec, +} +impl ClaimGenerator { + pub fn new(inputs: Vec) -> Self { + Self { inputs } + } + + pub fn write_trace( + mut self, + tree_builder: &mut TreeBuilder<'_, '_, SimdBackend, Blake2sMerkleChannel>, + memory_address_to_id_state: &mut memory_address_to_id::ClaimGenerator, + memory_id_to_big_state: &mut memory_id_to_big::ClaimGenerator, + verify_instruction_state: &mut verify_instruction::ClaimGenerator, + ) -> (Claim, InteractionClaimGenerator) { + let n_calls = self.inputs.len(); + assert_ne!(n_calls, 0); + let size = std::cmp::max(n_calls.next_power_of_two(), N_LANES); + let need_padding = n_calls != size; + + if need_padding { + self.inputs.resize(size, *self.inputs.first().unwrap()); + bit_reverse_coset_to_circle_domain_order(&mut self.inputs); + } + + let packed_inputs = pack_values(&self.inputs); + let (trace, mut sub_components_inputs, lookup_data) = write_trace_simd( + packed_inputs, + memory_address_to_id_state, + memory_id_to_big_state, + ); + + if need_padding { + sub_components_inputs.bit_reverse_coset_to_circle_domain_order(); + } + sub_components_inputs + .memory_address_to_id_inputs + .iter() + .for_each(|inputs| { + memory_address_to_id_state.add_inputs(&inputs[..n_calls]); + }); + sub_components_inputs + .verify_instruction_inputs + .iter() + .for_each(|inputs| { + verify_instruction_state.add_inputs(&inputs[..n_calls]); + }); + + tree_builder.extend_evals( + trace + .into_iter() + .map(|eval| { + let domain = CanonicCoset::new( + eval.len() + .checked_ilog2() + .expect("Input is not a power of 2!"), + ) + .circle_domain(); + CircleEvaluation::::new(domain, eval) + }) + .collect_vec(), + ); + + ( + Claim { n_calls }, + InteractionClaimGenerator { + n_calls, + lookup_data, + }, + ) + } + + pub fn add_inputs(&mut self, inputs: &[InputType]) { + self.inputs.extend(inputs); + } +} + +pub struct SubComponentInputs { + pub memory_address_to_id_inputs: [Vec; 2], + pub verify_instruction_inputs: [Vec; 1], +} +impl SubComponentInputs { + #[allow(unused_variables)] + fn with_capacity(capacity: usize) -> Self { + Self { + memory_address_to_id_inputs: [ + Vec::with_capacity(capacity), + Vec::with_capacity(capacity), + ], + verify_instruction_inputs: [Vec::with_capacity(capacity)], + } + } + + fn bit_reverse_coset_to_circle_domain_order(&mut self) { + self.memory_address_to_id_inputs + .iter_mut() + .for_each(|vec| bit_reverse_coset_to_circle_domain_order(vec)); + self.verify_instruction_inputs + .iter_mut() + .for_each(|vec| bit_reverse_coset_to_circle_domain_order(vec)); + } +} + +#[allow(clippy::useless_conversion)] +#[allow(unused_variables)] +#[allow(clippy::double_parens)] +#[allow(non_snake_case)] +pub fn write_trace_simd( + inputs: Vec, + memory_address_to_id_state: &mut memory_address_to_id::ClaimGenerator, + memory_id_to_big_state: &mut memory_id_to_big::ClaimGenerator, +) -> ( + [BaseColumn; N_TRACE_COLUMNS], + SubComponentInputs, + LookupData, +) { + const N_TRACE_COLUMNS: usize = 10; + let mut trace: [_; N_TRACE_COLUMNS] = + std::array::from_fn(|_| Col::::zeros(inputs.len() * N_LANES)); + + let mut lookup_data = LookupData::with_capacity(inputs.len()); + #[allow(unused_mut)] + let mut sub_components_inputs = SubComponentInputs::with_capacity(inputs.len()); + + let M31_0 = PackedM31::broadcast(M31::from(0)); + let M31_1 = PackedM31::broadcast(M31::from(1)); + let M31_32767 = PackedM31::broadcast(M31::from(32767)); + let M31_32768 = PackedM31::broadcast(M31::from(32768)); + let UInt16_0 = PackedUInt16::broadcast(UInt16::from(0)); + let UInt16_1 = PackedUInt16::broadcast(UInt16::from(1)); + let UInt16_11 = PackedUInt16::broadcast(UInt16::from(11)); + let UInt16_127 = PackedUInt16::broadcast(UInt16::from(127)); + let UInt16_13 = PackedUInt16::broadcast(UInt16::from(13)); + let UInt16_3 = PackedUInt16::broadcast(UInt16::from(3)); + let UInt16_4 = PackedUInt16::broadcast(UInt16::from(4)); + let UInt16_5 = PackedUInt16::broadcast(UInt16::from(5)); + let UInt16_6 = PackedUInt16::broadcast(UInt16::from(6)); + let UInt16_7 = PackedUInt16::broadcast(UInt16::from(7)); + let UInt16_9 = PackedUInt16::broadcast(UInt16::from(9)); + + inputs.into_iter().enumerate().for_each( + |(row_index, assert_eq_opcode_is_double_deref_f_is_imm_f_input)| { + let input_tmp_1110 = assert_eq_opcode_is_double_deref_f_is_imm_f_input; + let input_pc_col0 = input_tmp_1110.pc; + trace[0].data[row_index] = input_pc_col0; + let input_ap_col1 = input_tmp_1110.ap; + trace[1].data[row_index] = input_ap_col1; + let input_fp_col2 = input_tmp_1110.fp; + trace[2].data[row_index] = input_fp_col2; + + // decode_instruction_dc55adb272664963. + + let memory_address_to_id_value_tmp_1120 = + memory_address_to_id_state.deduce_output(input_pc_col0); + let memory_id_to_big_value_tmp_1121 = + memory_id_to_big_state.deduce_output(memory_address_to_id_value_tmp_1120); + let offset0_tmp_1122 = + ((PackedUInt16::from_m31(memory_id_to_big_value_tmp_1121.get_m31(0))) + + (((PackedUInt16::from_m31(memory_id_to_big_value_tmp_1121.get_m31(1))) + & (UInt16_127)) + << (UInt16_9))); + let offset0_col3 = offset0_tmp_1122.as_m31(); + trace[3].data[row_index] = offset0_col3; + let offset2_tmp_1123 = + ((((PackedUInt16::from_m31(memory_id_to_big_value_tmp_1121.get_m31(3))) + >> (UInt16_5)) + + ((PackedUInt16::from_m31(memory_id_to_big_value_tmp_1121.get_m31(4))) + << (UInt16_4))) + + (((PackedUInt16::from_m31(memory_id_to_big_value_tmp_1121.get_m31(5))) + & (UInt16_7)) + << (UInt16_13))); + let offset2_col4 = offset2_tmp_1123.as_m31(); + trace[4].data[row_index] = offset2_col4; + let dst_base_fp_tmp_1124 = + (((((PackedUInt16::from_m31(memory_id_to_big_value_tmp_1121.get_m31(5))) + >> (UInt16_3)) + + ((PackedUInt16::from_m31(memory_id_to_big_value_tmp_1121.get_m31(6))) + << (UInt16_6))) + >> (UInt16_0)) + & (UInt16_1)); + let dst_base_fp_col5 = dst_base_fp_tmp_1124.as_m31(); + trace[5].data[row_index] = dst_base_fp_col5; + let op1_base_fp_tmp_1125 = + (((((PackedUInt16::from_m31(memory_id_to_big_value_tmp_1121.get_m31(5))) + >> (UInt16_3)) + + ((PackedUInt16::from_m31(memory_id_to_big_value_tmp_1121.get_m31(6))) + << (UInt16_6))) + >> (UInt16_3)) + & (UInt16_1)); + let op1_base_fp_col6 = op1_base_fp_tmp_1125.as_m31(); + trace[6].data[row_index] = op1_base_fp_col6; + let op1_base_ap_tmp_1126 = + (((((PackedUInt16::from_m31(memory_id_to_big_value_tmp_1121.get_m31(5))) + >> (UInt16_3)) + + ((PackedUInt16::from_m31(memory_id_to_big_value_tmp_1121.get_m31(6))) + << (UInt16_6))) + >> (UInt16_4)) + & (UInt16_1)); + let op1_base_ap_col7 = op1_base_ap_tmp_1126.as_m31(); + trace[7].data[row_index] = op1_base_ap_col7; + let ap_update_add_1_tmp_1127 = + (((((PackedUInt16::from_m31(memory_id_to_big_value_tmp_1121.get_m31(5))) + >> (UInt16_3)) + + ((PackedUInt16::from_m31(memory_id_to_big_value_tmp_1121.get_m31(6))) + << (UInt16_6))) + >> (UInt16_11)) + & (UInt16_1)); + let ap_update_add_1_col8 = ap_update_add_1_tmp_1127.as_m31(); + trace[8].data[row_index] = ap_update_add_1_col8; + + sub_components_inputs.verify_instruction_inputs[0].extend( + ( + input_pc_col0, + [offset0_col3, M31_32767, offset2_col4], + [ + dst_base_fp_col5, + M31_1, + M31_0, + op1_base_fp_col6, + op1_base_ap_col7, + M31_0, + M31_0, + M31_0, + M31_0, + M31_0, + M31_0, + ap_update_add_1_col8, + M31_0, + M31_0, + M31_1, + ], + ) + .unpack(), + ); + + lookup_data.verifyinstruction[0].push([ + input_pc_col0, + offset0_col3, + M31_32767, + offset2_col4, + dst_base_fp_col5, + M31_1, + M31_0, + op1_base_fp_col6, + op1_base_ap_col7, + M31_0, + M31_0, + M31_0, + M31_0, + M31_0, + M31_0, + ap_update_add_1_col8, + M31_0, + M31_0, + M31_1, + ]); + + // mem_verify_equal. + + let memory_address_to_id_value_tmp_1130 = memory_address_to_id_state.deduce_output( + ((((dst_base_fp_col5) * (input_fp_col2)) + + (((M31_1) - (dst_base_fp_col5)) * (input_ap_col1))) + + ((offset0_col3) - (M31_32768))), + ); + let dst_id_col9 = memory_address_to_id_value_tmp_1130; + trace[9].data[row_index] = dst_id_col9; + sub_components_inputs.memory_address_to_id_inputs[0].extend( + ((((dst_base_fp_col5) * (input_fp_col2)) + + (((M31_1) - (dst_base_fp_col5)) * (input_ap_col1))) + + ((offset0_col3) - (M31_32768))) + .unpack(), + ); + + lookup_data.memoryaddresstoid[0].push([ + ((((dst_base_fp_col5) * (input_fp_col2)) + + (((M31_1) - (dst_base_fp_col5)) * (input_ap_col1))) + + ((offset0_col3) - (M31_32768))), + dst_id_col9, + ]); + sub_components_inputs.memory_address_to_id_inputs[1].extend( + ((((op1_base_fp_col6) * (input_fp_col2)) + ((op1_base_ap_col7) * (input_ap_col1))) + + ((offset2_col4) - (M31_32768))) + .unpack(), + ); + + lookup_data.memoryaddresstoid[1].push([ + ((((op1_base_fp_col6) * (input_fp_col2)) + ((op1_base_ap_col7) * (input_ap_col1))) + + ((offset2_col4) - (M31_32768))), + dst_id_col9, + ]); + + lookup_data.opcodes[0].push([input_pc_col0, input_ap_col1, input_fp_col2]); + lookup_data.opcodes[1].push([ + ((input_pc_col0) + (M31_1)), + ((input_ap_col1) + (ap_update_add_1_col8)), + input_fp_col2, + ]); + }, + ); + + (trace, sub_components_inputs, lookup_data) +} + +pub struct LookupData { + pub memoryaddresstoid: [Vec<[PackedM31; 2]>; 2], + pub opcodes: [Vec<[PackedM31; 3]>; 2], + pub verifyinstruction: [Vec<[PackedM31; 19]>; 1], +} +impl LookupData { + #[allow(unused_variables)] + fn with_capacity(capacity: usize) -> Self { + Self { + memoryaddresstoid: [Vec::with_capacity(capacity), Vec::with_capacity(capacity)], + opcodes: [Vec::with_capacity(capacity), Vec::with_capacity(capacity)], + verifyinstruction: [Vec::with_capacity(capacity)], + } + } +} + +pub struct InteractionClaimGenerator { + pub n_calls: usize, + pub lookup_data: LookupData, +} +impl InteractionClaimGenerator { + pub fn write_interaction_trace( + self, + tree_builder: &mut TreeBuilder<'_, '_, SimdBackend, Blake2sMerkleChannel>, + memoryaddresstoid_lookup_elements: &relations::MemoryAddressToId, + opcodes_lookup_elements: &relations::Opcodes, + verifyinstruction_lookup_elements: &relations::VerifyInstruction, + ) -> InteractionClaim { + let log_size = std::cmp::max(self.n_calls.next_power_of_two().ilog2(), LOG_N_LANES); + let mut logup_gen = LogupTraceGenerator::new(log_size); + + let mut col_gen = logup_gen.new_col(); + let lookup_row = &self.lookup_data.verifyinstruction[0]; + for (i, lookup_values) in lookup_row.iter().enumerate() { + let denom = verifyinstruction_lookup_elements.combine(lookup_values); + col_gen.write_frac(i, PackedQM31::one(), denom); + } + col_gen.finalize_col(); + + let mut col_gen = logup_gen.new_col(); + let lookup_row = &self.lookup_data.memoryaddresstoid[0]; + for (i, lookup_values) in lookup_row.iter().enumerate() { + let denom = memoryaddresstoid_lookup_elements.combine(lookup_values); + col_gen.write_frac(i, PackedQM31::one(), denom); + } + col_gen.finalize_col(); + + let mut col_gen = logup_gen.new_col(); + let lookup_row = &self.lookup_data.memoryaddresstoid[1]; + for (i, lookup_values) in lookup_row.iter().enumerate() { + let denom = memoryaddresstoid_lookup_elements.combine(lookup_values); + col_gen.write_frac(i, PackedQM31::one(), denom); + } + col_gen.finalize_col(); + + let mut col_gen = logup_gen.new_col(); + let lookup_row = &self.lookup_data.opcodes[0]; + for (i, lookup_values) in lookup_row.iter().enumerate() { + let denom = opcodes_lookup_elements.combine(lookup_values); + col_gen.write_frac(i, PackedQM31::one(), denom); + } + col_gen.finalize_col(); + + let mut col_gen = logup_gen.new_col(); + let lookup_row = &self.lookup_data.opcodes[1]; + for (i, lookup_values) in lookup_row.iter().enumerate() { + let denom = opcodes_lookup_elements.combine(lookup_values); + col_gen.write_frac(i, -PackedQM31::one(), denom); + } + col_gen.finalize_col(); + + let (trace, total_sum, claimed_sum) = if self.n_calls == 1 << log_size { + let (trace, claimed_sum) = logup_gen.finalize_last(); + (trace, claimed_sum, None) + } else { + let (trace, [total_sum, claimed_sum]) = + logup_gen.finalize_at([(1 << log_size) - 1, self.n_calls - 1]); + (trace, total_sum, Some((claimed_sum, self.n_calls - 1))) + }; + tree_builder.extend_evals(trace); + + InteractionClaim { + logup_sums: (total_sum, claimed_sum), + } + } +} diff --git a/stwo_cairo_prover/crates/prover/src/components/assert_eq_opcode_is_double_deref_f_is_imm_t/component.rs b/stwo_cairo_prover/crates/prover/src/components/assert_eq_opcode_is_double_deref_f_is_imm_t/component.rs new file mode 100644 index 00000000..7eac9f31 --- /dev/null +++ b/stwo_cairo_prover/crates/prover/src/components/assert_eq_opcode_is_double_deref_f_is_imm_t/component.rs @@ -0,0 +1,161 @@ +#![allow(non_camel_case_types)] +#![allow(unused_imports)] +use num_traits::{One, Zero}; +use serde::{Deserialize, Serialize}; +use stwo_prover::constraint_framework::logup::{LogupAtRow, LogupSums, LookupElements}; +use stwo_prover::constraint_framework::{ + EvalAtRow, FrameworkComponent, FrameworkEval, RelationEntry, +}; +use stwo_prover::core::backend::simd::m31::LOG_N_LANES; +use stwo_prover::core::channel::Channel; +use stwo_prover::core::fields::m31::M31; +use stwo_prover::core::fields::qm31::SecureField; +use stwo_prover::core::fields::secure_column::SECURE_EXTENSION_DEGREE; +use stwo_prover::core::pcs::TreeVec; + +use crate::relations; + +pub struct Eval { + pub claim: Claim, + pub memoryaddresstoid_lookup_elements: relations::MemoryAddressToId, + pub opcodes_lookup_elements: relations::Opcodes, + pub verifyinstruction_lookup_elements: relations::VerifyInstruction, +} + +#[derive(Copy, Clone, Serialize, Deserialize)] +pub struct Claim { + pub n_calls: usize, +} +impl Claim { + pub fn log_sizes(&self) -> TreeVec> { + let log_size = std::cmp::max(self.n_calls.next_power_of_two().ilog2(), LOG_N_LANES); + let trace_log_sizes = vec![log_size; 7]; + let interaction_log_sizes = vec![log_size; SECURE_EXTENSION_DEGREE * 5]; + let preprocessed_log_sizes = vec![log_size]; + TreeVec::new(vec![ + preprocessed_log_sizes, + trace_log_sizes, + interaction_log_sizes, + ]) + } + + pub fn mix_into(&self, channel: &mut impl Channel) { + channel.mix_u64(self.n_calls as u64); + } +} + +#[derive(Copy, Clone, Serialize, Deserialize)] +pub struct InteractionClaim { + pub logup_sums: LogupSums, +} +impl InteractionClaim { + pub fn mix_into(&self, channel: &mut impl Channel) { + let (total_sum, claimed_sum) = self.logup_sums; + channel.mix_felts(&[total_sum]); + if let Some(claimed_sum) = claimed_sum { + channel.mix_felts(&[claimed_sum.0]); + channel.mix_u64(claimed_sum.1 as u64); + } + } +} + +pub type Component = FrameworkComponent; + +impl FrameworkEval for Eval { + fn log_size(&self) -> u32 { + std::cmp::max(self.claim.n_calls.next_power_of_two().ilog2(), LOG_N_LANES) + } + + fn max_constraint_log_degree_bound(&self) -> u32 { + self.log_size() + 1 + } + + #[allow(unused_parens)] + #[allow(clippy::double_parens)] + #[allow(non_snake_case)] + fn evaluate(&self, mut eval: E) -> E { + let M31_0 = E::F::from(M31::from(0)); + let M31_1 = E::F::from(M31::from(1)); + let M31_2 = E::F::from(M31::from(2)); + let M31_32767 = E::F::from(M31::from(32767)); + let M31_32768 = E::F::from(M31::from(32768)); + let M31_32769 = E::F::from(M31::from(32769)); + let input_pc_col0 = eval.next_trace_mask(); + let input_ap_col1 = eval.next_trace_mask(); + let input_fp_col2 = eval.next_trace_mask(); + let offset0_col3 = eval.next_trace_mask(); + let dst_base_fp_col4 = eval.next_trace_mask(); + let ap_update_add_1_col5 = eval.next_trace_mask(); + let dst_id_col6 = eval.next_trace_mask(); + + // decode_instruction_684cf7138ce526e3. + + eval.add_to_relation(&[RelationEntry::new( + &self.verifyinstruction_lookup_elements, + E::EF::one(), + &[ + input_pc_col0.clone(), + offset0_col3.clone(), + M31_32767.clone(), + M31_32769.clone(), + dst_base_fp_col4.clone(), + M31_1.clone(), + M31_1.clone(), + M31_0.clone(), + M31_0.clone(), + M31_0.clone(), + M31_0.clone(), + M31_0.clone(), + M31_0.clone(), + M31_0.clone(), + M31_0.clone(), + ap_update_add_1_col5.clone(), + M31_0.clone(), + M31_0.clone(), + M31_1.clone(), + ], + )]); + + // mem_verify_equal. + + eval.add_to_relation(&[RelationEntry::new( + &self.memoryaddresstoid_lookup_elements, + E::EF::one(), + &[ + (((dst_base_fp_col4.clone() * input_fp_col2.clone()) + + ((M31_1.clone() - dst_base_fp_col4.clone()) * input_ap_col1.clone())) + + (offset0_col3.clone() - M31_32768.clone())), + dst_id_col6.clone(), + ], + )]); + + eval.add_to_relation(&[RelationEntry::new( + &self.memoryaddresstoid_lookup_elements, + E::EF::one(), + &[(input_pc_col0.clone() + M31_1.clone()), dst_id_col6.clone()], + )]); + + eval.add_to_relation(&[RelationEntry::new( + &self.opcodes_lookup_elements, + E::EF::one(), + &[ + input_pc_col0.clone(), + input_ap_col1.clone(), + input_fp_col2.clone(), + ], + )]); + + eval.add_to_relation(&[RelationEntry::new( + &self.opcodes_lookup_elements, + -E::EF::one(), + &[ + (input_pc_col0.clone() + M31_2.clone()), + (input_ap_col1.clone() + ap_update_add_1_col5.clone()), + input_fp_col2.clone(), + ], + )]); + + eval.finalize_logup(); + eval + } +} diff --git a/stwo_cairo_prover/crates/prover/src/components/assert_eq_opcode_is_double_deref_f_is_imm_t/mod.rs b/stwo_cairo_prover/crates/prover/src/components/assert_eq_opcode_is_double_deref_f_is_imm_t/mod.rs new file mode 100644 index 00000000..3f7a8d74 --- /dev/null +++ b/stwo_cairo_prover/crates/prover/src/components/assert_eq_opcode_is_double_deref_f_is_imm_t/mod.rs @@ -0,0 +1,5 @@ +pub mod component; +pub mod prover; + +pub use component::{Claim, Component, Eval, InteractionClaim}; +pub use prover::{ClaimGenerator, InputType, InteractionClaimGenerator}; diff --git a/stwo_cairo_prover/crates/prover/src/components/assert_eq_opcode_is_double_deref_f_is_imm_t/prover.rs b/stwo_cairo_prover/crates/prover/src/components/assert_eq_opcode_is_double_deref_f_is_imm_t/prover.rs new file mode 100644 index 00000000..68c4b325 --- /dev/null +++ b/stwo_cairo_prover/crates/prover/src/components/assert_eq_opcode_is_double_deref_f_is_imm_t/prover.rs @@ -0,0 +1,383 @@ +#![allow(unused_parens)] +#![allow(unused_imports)] +use itertools::{chain, zip_eq, Itertools}; +use num_traits::{One, Zero}; +use prover_types::cpu::*; +use prover_types::simd::*; +use stwo_prover::constraint_framework::logup::LogupTraceGenerator; +use stwo_prover::constraint_framework::Relation; +use stwo_prover::core::air::Component; +use stwo_prover::core::backend::simd::column::BaseColumn; +use stwo_prover::core::backend::simd::conversion::Unpack; +use stwo_prover::core::backend::simd::m31::{PackedM31, LOG_N_LANES, N_LANES}; +use stwo_prover::core::backend::simd::qm31::PackedQM31; +use stwo_prover::core::backend::simd::SimdBackend; +use stwo_prover::core::backend::{Col, Column}; +use stwo_prover::core::fields::m31::M31; +use stwo_prover::core::pcs::TreeBuilder; +use stwo_prover::core::poly::circle::{CanonicCoset, CircleEvaluation}; +use stwo_prover::core::poly::BitReversedOrder; +use stwo_prover::core::utils::bit_reverse_coset_to_circle_domain_order; +use stwo_prover::core::vcs::blake2_merkle::{Blake2sMerkleChannel, Blake2sMerkleHasher}; + +use super::component::{Claim, InteractionClaim}; +use crate::components::{memory_address_to_id, memory_id_to_big, pack_values, verify_instruction}; +use crate::relations; + +pub type InputType = CasmState; +pub type PackedInputType = PackedCasmState; +const N_TRACE_COLUMNS: usize = 7; + +#[derive(Default)] +pub struct ClaimGenerator { + pub inputs: Vec, +} +impl ClaimGenerator { + pub fn new(inputs: Vec) -> Self { + Self { inputs } + } + + pub fn write_trace( + mut self, + tree_builder: &mut TreeBuilder<'_, '_, SimdBackend, Blake2sMerkleChannel>, + memory_address_to_id_state: &mut memory_address_to_id::ClaimGenerator, + memory_id_to_big_state: &mut memory_id_to_big::ClaimGenerator, + verify_instruction_state: &mut verify_instruction::ClaimGenerator, + ) -> (Claim, InteractionClaimGenerator) { + let n_calls = self.inputs.len(); + assert_ne!(n_calls, 0); + let size = std::cmp::max(n_calls.next_power_of_two(), N_LANES); + let need_padding = n_calls != size; + + if need_padding { + self.inputs.resize(size, *self.inputs.first().unwrap()); + bit_reverse_coset_to_circle_domain_order(&mut self.inputs); + } + + let packed_inputs = pack_values(&self.inputs); + let (trace, mut sub_components_inputs, lookup_data) = write_trace_simd( + packed_inputs, + memory_address_to_id_state, + memory_id_to_big_state, + ); + + if need_padding { + sub_components_inputs.bit_reverse_coset_to_circle_domain_order(); + } + sub_components_inputs + .memory_address_to_id_inputs + .iter() + .for_each(|inputs| { + memory_address_to_id_state.add_inputs(&inputs[..n_calls]); + }); + sub_components_inputs + .verify_instruction_inputs + .iter() + .for_each(|inputs| { + verify_instruction_state.add_inputs(&inputs[..n_calls]); + }); + + tree_builder.extend_evals( + trace + .into_iter() + .map(|eval| { + let domain = CanonicCoset::new( + eval.len() + .checked_ilog2() + .expect("Input is not a power of 2!"), + ) + .circle_domain(); + CircleEvaluation::::new(domain, eval) + }) + .collect_vec(), + ); + + ( + Claim { n_calls }, + InteractionClaimGenerator { + n_calls, + lookup_data, + }, + ) + } + + pub fn add_inputs(&mut self, inputs: &[InputType]) { + self.inputs.extend(inputs); + } +} + +pub struct SubComponentInputs { + pub memory_address_to_id_inputs: [Vec; 2], + pub verify_instruction_inputs: [Vec; 1], +} +impl SubComponentInputs { + #[allow(unused_variables)] + fn with_capacity(capacity: usize) -> Self { + Self { + memory_address_to_id_inputs: [ + Vec::with_capacity(capacity), + Vec::with_capacity(capacity), + ], + verify_instruction_inputs: [Vec::with_capacity(capacity)], + } + } + + fn bit_reverse_coset_to_circle_domain_order(&mut self) { + self.memory_address_to_id_inputs + .iter_mut() + .for_each(|vec| bit_reverse_coset_to_circle_domain_order(vec)); + self.verify_instruction_inputs + .iter_mut() + .for_each(|vec| bit_reverse_coset_to_circle_domain_order(vec)); + } +} + +#[allow(clippy::useless_conversion)] +#[allow(unused_variables)] +#[allow(clippy::double_parens)] +#[allow(non_snake_case)] +pub fn write_trace_simd( + inputs: Vec, + memory_address_to_id_state: &mut memory_address_to_id::ClaimGenerator, + memory_id_to_big_state: &mut memory_id_to_big::ClaimGenerator, +) -> ( + [BaseColumn; N_TRACE_COLUMNS], + SubComponentInputs, + LookupData, +) { + const N_TRACE_COLUMNS: usize = 7; + let mut trace: [_; N_TRACE_COLUMNS] = + std::array::from_fn(|_| Col::::zeros(inputs.len() * N_LANES)); + + let mut lookup_data = LookupData::with_capacity(inputs.len()); + #[allow(unused_mut)] + let mut sub_components_inputs = SubComponentInputs::with_capacity(inputs.len()); + + let M31_0 = PackedM31::broadcast(M31::from(0)); + let M31_1 = PackedM31::broadcast(M31::from(1)); + let M31_2 = PackedM31::broadcast(M31::from(2)); + let M31_32767 = PackedM31::broadcast(M31::from(32767)); + let M31_32768 = PackedM31::broadcast(M31::from(32768)); + let M31_32769 = PackedM31::broadcast(M31::from(32769)); + let UInt16_0 = PackedUInt16::broadcast(UInt16::from(0)); + let UInt16_1 = PackedUInt16::broadcast(UInt16::from(1)); + let UInt16_11 = PackedUInt16::broadcast(UInt16::from(11)); + let UInt16_127 = PackedUInt16::broadcast(UInt16::from(127)); + let UInt16_3 = PackedUInt16::broadcast(UInt16::from(3)); + let UInt16_6 = PackedUInt16::broadcast(UInt16::from(6)); + let UInt16_9 = PackedUInt16::broadcast(UInt16::from(9)); + + inputs.into_iter().enumerate().for_each( + |(row_index, assert_eq_opcode_is_double_deref_f_is_imm_t_input)| { + let input_tmp_1155 = assert_eq_opcode_is_double_deref_f_is_imm_t_input; + let input_pc_col0 = input_tmp_1155.pc; + trace[0].data[row_index] = input_pc_col0; + let input_ap_col1 = input_tmp_1155.ap; + trace[1].data[row_index] = input_ap_col1; + let input_fp_col2 = input_tmp_1155.fp; + trace[2].data[row_index] = input_fp_col2; + + // decode_instruction_684cf7138ce526e3. + + let memory_address_to_id_value_tmp_1162 = + memory_address_to_id_state.deduce_output(input_pc_col0); + let memory_id_to_big_value_tmp_1163 = + memory_id_to_big_state.deduce_output(memory_address_to_id_value_tmp_1162); + let offset0_tmp_1164 = + ((PackedUInt16::from_m31(memory_id_to_big_value_tmp_1163.get_m31(0))) + + (((PackedUInt16::from_m31(memory_id_to_big_value_tmp_1163.get_m31(1))) + & (UInt16_127)) + << (UInt16_9))); + let offset0_col3 = offset0_tmp_1164.as_m31(); + trace[3].data[row_index] = offset0_col3; + let dst_base_fp_tmp_1165 = + (((((PackedUInt16::from_m31(memory_id_to_big_value_tmp_1163.get_m31(5))) + >> (UInt16_3)) + + ((PackedUInt16::from_m31(memory_id_to_big_value_tmp_1163.get_m31(6))) + << (UInt16_6))) + >> (UInt16_0)) + & (UInt16_1)); + let dst_base_fp_col4 = dst_base_fp_tmp_1165.as_m31(); + trace[4].data[row_index] = dst_base_fp_col4; + let ap_update_add_1_tmp_1166 = + (((((PackedUInt16::from_m31(memory_id_to_big_value_tmp_1163.get_m31(5))) + >> (UInt16_3)) + + ((PackedUInt16::from_m31(memory_id_to_big_value_tmp_1163.get_m31(6))) + << (UInt16_6))) + >> (UInt16_11)) + & (UInt16_1)); + let ap_update_add_1_col5 = ap_update_add_1_tmp_1166.as_m31(); + trace[5].data[row_index] = ap_update_add_1_col5; + + sub_components_inputs.verify_instruction_inputs[0].extend( + ( + input_pc_col0, + [offset0_col3, M31_32767, M31_32769], + [ + dst_base_fp_col4, + M31_1, + M31_1, + M31_0, + M31_0, + M31_0, + M31_0, + M31_0, + M31_0, + M31_0, + M31_0, + ap_update_add_1_col5, + M31_0, + M31_0, + M31_1, + ], + ) + .unpack(), + ); + + lookup_data.verifyinstruction[0].push([ + input_pc_col0, + offset0_col3, + M31_32767, + M31_32769, + dst_base_fp_col4, + M31_1, + M31_1, + M31_0, + M31_0, + M31_0, + M31_0, + M31_0, + M31_0, + M31_0, + M31_0, + ap_update_add_1_col5, + M31_0, + M31_0, + M31_1, + ]); + + // mem_verify_equal. + + let memory_address_to_id_value_tmp_1168 = memory_address_to_id_state.deduce_output( + ((((dst_base_fp_col4) * (input_fp_col2)) + + (((M31_1) - (dst_base_fp_col4)) * (input_ap_col1))) + + ((offset0_col3) - (M31_32768))), + ); + let dst_id_col6 = memory_address_to_id_value_tmp_1168; + trace[6].data[row_index] = dst_id_col6; + sub_components_inputs.memory_address_to_id_inputs[0].extend( + ((((dst_base_fp_col4) * (input_fp_col2)) + + (((M31_1) - (dst_base_fp_col4)) * (input_ap_col1))) + + ((offset0_col3) - (M31_32768))) + .unpack(), + ); + + lookup_data.memoryaddresstoid[0].push([ + ((((dst_base_fp_col4) * (input_fp_col2)) + + (((M31_1) - (dst_base_fp_col4)) * (input_ap_col1))) + + ((offset0_col3) - (M31_32768))), + dst_id_col6, + ]); + sub_components_inputs.memory_address_to_id_inputs[1] + .extend(((input_pc_col0) + (M31_1)).unpack()); + + lookup_data.memoryaddresstoid[1].push([((input_pc_col0) + (M31_1)), dst_id_col6]); + + lookup_data.opcodes[0].push([input_pc_col0, input_ap_col1, input_fp_col2]); + lookup_data.opcodes[1].push([ + ((input_pc_col0) + (M31_2)), + ((input_ap_col1) + (ap_update_add_1_col5)), + input_fp_col2, + ]); + }, + ); + + (trace, sub_components_inputs, lookup_data) +} + +pub struct LookupData { + pub memoryaddresstoid: [Vec<[PackedM31; 2]>; 2], + pub opcodes: [Vec<[PackedM31; 3]>; 2], + pub verifyinstruction: [Vec<[PackedM31; 19]>; 1], +} +impl LookupData { + #[allow(unused_variables)] + fn with_capacity(capacity: usize) -> Self { + Self { + memoryaddresstoid: [Vec::with_capacity(capacity), Vec::with_capacity(capacity)], + opcodes: [Vec::with_capacity(capacity), Vec::with_capacity(capacity)], + verifyinstruction: [Vec::with_capacity(capacity)], + } + } +} + +pub struct InteractionClaimGenerator { + pub n_calls: usize, + pub lookup_data: LookupData, +} +impl InteractionClaimGenerator { + pub fn write_interaction_trace( + self, + tree_builder: &mut TreeBuilder<'_, '_, SimdBackend, Blake2sMerkleChannel>, + memoryaddresstoid_lookup_elements: &relations::MemoryAddressToId, + opcodes_lookup_elements: &relations::Opcodes, + verifyinstruction_lookup_elements: &relations::VerifyInstruction, + ) -> InteractionClaim { + let log_size = std::cmp::max(self.n_calls.next_power_of_two().ilog2(), LOG_N_LANES); + let mut logup_gen = LogupTraceGenerator::new(log_size); + + let mut col_gen = logup_gen.new_col(); + let lookup_row = &self.lookup_data.verifyinstruction[0]; + for (i, lookup_values) in lookup_row.iter().enumerate() { + let denom = verifyinstruction_lookup_elements.combine(lookup_values); + col_gen.write_frac(i, PackedQM31::one(), denom); + } + col_gen.finalize_col(); + + let mut col_gen = logup_gen.new_col(); + let lookup_row = &self.lookup_data.memoryaddresstoid[0]; + for (i, lookup_values) in lookup_row.iter().enumerate() { + let denom = memoryaddresstoid_lookup_elements.combine(lookup_values); + col_gen.write_frac(i, PackedQM31::one(), denom); + } + col_gen.finalize_col(); + + let mut col_gen = logup_gen.new_col(); + let lookup_row = &self.lookup_data.memoryaddresstoid[1]; + for (i, lookup_values) in lookup_row.iter().enumerate() { + let denom = memoryaddresstoid_lookup_elements.combine(lookup_values); + col_gen.write_frac(i, PackedQM31::one(), denom); + } + col_gen.finalize_col(); + + let mut col_gen = logup_gen.new_col(); + let lookup_row = &self.lookup_data.opcodes[0]; + for (i, lookup_values) in lookup_row.iter().enumerate() { + let denom = opcodes_lookup_elements.combine(lookup_values); + col_gen.write_frac(i, PackedQM31::one(), denom); + } + col_gen.finalize_col(); + + let mut col_gen = logup_gen.new_col(); + let lookup_row = &self.lookup_data.opcodes[1]; + for (i, lookup_values) in lookup_row.iter().enumerate() { + let denom = opcodes_lookup_elements.combine(lookup_values); + col_gen.write_frac(i, -PackedQM31::one(), denom); + } + col_gen.finalize_col(); + + let (trace, total_sum, claimed_sum) = if self.n_calls == 1 << log_size { + let (trace, claimed_sum) = logup_gen.finalize_last(); + (trace, claimed_sum, None) + } else { + let (trace, [total_sum, claimed_sum]) = + logup_gen.finalize_at([(1 << log_size) - 1, self.n_calls - 1]); + (trace, total_sum, Some((claimed_sum, self.n_calls - 1))) + }; + tree_builder.extend_evals(trace); + + InteractionClaim { + logup_sums: (total_sum, claimed_sum), + } + } +} diff --git a/stwo_cairo_prover/crates/prover/src/components/assert_eq_opcode_is_double_deref_t_is_imm_f/component.rs b/stwo_cairo_prover/crates/prover/src/components/assert_eq_opcode_is_double_deref_t_is_imm_f/component.rs new file mode 100644 index 00000000..689713c2 --- /dev/null +++ b/stwo_cairo_prover/crates/prover/src/components/assert_eq_opcode_is_double_deref_t_is_imm_f/component.rs @@ -0,0 +1,198 @@ +#![allow(non_camel_case_types)] +#![allow(unused_imports)] +use num_traits::{One, Zero}; +use serde::{Deserialize, Serialize}; +use stwo_prover::constraint_framework::logup::{LogupAtRow, LogupSums, LookupElements}; +use stwo_prover::constraint_framework::{ + EvalAtRow, FrameworkComponent, FrameworkEval, RelationEntry, +}; +use stwo_prover::core::backend::simd::m31::LOG_N_LANES; +use stwo_prover::core::channel::Channel; +use stwo_prover::core::fields::m31::M31; +use stwo_prover::core::fields::qm31::SecureField; +use stwo_prover::core::fields::secure_column::SECURE_EXTENSION_DEGREE; +use stwo_prover::core::pcs::TreeVec; + +use crate::relations; + +pub struct Eval { + pub claim: Claim, + pub memoryaddresstoid_lookup_elements: relations::MemoryAddressToId, + pub memoryidtobig_lookup_elements: relations::MemoryIdToBig, + pub opcodes_lookup_elements: relations::Opcodes, + pub verifyinstruction_lookup_elements: relations::VerifyInstruction, +} + +#[derive(Copy, Clone, Serialize, Deserialize)] +pub struct Claim { + pub n_calls: usize, +} +impl Claim { + pub fn log_sizes(&self) -> TreeVec> { + let log_size = std::cmp::max(self.n_calls.next_power_of_two().ilog2(), LOG_N_LANES); + let trace_log_sizes = vec![log_size; 14]; + let interaction_log_sizes = vec![log_size; SECURE_EXTENSION_DEGREE * 7]; + let preprocessed_log_sizes = vec![log_size]; + TreeVec::new(vec![ + preprocessed_log_sizes, + trace_log_sizes, + interaction_log_sizes, + ]) + } + + pub fn mix_into(&self, channel: &mut impl Channel) { + channel.mix_u64(self.n_calls as u64); + } +} + +#[derive(Copy, Clone, Serialize, Deserialize)] +pub struct InteractionClaim { + pub logup_sums: LogupSums, +} +impl InteractionClaim { + pub fn mix_into(&self, channel: &mut impl Channel) { + let (total_sum, claimed_sum) = self.logup_sums; + channel.mix_felts(&[total_sum]); + if let Some(claimed_sum) = claimed_sum { + channel.mix_felts(&[claimed_sum.0]); + channel.mix_u64(claimed_sum.1 as u64); + } + } +} + +pub type Component = FrameworkComponent; + +impl FrameworkEval for Eval { + fn log_size(&self) -> u32 { + std::cmp::max(self.claim.n_calls.next_power_of_two().ilog2(), LOG_N_LANES) + } + + fn max_constraint_log_degree_bound(&self) -> u32 { + self.log_size() + 1 + } + + #[allow(unused_parens)] + #[allow(clippy::double_parens)] + #[allow(non_snake_case)] + fn evaluate(&self, mut eval: E) -> E { + let M31_0 = E::F::from(M31::from(0)); + let M31_1 = E::F::from(M31::from(1)); + let M31_262144 = E::F::from(M31::from(262144)); + let M31_32768 = E::F::from(M31::from(32768)); + let M31_512 = E::F::from(M31::from(512)); + let input_pc_col0 = eval.next_trace_mask(); + let input_ap_col1 = eval.next_trace_mask(); + let input_fp_col2 = eval.next_trace_mask(); + let offset0_col3 = eval.next_trace_mask(); + let offset1_col4 = eval.next_trace_mask(); + let offset2_col5 = eval.next_trace_mask(); + let dst_base_fp_col6 = eval.next_trace_mask(); + let op0_base_fp_col7 = eval.next_trace_mask(); + let ap_update_add_1_col8 = eval.next_trace_mask(); + let mem1_base_id_col9 = eval.next_trace_mask(); + let mem1_base_limb_0_col10 = eval.next_trace_mask(); + let mem1_base_limb_1_col11 = eval.next_trace_mask(); + let mem1_base_limb_2_col12 = eval.next_trace_mask(); + let dst_id_col13 = eval.next_trace_mask(); + + // decode_instruction_a2af169c0fec5c47. + + eval.add_to_relation(&[RelationEntry::new( + &self.verifyinstruction_lookup_elements, + E::EF::one(), + &[ + input_pc_col0.clone(), + offset0_col3.clone(), + offset1_col4.clone(), + offset2_col5.clone(), + dst_base_fp_col6.clone(), + op0_base_fp_col7.clone(), + M31_0.clone(), + M31_0.clone(), + M31_0.clone(), + M31_0.clone(), + M31_0.clone(), + M31_0.clone(), + M31_0.clone(), + M31_0.clone(), + M31_0.clone(), + ap_update_add_1_col8.clone(), + M31_0.clone(), + M31_0.clone(), + M31_1.clone(), + ], + )]); + + // read_positive_num_bits_27. + + eval.add_to_relation(&[RelationEntry::new( + &self.memoryaddresstoid_lookup_elements, + E::EF::one(), + &[ + (((op0_base_fp_col7.clone() * input_fp_col2.clone()) + + ((M31_1.clone() - op0_base_fp_col7.clone()) * input_ap_col1.clone())) + + (offset1_col4.clone() - M31_32768.clone())), + mem1_base_id_col9.clone(), + ], + )]); + + eval.add_to_relation(&[RelationEntry::new( + &self.memoryidtobig_lookup_elements, + E::EF::one(), + &[ + mem1_base_id_col9.clone(), + mem1_base_limb_0_col10.clone(), + mem1_base_limb_1_col11.clone(), + mem1_base_limb_2_col12.clone(), + ], + )]); + + // mem_verify_equal. + + eval.add_to_relation(&[RelationEntry::new( + &self.memoryaddresstoid_lookup_elements, + E::EF::one(), + &[ + (((dst_base_fp_col6.clone() * input_fp_col2.clone()) + + ((M31_1.clone() - dst_base_fp_col6.clone()) * input_ap_col1.clone())) + + (offset0_col3.clone() - M31_32768.clone())), + dst_id_col13.clone(), + ], + )]); + + eval.add_to_relation(&[RelationEntry::new( + &self.memoryaddresstoid_lookup_elements, + E::EF::one(), + &[ + (((mem1_base_limb_0_col10.clone() + + (mem1_base_limb_1_col11.clone() * M31_512.clone())) + + (mem1_base_limb_2_col12.clone() * M31_262144.clone())) + + (offset2_col5.clone() - M31_32768.clone())), + dst_id_col13.clone(), + ], + )]); + + eval.add_to_relation(&[RelationEntry::new( + &self.opcodes_lookup_elements, + E::EF::one(), + &[ + input_pc_col0.clone(), + input_ap_col1.clone(), + input_fp_col2.clone(), + ], + )]); + + eval.add_to_relation(&[RelationEntry::new( + &self.opcodes_lookup_elements, + -E::EF::one(), + &[ + (input_pc_col0.clone() + M31_1.clone()), + (input_ap_col1.clone() + ap_update_add_1_col8.clone()), + input_fp_col2.clone(), + ], + )]); + + eval.finalize_logup(); + eval + } +} diff --git a/stwo_cairo_prover/crates/prover/src/components/assert_eq_opcode_is_double_deref_t_is_imm_f/mod.rs b/stwo_cairo_prover/crates/prover/src/components/assert_eq_opcode_is_double_deref_t_is_imm_f/mod.rs new file mode 100644 index 00000000..3f7a8d74 --- /dev/null +++ b/stwo_cairo_prover/crates/prover/src/components/assert_eq_opcode_is_double_deref_t_is_imm_f/mod.rs @@ -0,0 +1,5 @@ +pub mod component; +pub mod prover; + +pub use component::{Claim, Component, Eval, InteractionClaim}; +pub use prover::{ClaimGenerator, InputType, InteractionClaimGenerator}; diff --git a/stwo_cairo_prover/crates/prover/src/components/assert_eq_opcode_is_double_deref_t_is_imm_f/prover.rs b/stwo_cairo_prover/crates/prover/src/components/assert_eq_opcode_is_double_deref_t_is_imm_f/prover.rs new file mode 100644 index 00000000..bc932ab7 --- /dev/null +++ b/stwo_cairo_prover/crates/prover/src/components/assert_eq_opcode_is_double_deref_t_is_imm_f/prover.rs @@ -0,0 +1,525 @@ +#![allow(unused_parens)] +#![allow(unused_imports)] +use itertools::{chain, zip_eq, Itertools}; +use num_traits::{One, Zero}; +use prover_types::cpu::*; +use prover_types::simd::*; +use stwo_prover::constraint_framework::logup::LogupTraceGenerator; +use stwo_prover::constraint_framework::Relation; +use stwo_prover::core::air::Component; +use stwo_prover::core::backend::simd::column::BaseColumn; +use stwo_prover::core::backend::simd::conversion::Unpack; +use stwo_prover::core::backend::simd::m31::{PackedM31, LOG_N_LANES, N_LANES}; +use stwo_prover::core::backend::simd::qm31::PackedQM31; +use stwo_prover::core::backend::simd::SimdBackend; +use stwo_prover::core::backend::{Col, Column}; +use stwo_prover::core::fields::m31::M31; +use stwo_prover::core::pcs::TreeBuilder; +use stwo_prover::core::poly::circle::{CanonicCoset, CircleEvaluation}; +use stwo_prover::core::poly::BitReversedOrder; +use stwo_prover::core::utils::bit_reverse_coset_to_circle_domain_order; +use stwo_prover::core::vcs::blake2_merkle::{Blake2sMerkleChannel, Blake2sMerkleHasher}; + +use super::component::{Claim, InteractionClaim}; +use crate::components::{memory_address_to_id, memory_id_to_big, pack_values, verify_instruction}; +use crate::relations; + +pub type InputType = CasmState; +pub type PackedInputType = PackedCasmState; +const N_TRACE_COLUMNS: usize = 14; + +#[derive(Default)] +pub struct ClaimGenerator { + pub inputs: Vec, +} +impl ClaimGenerator { + pub fn new(inputs: Vec) -> Self { + Self { inputs } + } + + pub fn write_trace( + mut self, + tree_builder: &mut TreeBuilder<'_, '_, SimdBackend, Blake2sMerkleChannel>, + memory_address_to_id_state: &mut memory_address_to_id::ClaimGenerator, + memory_id_to_big_state: &mut memory_id_to_big::ClaimGenerator, + verify_instruction_state: &mut verify_instruction::ClaimGenerator, + ) -> (Claim, InteractionClaimGenerator) { + let n_calls = self.inputs.len(); + assert_ne!(n_calls, 0); + let size = std::cmp::max(n_calls.next_power_of_two(), N_LANES); + let need_padding = n_calls != size; + + if need_padding { + self.inputs.resize(size, *self.inputs.first().unwrap()); + bit_reverse_coset_to_circle_domain_order(&mut self.inputs); + } + + let packed_inputs = pack_values(&self.inputs); + let (trace, mut sub_components_inputs, lookup_data) = write_trace_simd( + packed_inputs, + memory_address_to_id_state, + memory_id_to_big_state, + ); + + if need_padding { + sub_components_inputs.bit_reverse_coset_to_circle_domain_order(); + } + sub_components_inputs + .memory_address_to_id_inputs + .iter() + .for_each(|inputs| { + memory_address_to_id_state.add_inputs(&inputs[..n_calls]); + }); + sub_components_inputs + .memory_id_to_big_inputs + .iter() + .for_each(|inputs| { + memory_id_to_big_state.add_inputs(&inputs[..n_calls]); + }); + sub_components_inputs + .verify_instruction_inputs + .iter() + .for_each(|inputs| { + verify_instruction_state.add_inputs(&inputs[..n_calls]); + }); + + tree_builder.extend_evals( + trace + .into_iter() + .map(|eval| { + let domain = CanonicCoset::new( + eval.len() + .checked_ilog2() + .expect("Input is not a power of 2!"), + ) + .circle_domain(); + CircleEvaluation::::new(domain, eval) + }) + .collect_vec(), + ); + + ( + Claim { n_calls }, + InteractionClaimGenerator { + n_calls, + lookup_data, + }, + ) + } + + pub fn add_inputs(&mut self, inputs: &[InputType]) { + self.inputs.extend(inputs); + } +} + +pub struct SubComponentInputs { + pub memory_address_to_id_inputs: [Vec; 3], + pub memory_id_to_big_inputs: [Vec; 1], + pub verify_instruction_inputs: [Vec; 1], +} +impl SubComponentInputs { + #[allow(unused_variables)] + fn with_capacity(capacity: usize) -> Self { + Self { + memory_address_to_id_inputs: [ + Vec::with_capacity(capacity), + Vec::with_capacity(capacity), + Vec::with_capacity(capacity), + ], + memory_id_to_big_inputs: [Vec::with_capacity(capacity)], + verify_instruction_inputs: [Vec::with_capacity(capacity)], + } + } + + fn bit_reverse_coset_to_circle_domain_order(&mut self) { + self.memory_address_to_id_inputs + .iter_mut() + .for_each(|vec| bit_reverse_coset_to_circle_domain_order(vec)); + self.memory_id_to_big_inputs + .iter_mut() + .for_each(|vec| bit_reverse_coset_to_circle_domain_order(vec)); + self.verify_instruction_inputs + .iter_mut() + .for_each(|vec| bit_reverse_coset_to_circle_domain_order(vec)); + } +} + +#[allow(clippy::useless_conversion)] +#[allow(unused_variables)] +#[allow(clippy::double_parens)] +#[allow(non_snake_case)] +pub fn write_trace_simd( + inputs: Vec, + memory_address_to_id_state: &mut memory_address_to_id::ClaimGenerator, + memory_id_to_big_state: &mut memory_id_to_big::ClaimGenerator, +) -> ( + [BaseColumn; N_TRACE_COLUMNS], + SubComponentInputs, + LookupData, +) { + const N_TRACE_COLUMNS: usize = 14; + let mut trace: [_; N_TRACE_COLUMNS] = + std::array::from_fn(|_| Col::::zeros(inputs.len() * N_LANES)); + + let mut lookup_data = LookupData::with_capacity(inputs.len()); + #[allow(unused_mut)] + let mut sub_components_inputs = SubComponentInputs::with_capacity(inputs.len()); + + let M31_0 = PackedM31::broadcast(M31::from(0)); + let M31_1 = PackedM31::broadcast(M31::from(1)); + let M31_262144 = PackedM31::broadcast(M31::from(262144)); + let M31_32768 = PackedM31::broadcast(M31::from(32768)); + let M31_512 = PackedM31::broadcast(M31::from(512)); + let UInt16_0 = PackedUInt16::broadcast(UInt16::from(0)); + let UInt16_1 = PackedUInt16::broadcast(UInt16::from(1)); + let UInt16_11 = PackedUInt16::broadcast(UInt16::from(11)); + let UInt16_127 = PackedUInt16::broadcast(UInt16::from(127)); + let UInt16_13 = PackedUInt16::broadcast(UInt16::from(13)); + let UInt16_2 = PackedUInt16::broadcast(UInt16::from(2)); + let UInt16_3 = PackedUInt16::broadcast(UInt16::from(3)); + let UInt16_31 = PackedUInt16::broadcast(UInt16::from(31)); + let UInt16_4 = PackedUInt16::broadcast(UInt16::from(4)); + let UInt16_5 = PackedUInt16::broadcast(UInt16::from(5)); + let UInt16_6 = PackedUInt16::broadcast(UInt16::from(6)); + let UInt16_7 = PackedUInt16::broadcast(UInt16::from(7)); + let UInt16_9 = PackedUInt16::broadcast(UInt16::from(9)); + + inputs.into_iter().enumerate().for_each( + |(row_index, assert_eq_opcode_is_double_deref_t_is_imm_f_input)| { + let input_tmp_1131 = assert_eq_opcode_is_double_deref_t_is_imm_f_input; + let input_pc_col0 = input_tmp_1131.pc; + trace[0].data[row_index] = input_pc_col0; + let input_ap_col1 = input_tmp_1131.ap; + trace[1].data[row_index] = input_ap_col1; + let input_fp_col2 = input_tmp_1131.fp; + trace[2].data[row_index] = input_fp_col2; + + // decode_instruction_a2af169c0fec5c47. + + let memory_address_to_id_value_tmp_1141 = + memory_address_to_id_state.deduce_output(input_pc_col0); + let memory_id_to_big_value_tmp_1142 = + memory_id_to_big_state.deduce_output(memory_address_to_id_value_tmp_1141); + let offset0_tmp_1143 = + ((PackedUInt16::from_m31(memory_id_to_big_value_tmp_1142.get_m31(0))) + + (((PackedUInt16::from_m31(memory_id_to_big_value_tmp_1142.get_m31(1))) + & (UInt16_127)) + << (UInt16_9))); + let offset0_col3 = offset0_tmp_1143.as_m31(); + trace[3].data[row_index] = offset0_col3; + let offset1_tmp_1144 = + ((((PackedUInt16::from_m31(memory_id_to_big_value_tmp_1142.get_m31(1))) + >> (UInt16_7)) + + ((PackedUInt16::from_m31(memory_id_to_big_value_tmp_1142.get_m31(2))) + << (UInt16_2))) + + (((PackedUInt16::from_m31(memory_id_to_big_value_tmp_1142.get_m31(3))) + & (UInt16_31)) + << (UInt16_11))); + let offset1_col4 = offset1_tmp_1144.as_m31(); + trace[4].data[row_index] = offset1_col4; + let offset2_tmp_1145 = + ((((PackedUInt16::from_m31(memory_id_to_big_value_tmp_1142.get_m31(3))) + >> (UInt16_5)) + + ((PackedUInt16::from_m31(memory_id_to_big_value_tmp_1142.get_m31(4))) + << (UInt16_4))) + + (((PackedUInt16::from_m31(memory_id_to_big_value_tmp_1142.get_m31(5))) + & (UInt16_7)) + << (UInt16_13))); + let offset2_col5 = offset2_tmp_1145.as_m31(); + trace[5].data[row_index] = offset2_col5; + let dst_base_fp_tmp_1146 = + (((((PackedUInt16::from_m31(memory_id_to_big_value_tmp_1142.get_m31(5))) + >> (UInt16_3)) + + ((PackedUInt16::from_m31(memory_id_to_big_value_tmp_1142.get_m31(6))) + << (UInt16_6))) + >> (UInt16_0)) + & (UInt16_1)); + let dst_base_fp_col6 = dst_base_fp_tmp_1146.as_m31(); + trace[6].data[row_index] = dst_base_fp_col6; + let op0_base_fp_tmp_1147 = + (((((PackedUInt16::from_m31(memory_id_to_big_value_tmp_1142.get_m31(5))) + >> (UInt16_3)) + + ((PackedUInt16::from_m31(memory_id_to_big_value_tmp_1142.get_m31(6))) + << (UInt16_6))) + >> (UInt16_1)) + & (UInt16_1)); + let op0_base_fp_col7 = op0_base_fp_tmp_1147.as_m31(); + trace[7].data[row_index] = op0_base_fp_col7; + let ap_update_add_1_tmp_1148 = + (((((PackedUInt16::from_m31(memory_id_to_big_value_tmp_1142.get_m31(5))) + >> (UInt16_3)) + + ((PackedUInt16::from_m31(memory_id_to_big_value_tmp_1142.get_m31(6))) + << (UInt16_6))) + >> (UInt16_11)) + & (UInt16_1)); + let ap_update_add_1_col8 = ap_update_add_1_tmp_1148.as_m31(); + trace[8].data[row_index] = ap_update_add_1_col8; + + sub_components_inputs.verify_instruction_inputs[0].extend( + ( + input_pc_col0, + [offset0_col3, offset1_col4, offset2_col5], + [ + dst_base_fp_col6, + op0_base_fp_col7, + M31_0, + M31_0, + M31_0, + M31_0, + M31_0, + M31_0, + M31_0, + M31_0, + M31_0, + ap_update_add_1_col8, + M31_0, + M31_0, + M31_1, + ], + ) + .unpack(), + ); + + lookup_data.verifyinstruction[0].push([ + input_pc_col0, + offset0_col3, + offset1_col4, + offset2_col5, + dst_base_fp_col6, + op0_base_fp_col7, + M31_0, + M31_0, + M31_0, + M31_0, + M31_0, + M31_0, + M31_0, + M31_0, + M31_0, + ap_update_add_1_col8, + M31_0, + M31_0, + M31_1, + ]); + + // read_positive_num_bits_27. + + let memory_address_to_id_value_tmp_1152 = memory_address_to_id_state.deduce_output( + ((((op0_base_fp_col7) * (input_fp_col2)) + + (((M31_1) - (op0_base_fp_col7)) * (input_ap_col1))) + + ((offset1_col4) - (M31_32768))), + ); + let memory_id_to_big_value_tmp_1153 = + memory_id_to_big_state.deduce_output(memory_address_to_id_value_tmp_1152); + let mem1_base_id_col9 = memory_address_to_id_value_tmp_1152; + trace[9].data[row_index] = mem1_base_id_col9; + sub_components_inputs.memory_address_to_id_inputs[0].extend( + ((((op0_base_fp_col7) * (input_fp_col2)) + + (((M31_1) - (op0_base_fp_col7)) * (input_ap_col1))) + + ((offset1_col4) - (M31_32768))) + .unpack(), + ); + + lookup_data.memoryaddresstoid[0].push([ + ((((op0_base_fp_col7) * (input_fp_col2)) + + (((M31_1) - (op0_base_fp_col7)) * (input_ap_col1))) + + ((offset1_col4) - (M31_32768))), + mem1_base_id_col9, + ]); + let mem1_base_limb_0_col10 = memory_id_to_big_value_tmp_1153.get_m31(0); + trace[10].data[row_index] = mem1_base_limb_0_col10; + let mem1_base_limb_1_col11 = memory_id_to_big_value_tmp_1153.get_m31(1); + trace[11].data[row_index] = mem1_base_limb_1_col11; + let mem1_base_limb_2_col12 = memory_id_to_big_value_tmp_1153.get_m31(2); + trace[12].data[row_index] = mem1_base_limb_2_col12; + sub_components_inputs.memory_id_to_big_inputs[0].extend(mem1_base_id_col9.unpack()); + + lookup_data.memoryidtobig[0].push([ + mem1_base_id_col9, + mem1_base_limb_0_col10, + mem1_base_limb_1_col11, + mem1_base_limb_2_col12, + M31_0, + M31_0, + M31_0, + M31_0, + M31_0, + M31_0, + M31_0, + M31_0, + M31_0, + M31_0, + M31_0, + M31_0, + M31_0, + M31_0, + M31_0, + M31_0, + M31_0, + M31_0, + M31_0, + M31_0, + M31_0, + M31_0, + M31_0, + M31_0, + M31_0, + ]); + + // mem_verify_equal. + + let memory_address_to_id_value_tmp_1154 = memory_address_to_id_state.deduce_output( + ((((dst_base_fp_col6) * (input_fp_col2)) + + (((M31_1) - (dst_base_fp_col6)) * (input_ap_col1))) + + ((offset0_col3) - (M31_32768))), + ); + let dst_id_col13 = memory_address_to_id_value_tmp_1154; + trace[13].data[row_index] = dst_id_col13; + sub_components_inputs.memory_address_to_id_inputs[1].extend( + ((((dst_base_fp_col6) * (input_fp_col2)) + + (((M31_1) - (dst_base_fp_col6)) * (input_ap_col1))) + + ((offset0_col3) - (M31_32768))) + .unpack(), + ); + + lookup_data.memoryaddresstoid[1].push([ + ((((dst_base_fp_col6) * (input_fp_col2)) + + (((M31_1) - (dst_base_fp_col6)) * (input_ap_col1))) + + ((offset0_col3) - (M31_32768))), + dst_id_col13, + ]); + sub_components_inputs.memory_address_to_id_inputs[2].extend( + ((((mem1_base_limb_0_col10) + ((mem1_base_limb_1_col11) * (M31_512))) + + ((mem1_base_limb_2_col12) * (M31_262144))) + + ((offset2_col5) - (M31_32768))) + .unpack(), + ); + + lookup_data.memoryaddresstoid[2].push([ + ((((mem1_base_limb_0_col10) + ((mem1_base_limb_1_col11) * (M31_512))) + + ((mem1_base_limb_2_col12) * (M31_262144))) + + ((offset2_col5) - (M31_32768))), + dst_id_col13, + ]); + + lookup_data.opcodes[0].push([input_pc_col0, input_ap_col1, input_fp_col2]); + lookup_data.opcodes[1].push([ + ((input_pc_col0) + (M31_1)), + ((input_ap_col1) + (ap_update_add_1_col8)), + input_fp_col2, + ]); + }, + ); + + (trace, sub_components_inputs, lookup_data) +} + +pub struct LookupData { + pub memoryaddresstoid: [Vec<[PackedM31; 2]>; 3], + pub memoryidtobig: [Vec<[PackedM31; 29]>; 1], + pub opcodes: [Vec<[PackedM31; 3]>; 2], + pub verifyinstruction: [Vec<[PackedM31; 19]>; 1], +} +impl LookupData { + #[allow(unused_variables)] + fn with_capacity(capacity: usize) -> Self { + Self { + memoryaddresstoid: [ + Vec::with_capacity(capacity), + Vec::with_capacity(capacity), + Vec::with_capacity(capacity), + ], + memoryidtobig: [Vec::with_capacity(capacity)], + opcodes: [Vec::with_capacity(capacity), Vec::with_capacity(capacity)], + verifyinstruction: [Vec::with_capacity(capacity)], + } + } +} + +pub struct InteractionClaimGenerator { + pub n_calls: usize, + pub lookup_data: LookupData, +} +impl InteractionClaimGenerator { + pub fn write_interaction_trace( + self, + tree_builder: &mut TreeBuilder<'_, '_, SimdBackend, Blake2sMerkleChannel>, + memoryaddresstoid_lookup_elements: &relations::MemoryAddressToId, + memoryidtobig_lookup_elements: &relations::MemoryIdToBig, + opcodes_lookup_elements: &relations::Opcodes, + verifyinstruction_lookup_elements: &relations::VerifyInstruction, + ) -> InteractionClaim { + let log_size = std::cmp::max(self.n_calls.next_power_of_two().ilog2(), LOG_N_LANES); + let mut logup_gen = LogupTraceGenerator::new(log_size); + + let mut col_gen = logup_gen.new_col(); + let lookup_row = &self.lookup_data.verifyinstruction[0]; + for (i, lookup_values) in lookup_row.iter().enumerate() { + let denom = verifyinstruction_lookup_elements.combine(lookup_values); + col_gen.write_frac(i, PackedQM31::one(), denom); + } + col_gen.finalize_col(); + + let mut col_gen = logup_gen.new_col(); + let lookup_row = &self.lookup_data.memoryaddresstoid[0]; + for (i, lookup_values) in lookup_row.iter().enumerate() { + let denom = memoryaddresstoid_lookup_elements.combine(lookup_values); + col_gen.write_frac(i, PackedQM31::one(), denom); + } + col_gen.finalize_col(); + + let mut col_gen = logup_gen.new_col(); + let lookup_row = &self.lookup_data.memoryidtobig[0]; + for (i, lookup_values) in lookup_row.iter().enumerate() { + let denom = memoryidtobig_lookup_elements.combine(lookup_values); + col_gen.write_frac(i, PackedQM31::one(), denom); + } + col_gen.finalize_col(); + + let mut col_gen = logup_gen.new_col(); + let lookup_row = &self.lookup_data.memoryaddresstoid[1]; + for (i, lookup_values) in lookup_row.iter().enumerate() { + let denom = memoryaddresstoid_lookup_elements.combine(lookup_values); + col_gen.write_frac(i, PackedQM31::one(), denom); + } + col_gen.finalize_col(); + + let mut col_gen = logup_gen.new_col(); + let lookup_row = &self.lookup_data.memoryaddresstoid[2]; + for (i, lookup_values) in lookup_row.iter().enumerate() { + let denom = memoryaddresstoid_lookup_elements.combine(lookup_values); + col_gen.write_frac(i, PackedQM31::one(), denom); + } + col_gen.finalize_col(); + + let mut col_gen = logup_gen.new_col(); + let lookup_row = &self.lookup_data.opcodes[0]; + for (i, lookup_values) in lookup_row.iter().enumerate() { + let denom = opcodes_lookup_elements.combine(lookup_values); + col_gen.write_frac(i, PackedQM31::one(), denom); + } + col_gen.finalize_col(); + + let mut col_gen = logup_gen.new_col(); + let lookup_row = &self.lookup_data.opcodes[1]; + for (i, lookup_values) in lookup_row.iter().enumerate() { + let denom = opcodes_lookup_elements.combine(lookup_values); + col_gen.write_frac(i, -PackedQM31::one(), denom); + } + col_gen.finalize_col(); + + let (trace, total_sum, claimed_sum) = if self.n_calls == 1 << log_size { + let (trace, claimed_sum) = logup_gen.finalize_last(); + (trace, claimed_sum, None) + } else { + let (trace, [total_sum, claimed_sum]) = + logup_gen.finalize_at([(1 << log_size) - 1, self.n_calls - 1]); + (trace, total_sum, Some((claimed_sum, self.n_calls - 1))) + }; + tree_builder.extend_evals(trace); + + InteractionClaim { + logup_sums: (total_sum, claimed_sum), + } + } +} diff --git a/stwo_cairo_prover/crates/prover/src/components/mod.rs b/stwo_cairo_prover/crates/prover/src/components/mod.rs index f419ec38..f90af834 100644 --- a/stwo_cairo_prover/crates/prover/src/components/mod.rs +++ b/stwo_cairo_prover/crates/prover/src/components/mod.rs @@ -4,6 +4,9 @@ use stwo_prover::core::backend::simd::conversion::Pack; pub mod add_ap_opcode_is_imm_f_op1_base_fp_f; pub mod add_ap_opcode_is_imm_f_op1_base_fp_t; pub mod add_ap_opcode_is_imm_t_op1_base_fp_f; +pub mod assert_eq_opcode_is_double_deref_f_is_imm_f; +pub mod assert_eq_opcode_is_double_deref_f_is_imm_t; +pub mod assert_eq_opcode_is_double_deref_t_is_imm_f; pub mod generic_opcode; pub mod jnz_opcode_is_taken_f_dst_base_fp_f; pub mod jnz_opcode_is_taken_f_dst_base_fp_t; diff --git a/stwo_cairo_prover/crates/prover/src/input/state_transitions.rs b/stwo_cairo_prover/crates/prover/src/input/state_transitions.rs index af038d74..5b3e0c9b 100644 --- a/stwo_cairo_prover/crates/prover/src/input/state_transitions.rs +++ b/stwo_cairo_prover/crates/prover/src/input/state_transitions.rs @@ -444,7 +444,7 @@ impl StateTransitions { opcode_call: false, opcode_ret: false, opcode_assert_eq: true, - } if !dev_mode => { + } => { if op1_imm { // [ap/fp + offset0] = imm. assert!(