Skip to content

Commit

Permalink
mul
Browse files Browse the repository at this point in the history
  • Loading branch information
ohad-starkware committed Dec 4, 2024
1 parent 91e9a7c commit ab85e9e
Show file tree
Hide file tree
Showing 12 changed files with 8,091 additions and 31 deletions.
180 changes: 178 additions & 2 deletions stwo_cairo_prover/crates/prover/src/cairo_air/opcodes_air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ use crate::components::{
assert_eq_opcode_is_double_deref_f_is_imm_t, assert_eq_opcode_is_double_deref_t_is_imm_f,
generic_opcode, jnz_opcode_is_taken_f_dst_base_fp_f, jnz_opcode_is_taken_f_dst_base_fp_t,
jnz_opcode_is_taken_t_dst_base_fp_f, jnz_opcode_is_taken_t_dst_base_fp_t, memory_address_to_id,
memory_id_to_big, range_check_19, range_check_9_9, ret_opcode, verify_instruction,
memory_id_to_big, mul_opcode_is_small_f_is_imm_f, mul_opcode_is_small_f_is_imm_t,
range_check_19, range_check_9_9, ret_opcode, verify_instruction,
};
use crate::input::state_transitions::StateTransitions;

Expand All @@ -33,6 +34,8 @@ pub struct OpcodeClaim {
jnz_f_t: Vec<jnz_opcode_is_taken_f_dst_base_fp_t::Claim>,
jnz_t_f: Vec<jnz_opcode_is_taken_t_dst_base_fp_f::Claim>,
jnz_t_t: Vec<jnz_opcode_is_taken_t_dst_base_fp_t::Claim>,
mul_f_f: Vec<mul_opcode_is_small_f_is_imm_f::Claim>,
mul_f_t: Vec<mul_opcode_is_small_f_is_imm_t::Claim>,
ret: Vec<ret_opcode::Claim>,
}
impl OpcodeClaim {
Expand All @@ -48,6 +51,8 @@ impl OpcodeClaim {
self.jnz_f_t.iter().for_each(|c| c.mix_into(channel));
self.jnz_t_f.iter().for_each(|c| c.mix_into(channel));
self.jnz_t_t.iter().for_each(|c| c.mix_into(channel));
self.mul_f_f.iter().for_each(|c| c.mix_into(channel));
self.mul_f_t.iter().for_each(|c| c.mix_into(channel));
self.ret.iter().for_each(|c| c.mix_into(channel));
}

Expand All @@ -64,6 +69,8 @@ impl OpcodeClaim {
self.jnz_f_t.iter().map(|c| c.log_sizes()),
self.jnz_t_f.iter().map(|c| c.log_sizes()),
self.jnz_t_t.iter().map(|c| c.log_sizes()),
self.mul_f_f.iter().map(|c| c.log_sizes()),
self.mul_f_t.iter().map(|c| c.log_sizes()),
self.ret.iter().map(|c| c.log_sizes()),
))
}
Expand All @@ -81,6 +88,8 @@ pub struct OpcodesClaimGenerator {
jnz_f_t: Vec<jnz_opcode_is_taken_f_dst_base_fp_t::ClaimGenerator>,
jnz_t_f: Vec<jnz_opcode_is_taken_t_dst_base_fp_f::ClaimGenerator>,
jnz_t_t: Vec<jnz_opcode_is_taken_t_dst_base_fp_t::ClaimGenerator>,
mul_f_f: Vec<mul_opcode_is_small_f_is_imm_f::ClaimGenerator>,
mul_f_t: Vec<mul_opcode_is_small_f_is_imm_t::ClaimGenerator>,
ret: Vec<ret_opcode::ClaimGenerator>,
}
impl OpcodesClaimGenerator {
Expand All @@ -97,6 +106,8 @@ impl OpcodesClaimGenerator {
let mut jnz_f_t = vec![];
let mut jnz_t_f = vec![];
let mut jnz_t_t = vec![];
let mut mul_f_f = vec![];
let mut mul_f_t = vec![];
let mut ret = vec![];
if !input
.casm_states_by_opcode
Expand Down Expand Up @@ -219,6 +230,26 @@ impl OpcodesClaimGenerator {
.jnz_opcode_is_taken_t_dst_base_fp_t,
));
}
// Handle small mul in big mul component. Temporary until airs are written with Rc_3_6_6.
// TODO(Ohad): mul small.
if !input
.casm_states_by_opcode
.mul_opcode_is_small_f_is_imm_f
.is_empty()
{
mul_f_f.push(mul_opcode_is_small_f_is_imm_f::ClaimGenerator::new(
input.casm_states_by_opcode.mul_opcode_is_small_f_is_imm_f,
));
}
if !input
.casm_states_by_opcode
.mul_opcode_is_small_f_is_imm_t
.is_empty()
{
mul_f_t.push(mul_opcode_is_small_f_is_imm_t::ClaimGenerator::new(
input.casm_states_by_opcode.mul_opcode_is_small_f_is_imm_t,
));
}
if !input.casm_states_by_opcode.ret_opcode.is_empty() {
ret.push(ret_opcode::ClaimGenerator::new(
input.casm_states_by_opcode.ret_opcode,
Expand All @@ -236,6 +267,8 @@ impl OpcodesClaimGenerator {
jnz_f_t,
jnz_t_f,
jnz_t_t,
mul_f_f,
mul_f_t,
ret,
}
}
Expand Down Expand Up @@ -383,7 +416,32 @@ impl OpcodesClaimGenerator {
)
})
.unzip();

