Skip to content

Commit

Permalink
mul
Browse files Browse the repository at this point in the history
  • Loading branch information
ohad-starkware committed Dec 2, 2024
1 parent 0577f30 commit cf8b202
Show file tree
Hide file tree
Showing 11 changed files with 8,051 additions and 2 deletions.
177 changes: 176 additions & 1 deletion stwo_cairo_prover/crates/prover/src/cairo_air/opcodes_air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ use crate::components::{
assert_eq_opcode_is_double_deref_f_is_imm_t, assert_eq_opcode_is_double_deref_t_is_imm_f,
generic_opcode, jnz_opcode_is_taken_f_dst_base_fp_f, jnz_opcode_is_taken_f_dst_base_fp_t,
jnz_opcode_is_taken_t_dst_base_fp_f, jnz_opcode_is_taken_t_dst_base_fp_t, memory_address_to_id,
memory_id_to_big, range_check_19, range_check_9_9, ret_opcode, verify_instruction,
memory_id_to_big, mul_opcode_is_small_f_is_imm_f, mul_opcode_is_small_f_is_imm_t,
range_check_19, range_check_9_9, ret_opcode, verify_instruction,
};
use crate::input::state_transitions::StateTransitions;

Expand All @@ -29,6 +30,8 @@ pub struct OpcodeClaim {
assert_eq_f_t: Vec<assert_eq_opcode_is_double_deref_f_is_imm_t::Claim>,
assert_eq_t_f: Vec<assert_eq_opcode_is_double_deref_t_is_imm_f::Claim>,
generic: Vec<generic_opcode::Claim>,
mul_f_f: Vec<mul_opcode_is_small_f_is_imm_f::Claim>,
mul_f_t: Vec<mul_opcode_is_small_f_is_imm_t::Claim>,
jnz_f_f: Vec<jnz_opcode_is_taken_f_dst_base_fp_f::Claim>,
jnz_f_t: Vec<jnz_opcode_is_taken_f_dst_base_fp_t::Claim>,
jnz_t_f: Vec<jnz_opcode_is_taken_t_dst_base_fp_f::Claim>,
Expand All @@ -44,6 +47,8 @@ impl OpcodeClaim {
self.assert_eq_f_t.iter().for_each(|c| c.mix_into(channel));
self.assert_eq_t_f.iter().for_each(|c| c.mix_into(channel));
self.generic.iter().for_each(|c| c.mix_into(channel));
self.mul_f_f.iter().for_each(|c| c.mix_into(channel));
self.mul_f_t.iter().for_each(|c| c.mix_into(channel));
self.jnz_f_f.iter().for_each(|c| c.mix_into(channel));
self.jnz_f_t.iter().for_each(|c| c.mix_into(channel));
self.jnz_t_f.iter().for_each(|c| c.mix_into(channel));
Expand All @@ -60,6 +65,8 @@ impl OpcodeClaim {
self.assert_eq_f_t.iter().map(|c| c.log_sizes()),
self.assert_eq_t_f.iter().map(|c| c.log_sizes()),
self.generic.iter().map(|c| c.log_sizes()),
self.mul_f_f.iter().map(|c| c.log_sizes()),
self.mul_f_t.iter().map(|c| c.log_sizes()),
self.jnz_f_f.iter().map(|c| c.log_sizes()),
self.jnz_f_t.iter().map(|c| c.log_sizes()),
self.jnz_t_f.iter().map(|c| c.log_sizes()),
Expand All @@ -76,6 +83,8 @@ pub struct OpcodesClaimGenerator {
assert_eq_f_f: Vec<assert_eq_opcode_is_double_deref_f_is_imm_f::ClaimGenerator>,
assert_eq_f_t: Vec<assert_eq_opcode_is_double_deref_f_is_imm_t::ClaimGenerator>,
assert_eq_t_f: Vec<assert_eq_opcode_is_double_deref_t_is_imm_f::ClaimGenerator>,
mul_f_f: Vec<mul_opcode_is_small_f_is_imm_f::ClaimGenerator>,
mul_f_t: Vec<mul_opcode_is_small_f_is_imm_t::ClaimGenerator>,
generic: Vec<generic_opcode::ClaimGenerator>,
jnz_f_f: Vec<jnz_opcode_is_taken_f_dst_base_fp_f::ClaimGenerator>,
jnz_f_t: Vec<jnz_opcode_is_taken_f_dst_base_fp_t::ClaimGenerator>,
Expand All @@ -93,6 +102,8 @@ impl OpcodesClaimGenerator {
let mut assert_eq_f_t = vec![];
let mut assert_eq_t_f = vec![];
let mut generic = vec![];
let mut mul_f_f = vec![];
let mut mul_f_t = vec![];
let mut jnz_f_f = vec![];
let mut jnz_f_t = vec![];
let mut jnz_t_f = vec![];
Expand Down Expand Up @@ -219,6 +230,24 @@ impl OpcodesClaimGenerator {
input.casm_states_by_opcode.generic_opcode,
));
}
if !input
.casm_states_by_opcode
.mul_opcode_is_small_f_is_imm_f
.is_empty()
{
mul_f_f.push(mul_opcode_is_small_f_is_imm_f::ClaimGenerator::new(
input.casm_states_by_opcode.mul_opcode_is_small_f_is_imm_f,
));
}
if !input
.casm_states_by_opcode
.mul_opcode_is_small_f_is_imm_t
.is_empty()
{
mul_f_t.push(mul_opcode_is_small_f_is_imm_t::ClaimGenerator::new(
input.casm_states_by_opcode.mul_opcode_is_small_f_is_imm_t,
));
}
if !input.casm_states_by_opcode.ret_opcode.is_empty() {
ret.push(ret_opcode::ClaimGenerator::new(
input.casm_states_by_opcode.ret_opcode,
Expand All @@ -232,6 +261,8 @@ impl OpcodesClaimGenerator {
assert_eq_f_t,
assert_eq_t_f,
generic,
mul_f_f,
mul_f_t,
jnz_f_f,
jnz_f_t,
jnz_t_f,
Expand Down Expand Up @@ -335,6 +366,32 @@ impl OpcodesClaimGenerator {
)
})
.unzip();
let (mul_f_f_claims, mul_f_f_interaction_gens) = self
.mul_f_f
.into_iter()
.map(|gen| {
gen.write_trace(
tree_builder,
memory_address_to_id_trace_generator,
memory_id_to_value_trace_generator,
range_check_19_trace_generator,
verify_instruction_trace_generator,
)
})
.unzip();
let (mul_f_t_claims, mul_f_t_interaction_gens) = self
.mul_f_t
.into_iter()
.map(|gen| {
gen.write_trace(
tree_builder,
memory_address_to_id_trace_generator,
memory_id_to_value_trace_generator,
range_check_19_trace_generator,
verify_instruction_trace_generator,
)
})
.unzip();
let (jnz_f_f_claims, jnz_f_f_interaction_gens) = self
.jnz_f_f
.into_iter()
Expand Down Expand Up @@ -405,6 +462,8 @@ impl OpcodesClaimGenerator {
assert_eq_f_t: assert_eq_f_t_claims,
assert_eq_t_f: assert_eq_t_f_claims,
generic: generic_opcode_claims,
mul_f_f: mul_f_f_claims,
mul_f_t: mul_f_t_claims,
jnz_f_f: jnz_f_f_claims,
jnz_f_t: jnz_f_t_claims,
jnz_t_f: jnz_t_f_claims,
Expand All @@ -419,6 +478,8 @@ impl OpcodesClaimGenerator {
assert_eq_f_t: assert_eq_f_t_interaction_gens,
assert_eq_t_f: assert_eq_t_f_interaction_gens,
generic_opcode_interaction_gens,
mul_f_f: mul_f_f_interaction_gens,
mul_f_t: mul_f_t_interaction_gens,
jnz_f_f: jnz_f_f_interaction_gens,
jnz_f_t: jnz_f_t_interaction_gens,
jnz_t_f: jnz_t_f_interaction_gens,
Expand All @@ -438,6 +499,8 @@ pub struct OpcodeInteractionClaim {
assert_eq_f_t: Vec<assert_eq_opcode_is_double_deref_f_is_imm_t::InteractionClaim>,
assert_eq_t_f: Vec<assert_eq_opcode_is_double_deref_t_is_imm_f::InteractionClaim>,
generic: Vec<generic_opcode::InteractionClaim>,
mul_f_f: Vec<mul_opcode_is_small_f_is_imm_f::InteractionClaim>,
mul_f_t: Vec<mul_opcode_is_small_f_is_imm_t::InteractionClaim>,
jnz_f_f: Vec<jnz_opcode_is_taken_f_dst_base_fp_f::InteractionClaim>,
jnz_f_t: Vec<jnz_opcode_is_taken_f_dst_base_fp_t::InteractionClaim>,
jnz_t_f: Vec<jnz_opcode_is_taken_t_dst_base_fp_f::InteractionClaim>,
Expand All @@ -453,6 +516,8 @@ impl OpcodeInteractionClaim {
self.assert_eq_f_t.iter().for_each(|c| c.mix_into(channel));
self.assert_eq_t_f.iter().for_each(|c| c.mix_into(channel));
self.generic.iter().for_each(|c| c.mix_into(channel));
self.mul_f_f.iter().for_each(|c| c.mix_into(channel));
self.mul_f_t.iter().for_each(|c| c.mix_into(channel));
self.jnz_f_f.iter().for_each(|c| c.mix_into(channel));
self.jnz_f_t.iter().for_each(|c| c.mix_into(channel));
self.jnz_t_f.iter().for_each(|c| c.mix_into(channel));
Expand Down Expand Up @@ -511,6 +576,20 @@ impl OpcodeInteractionClaim {
None => total_sum,
};
}
for interaction_claim in &self.mul_f_f {
let (total_sum, claimed_sum) = interaction_claim.logup_sums;
sum += match claimed_sum {
Some((claimed_sum, ..)) => claimed_sum,
None => total_sum,
};
}
for interaction_claim in &self.mul_f_t {
let (total_sum, claimed_sum) = interaction_claim.logup_sums;
sum += match claimed_sum {
Some((claimed_sum, ..)) => claimed_sum,
None => total_sum,
};
}
for interaction_claim in &self.jnz_f_f {
let (total_sum, claimed_sum) = interaction_claim.logup_sums;
sum += match claimed_sum {
Expand Down Expand Up @@ -558,6 +637,8 @@ pub struct OpcodesInteractionClaimGenerator {
assert_eq_f_t: Vec<assert_eq_opcode_is_double_deref_f_is_imm_t::InteractionClaimGenerator>,
assert_eq_t_f: Vec<assert_eq_opcode_is_double_deref_t_is_imm_f::InteractionClaimGenerator>,
generic_opcode_interaction_gens: Vec<generic_opcode::InteractionClaimGenerator>,
mul_f_f: Vec<mul_opcode_is_small_f_is_imm_f::InteractionClaimGenerator>,
mul_f_t: Vec<mul_opcode_is_small_f_is_imm_t::InteractionClaimGenerator>,
jnz_f_f: Vec<jnz_opcode_is_taken_f_dst_base_fp_f::InteractionClaimGenerator>,
jnz_f_t: Vec<jnz_opcode_is_taken_f_dst_base_fp_t::InteractionClaimGenerator>,
jnz_t_f: Vec<jnz_opcode_is_taken_t_dst_base_fp_f::InteractionClaimGenerator>,
Expand Down Expand Up @@ -661,6 +742,34 @@ impl OpcodesInteractionClaimGenerator {
)
})
.collect();
let mul_f_f_interaction_claims = self
.mul_f_f
.into_iter()
.map(|gen| {
gen.write_interaction_trace(
tree_builder,
&interaction_elements.memory_address_to_id,
&interaction_elements.memory_id_to_value,
&interaction_elements.opcodes,
&interaction_elements.range_check_19,
&interaction_elements.verify_instruction,
)
})
.collect();
let mul_f_t_interaction_claims = self
.mul_f_t
.into_iter()
.map(|gen| {
gen.write_interaction_trace(
tree_builder,
&interaction_elements.memory_address_to_id,
&interaction_elements.memory_id_to_value,
&interaction_elements.opcodes,
&interaction_elements.range_check_19,
&interaction_elements.verify_instruction,
)
})
.collect();
let jnz_f_f_interaction_claims = self
.jnz_f_f
.into_iter()
Expand Down Expand Up @@ -734,6 +843,8 @@ impl OpcodesInteractionClaimGenerator {
assert_eq_f_t: assert_eq_f_t_interaction_claims,
assert_eq_t_f: assert_eq_t_f_interaction_claims,
generic: generic_opcode_interaction_claims,
mul_f_f: mul_f_f_interaction_claims,
mul_f_t: mul_f_t_interaction_claims,
jnz_f_f: jnz_f_f_interaction_claims,
jnz_f_t: jnz_f_t_interaction_claims,
jnz_t_f: jnz_t_f_interaction_claims,
Expand All @@ -751,6 +862,8 @@ pub struct OpcodeComponents {
assert_eq_f_t: Vec<assert_eq_opcode_is_double_deref_f_is_imm_t::Component>,
assert_eq_t_f: Vec<assert_eq_opcode_is_double_deref_t_is_imm_f::Component>,
generic: Vec<generic_opcode::Component>,
mul_f_f: Vec<mul_opcode_is_small_f_is_imm_f::Component>,
mul_f_t: Vec<mul_opcode_is_small_f_is_imm_t::Component>,
jnz_f_f: Vec<jnz_opcode_is_taken_f_dst_base_fp_f::Component>,
jnz_f_t: Vec<jnz_opcode_is_taken_f_dst_base_fp_t::Component>,
jnz_t_f: Vec<jnz_opcode_is_taken_t_dst_base_fp_f::Component>,
Expand Down Expand Up @@ -930,6 +1043,56 @@ impl OpcodeComponents {
)
})
.collect_vec();
let mul_f_f_components = claim
.mul_f_f
.iter()
.zip(interaction_claim.mul_f_f.iter())
.map(|(&claim, &interaction_claim)| {
mul_opcode_is_small_f_is_imm_f::Component::new(
tree_span_provider,
mul_opcode_is_small_f_is_imm_f::Eval {
claim,
memoryaddresstoid_lookup_elements: interaction_elements
.memory_address_to_id
.clone(),
memoryidtobig_lookup_elements: interaction_elements
.memory_id_to_value
.clone(),
opcodes_lookup_elements: interaction_elements.opcodes.clone(),
rangecheck_19_lookup_elements: interaction_elements.range_check_19.clone(),
verifyinstruction_lookup_elements: interaction_elements
.verify_instruction
.clone(),
},
interaction_claim.logup_sums,
)
})
.collect_vec();
let mul_f_t_components = claim
.mul_f_t
.iter()
.zip(interaction_claim.mul_f_t.iter())
.map(|(&claim, &interaction_claim)| {
mul_opcode_is_small_f_is_imm_t::Component::new(
tree_span_provider,
mul_opcode_is_small_f_is_imm_t::Eval {
claim,
memoryaddresstoid_lookup_elements: interaction_elements
.memory_address_to_id
.clone(),
memoryidtobig_lookup_elements: interaction_elements
.memory_id_to_value
.clone(),
opcodes_lookup_elements: interaction_elements.opcodes.clone(),
rangecheck_19_lookup_elements: interaction_elements.range_check_19.clone(),
verifyinstruction_lookup_elements: interaction_elements
.verify_instruction
.clone(),
},
interaction_claim.logup_sums,
)
})
.collect_vec();
let jnz_f_f_components = claim
.jnz_f_f
.iter()
Expand Down Expand Up @@ -1058,6 +1221,8 @@ impl OpcodeComponents {
assert_eq_f_t: assert_eq_f_t_components,
assert_eq_t_f: assert_eq_t_f_components,
generic: generic_components,
mul_f_f: mul_f_f_components,
mul_f_t: mul_f_t_components,
jnz_f_f: jnz_f_f_components,
jnz_f_t: jnz_f_t_components,
jnz_t_f: jnz_t_f_components,
Expand Down Expand Up @@ -1103,6 +1268,16 @@ impl OpcodeComponents {
.iter()
.map(|component| component as &dyn ComponentProver<SimdBackend>),
);
vec.extend(
self.mul_f_f
.iter()
.map(|component| component as &dyn ComponentProver<SimdBackend>),
);
vec.extend(
self.mul_f_t
.iter()
.map(|component| component as &dyn ComponentProver<SimdBackend>),
);
vec.extend(
self.jnz_f_f
.iter()
Expand Down
9 changes: 8 additions & 1 deletion stwo_cairo_prover/crates/prover/src/components/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,15 @@ pub mod range_check_vector;
pub mod ret_opcode;
pub mod verify_instruction;

// TODO(Ohad): mul small.
pub mod mul_opcode_is_small_f_is_imm_f;
pub mod mul_opcode_is_small_f_is_imm_t;

pub use memory::{memory_address_to_id, memory_id_to_big};
pub use range_check_vector::{range_check_19, range_check_4_3, range_check_7_2_5, range_check_9_9};
pub use range_check_vector::{
range_check_19, range_check_3, range_check_4_3, range_check_6, range_check_7_2_5,
range_check_9_9,
};

pub fn pack_values<T: Pack>(values: &[T]) -> Vec<T::SimdType> {
values
Expand Down
Loading

0 comments on commit cf8b202

Please sign in to comment.