let (mul_f_f_claims, mul_f_f_interaction_gens) = self
.mul_f_f
.into_iter()
.map(|gen| {
gen.write_trace(
tree_builder,
memory_address_to_id_trace_generator,
memory_id_to_value_trace_generator,
range_check_19_trace_generator,
verify_instruction_trace_generator,
)
})
.unzip();
let (mul_f_t_claims, mul_f_t_interaction_gens) = self
.mul_f_t
.into_iter()
.map(|gen| {
gen.write_trace(
tree_builder,
memory_address_to_id_trace_generator,
memory_id_to_value_trace_generator,
range_check_19_trace_generator,
verify_instruction_trace_generator,
)
})
.unzip();
let (ret_claims, ret_interaction_gens) = self
.ret
.into_iter()
Expand All @@ -409,6 +467,8 @@ impl OpcodesClaimGenerator {
jnz_f_t: jnz_f_t_claims,
jnz_t_f: jnz_t_f_claims,
jnz_t_t: jnz_t_t_claims,
mul_f_f: mul_f_f_claims,
mul_f_t: mul_f_t_claims,
ret: ret_claims,
},
OpcodesInteractionClaimGenerator {
Expand All @@ -423,6 +483,8 @@ impl OpcodesClaimGenerator {
jnz_f_t: jnz_f_t_interaction_gens,
jnz_t_f: jnz_t_f_interaction_gens,
jnz_t_t: jnz_t_t_interaction_gens,
mul_f_f: mul_f_f_interaction_gens,
mul_f_t: mul_f_t_interaction_gens,
ret_interaction_gens,
},
)
Expand All @@ -442,6 +504,8 @@ pub struct OpcodeInteractionClaim {
jnz_f_t: Vec<jnz_opcode_is_taken_f_dst_base_fp_t::InteractionClaim>,
jnz_t_f: Vec<jnz_opcode_is_taken_t_dst_base_fp_f::InteractionClaim>,
jnz_t_t: Vec<jnz_opcode_is_taken_t_dst_base_fp_t::InteractionClaim>,
mul_f_f: Vec<mul_opcode_is_small_f_is_imm_f::InteractionClaim>,
mul_f_t: Vec<mul_opcode_is_small_f_is_imm_t::InteractionClaim>,
ret: Vec<ret_opcode::InteractionClaim>,
}
impl OpcodeInteractionClaim {
Expand All @@ -457,6 +521,8 @@ impl OpcodeInteractionClaim {
self.jnz_f_t.iter().for_each(|c| c.mix_into(channel));
self.jnz_t_f.iter().for_each(|c| c.mix_into(channel));
self.jnz_t_t.iter().for_each(|c| c.mix_into(channel));
self.mul_f_f.iter().for_each(|c| c.mix_into(channel));
self.mul_f_t.iter().for_each(|c| c.mix_into(channel));
self.ret.iter().for_each(|c| c.mix_into(channel));
}

Expand Down Expand Up @@ -539,6 +605,20 @@ impl OpcodeInteractionClaim {
None => total_sum,
};
}
for interaction_claim in &self.mul_f_f {
let (total_sum, claimed_sum) = interaction_claim.logup_sums;
sum += match claimed_sum {
Some((claimed_sum, ..)) => claimed_sum,
None => total_sum,
};
}
for interaction_claim in &self.mul_f_t {
let (total_sum, claimed_sum) = interaction_claim.logup_sums;
sum += match claimed_sum {
Some((claimed_sum, ..)) => claimed_sum,
None => total_sum,
};
}
for interaction_claim in &self.ret {
let (total_sum, claimed_sum) = interaction_claim.logup_sums;
sum += match claimed_sum {
Expand All @@ -562,6 +642,8 @@ pub struct OpcodesInteractionClaimGenerator {
jnz_f_t: Vec<jnz_opcode_is_taken_f_dst_base_fp_t::InteractionClaimGenerator>,
jnz_t_f: Vec<jnz_opcode_is_taken_t_dst_base_fp_f::InteractionClaimGenerator>,
jnz_t_t: Vec<jnz_opcode_is_taken_t_dst_base_fp_t::InteractionClaimGenerator>,
mul_f_f: Vec<mul_opcode_is_small_f_is_imm_f::InteractionClaimGenerator>,
mul_f_t: Vec<mul_opcode_is_small_f_is_imm_t::InteractionClaimGenerator>,
ret_interaction_gens: Vec<ret_opcode::InteractionClaimGenerator>,
}
impl OpcodesInteractionClaimGenerator {
Expand Down Expand Up @@ -713,6 +795,34 @@ impl OpcodesInteractionClaimGenerator {
)
})
.collect();
let mul_f_f_interaction_claims = self
.mul_f_f
.into_iter()
.map(|gen| {
gen.write_interaction_trace(
tree_builder,
&interaction_elements.memory_address_to_id,
&interaction_elements.memory_id_to_value,
&interaction_elements.opcodes,
&interaction_elements.range_check_19,
&interaction_elements.verify_instruction,
)
})
.collect();
let mul_f_t_interaction_claims = self
.mul_f_t
.into_iter()
.map(|gen| {
gen.write_interaction_trace(
tree_builder,
&interaction_elements.memory_address_to_id,
&interaction_elements.memory_id_to_value,
&interaction_elements.opcodes,
&interaction_elements.range_check_19,
&interaction_elements.verify_instruction,
)
})
.collect();
let ret_interaction_claims = self
.ret_interaction_gens
.into_iter()
Expand All @@ -738,6 +848,8 @@ impl OpcodesInteractionClaimGenerator {
jnz_f_t: jnz_f_t_interaction_claims,
jnz_t_f: jnz_t_f_interaction_claims,
jnz_t_t: jnz_t_t_interaction_claims,
mul_f_f: mul_f_f_interaction_claims,
mul_f_t: mul_f_t_interaction_claims,
ret: ret_interaction_claims,
}
}
Expand All @@ -755,6 +867,8 @@ pub struct OpcodeComponents {
jnz_f_t: Vec<jnz_opcode_is_taken_f_dst_base_fp_t::Component>,
jnz_t_f: Vec<jnz_opcode_is_taken_t_dst_base_fp_f::Component>,
jnz_t_t: Vec<jnz_opcode_is_taken_t_dst_base_fp_t::Component>,
mul_f_f: Vec<mul_opcode_is_small_f_is_imm_f::Component>,
mul_f_t: Vec<mul_opcode_is_small_f_is_imm_t::Component>,
ret: Vec<ret_opcode::Component>,
}
impl OpcodeComponents {
Expand Down Expand Up @@ -1026,6 +1140,56 @@ impl OpcodeComponents {
)
})
.collect_vec();
let mul_f_f_components = claim
.mul_f_f
.iter()
.zip(interaction_claim.mul_f_f.iter())
.map(|(&claim, &interaction_claim)| {
mul_opcode_is_small_f_is_imm_f::Component::new(
tree_span_provider,
mul_opcode_is_small_f_is_imm_f::Eval {
claim,
memoryaddresstoid_lookup_elements: interaction_elements
.memory_address_to_id
.clone(),
memoryidtobig_lookup_elements: interaction_elements
.memory_id_to_value
.clone(),
opcodes_lookup_elements: interaction_elements.opcodes.clone(),
rangecheck_19_lookup_elements: interaction_elements.range_check_19.clone(),
verifyinstruction_lookup_elements: interaction_elements
.verify_instruction
.clone(),
},
interaction_claim.logup_sums,
)
})
.collect_vec();
let mul_f_t_components = claim
.mul_f_t
.iter()
.zip(interaction_claim.mul_f_t.iter())
.map(|(&claim, &interaction_claim)| {
mul_opcode_is_small_f_is_imm_t::Component::new(
tree_span_provider,
mul_opcode_is_small_f_is_imm_t::Eval {
claim,
memoryaddresstoid_lookup_elements: interaction_elements
.memory_address_to_id
.clone(),
memoryidtobig_lookup_elements: interaction_elements
.memory_id_to_value
.clone(),
opcodes_lookup_elements: interaction_elements.opcodes.clone(),
rangecheck_19_lookup_elements: interaction_elements.range_check_19.clone(),
verifyinstruction_lookup_elements: interaction_elements
.verify_instruction
.clone(),
},
interaction_claim.logup_sums,
)
})
.collect_vec();
let ret_components = claim
.ret
.iter()
Expand Down Expand Up @@ -1062,6 +1226,8 @@ impl OpcodeComponents {
jnz_f_t: jnz_f_t_components,
jnz_t_f: jnz_t_f_components,
jnz_t_t: jnz_t_t_components,
mul_f_f: mul_f_f_components,
mul_f_t: mul_f_t_components,
ret: ret_components,
}
}
Expand Down Expand Up @@ -1123,6 +1289,16 @@ impl OpcodeComponents {
.iter()
.map(|component| component as &dyn ComponentProver<SimdBackend>),
);
vec.extend(
self.mul_f_f
.iter()
.map(|component| component as &dyn ComponentProver<SimdBackend>),
);
vec.extend(
self.mul_f_t
.iter()
.map(|component| component as &dyn ComponentProver<SimdBackend>),
);
vec.extend(
self.ret
.iter()
Expand Down
9 changes: 8 additions & 1 deletion stwo_cairo_prover/crates/prover/src/components/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,15 @@ pub mod range_check_vector;
pub mod ret_opcode;
pub mod verify_instruction;

// TODO(Ohad): mul small.
pub mod mul_opcode_is_small_f_is_imm_f;
pub mod mul_opcode_is_small_f_is_imm_t;

pub use memory::{memory_address_to_id, memory_id_to_big};
pub use range_check_vector::{range_check_19, range_check_4_3, range_check_7_2_5, range_check_9_9};
pub use range_check_vector::{
range_check_19, range_check_3, range_check_4_3, range_check_6, range_check_7_2_5,
range_check_9_9,
};

// When padding is needed, the inputs must be arranged in the order defined by the neighbor
// function. This order allows using the partial sum mechanism to sum only the first n_call inputs.
Expand Down
Loading

0 comments on commit ab85e9e

Please sign in to comment.