Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mul #219

Merged
merged 1 commit into from
Dec 5, 2024
Merged

mul #219

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 178 additions & 2 deletions stwo_cairo_prover/crates/prover/src/cairo_air/opcodes_air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ use crate::components::{
assert_eq_opcode_is_double_deref_f_is_imm_t, assert_eq_opcode_is_double_deref_t_is_imm_f,
generic_opcode, jnz_opcode_is_taken_f_dst_base_fp_f, jnz_opcode_is_taken_f_dst_base_fp_t,
jnz_opcode_is_taken_t_dst_base_fp_f, jnz_opcode_is_taken_t_dst_base_fp_t, memory_address_to_id,
memory_id_to_big, range_check_19, range_check_9_9, ret_opcode, verify_instruction,
memory_id_to_big, mul_opcode_is_small_f_is_imm_f, mul_opcode_is_small_f_is_imm_t,
range_check_19, range_check_9_9, ret_opcode, verify_instruction,
};
use crate::input::state_transitions::StateTransitions;

Expand All @@ -33,6 +34,8 @@ pub struct OpcodeClaim {
jnz_f_t: Vec<jnz_opcode_is_taken_f_dst_base_fp_t::Claim>,
jnz_t_f: Vec<jnz_opcode_is_taken_t_dst_base_fp_f::Claim>,
jnz_t_t: Vec<jnz_opcode_is_taken_t_dst_base_fp_t::Claim>,
mul_f_f: Vec<mul_opcode_is_small_f_is_imm_f::Claim>,
mul_f_t: Vec<mul_opcode_is_small_f_is_imm_t::Claim>,
ret: Vec<ret_opcode::Claim>,
}
impl OpcodeClaim {
Expand All @@ -48,6 +51,8 @@ impl OpcodeClaim {
self.jnz_f_t.iter().for_each(|c| c.mix_into(channel));
self.jnz_t_f.iter().for_each(|c| c.mix_into(channel));
self.jnz_t_t.iter().for_each(|c| c.mix_into(channel));
self.mul_f_f.iter().for_each(|c| c.mix_into(channel));
self.mul_f_t.iter().for_each(|c| c.mix_into(channel));
self.ret.iter().for_each(|c| c.mix_into(channel));
}

Expand All @@ -64,6 +69,8 @@ impl OpcodeClaim {
self.jnz_f_t.iter().map(|c| c.log_sizes()),
self.jnz_t_f.iter().map(|c| c.log_sizes()),
self.jnz_t_t.iter().map(|c| c.log_sizes()),
self.mul_f_f.iter().map(|c| c.log_sizes()),
self.mul_f_t.iter().map(|c| c.log_sizes()),
self.ret.iter().map(|c| c.log_sizes()),
))
}
Expand All @@ -81,6 +88,8 @@ pub struct OpcodesClaimGenerator {
jnz_f_t: Vec<jnz_opcode_is_taken_f_dst_base_fp_t::ClaimGenerator>,
jnz_t_f: Vec<jnz_opcode_is_taken_t_dst_base_fp_f::ClaimGenerator>,
jnz_t_t: Vec<jnz_opcode_is_taken_t_dst_base_fp_t::ClaimGenerator>,
mul_f_f: Vec<mul_opcode_is_small_f_is_imm_f::ClaimGenerator>,
mul_f_t: Vec<mul_opcode_is_small_f_is_imm_t::ClaimGenerator>,
ret: Vec<ret_opcode::ClaimGenerator>,
}
impl OpcodesClaimGenerator {
Expand All @@ -97,6 +106,8 @@ impl OpcodesClaimGenerator {
let mut jnz_f_t = vec![];
let mut jnz_t_f = vec![];
let mut jnz_t_t = vec![];
let mut mul_f_f = vec![];
let mut mul_f_t = vec![];
let mut ret = vec![];
if !input
.casm_states_by_opcode
Expand Down Expand Up @@ -219,6 +230,26 @@ impl OpcodesClaimGenerator {
.jnz_opcode_is_taken_t_dst_base_fp_t,
));
}
// Handle small mul in big mul component. Temporary until airs are written with Rc_3_6_6.
// TODO(Ohad): mul small.
if !input
.casm_states_by_opcode
.mul_opcode_is_small_f_is_imm_f
.is_empty()
{
mul_f_f.push(mul_opcode_is_small_f_is_imm_f::ClaimGenerator::new(
input.casm_states_by_opcode.mul_opcode_is_small_f_is_imm_f,
));
}
if !input
.casm_states_by_opcode
.mul_opcode_is_small_f_is_imm_t
.is_empty()
{
mul_f_t.push(mul_opcode_is_small_f_is_imm_t::ClaimGenerator::new(
input.casm_states_by_opcode.mul_opcode_is_small_f_is_imm_t,
));
}
if !input.casm_states_by_opcode.ret_opcode.is_empty() {
ret.push(ret_opcode::ClaimGenerator::new(
input.casm_states_by_opcode.ret_opcode,
Expand All @@ -236,6 +267,8 @@ impl OpcodesClaimGenerator {
jnz_f_t,
jnz_t_f,
jnz_t_t,
mul_f_f,
mul_f_t,
ret,
}
}
Expand Down Expand Up @@ -383,7 +416,32 @@ impl OpcodesClaimGenerator {
)
})
.unzip();

let (mul_f_f_claims, mul_f_f_interaction_gens) = self
.mul_f_f
.into_iter()
.map(|gen| {
gen.write_trace(
tree_builder,
memory_address_to_id_trace_generator,
memory_id_to_value_trace_generator,
range_check_19_trace_generator,
verify_instruction_trace_generator,
)
})
.unzip();
let (mul_f_t_claims, mul_f_t_interaction_gens) = self
.mul_f_t
.into_iter()
.map(|gen| {
gen.write_trace(
tree_builder,
memory_address_to_id_trace_generator,
memory_id_to_value_trace_generator,
range_check_19_trace_generator,
verify_instruction_trace_generator,
)
})
.unzip();
let (ret_claims, ret_interaction_gens) = self
.ret
.into_iter()
Expand All @@ -409,6 +467,8 @@ impl OpcodesClaimGenerator {
jnz_f_t: jnz_f_t_claims,
jnz_t_f: jnz_t_f_claims,
jnz_t_t: jnz_t_t_claims,
mul_f_f: mul_f_f_claims,
mul_f_t: mul_f_t_claims,
ret: ret_claims,
},
OpcodesInteractionClaimGenerator {
Expand All @@ -423,6 +483,8 @@ impl OpcodesClaimGenerator {
jnz_f_t: jnz_f_t_interaction_gens,
jnz_t_f: jnz_t_f_interaction_gens,
jnz_t_t: jnz_t_t_interaction_gens,
mul_f_f: mul_f_f_interaction_gens,
mul_f_t: mul_f_t_interaction_gens,
ret_interaction_gens,
},
)
Expand All @@ -442,6 +504,8 @@ pub struct OpcodeInteractionClaim {
jnz_f_t: Vec<jnz_opcode_is_taken_f_dst_base_fp_t::InteractionClaim>,
jnz_t_f: Vec<jnz_opcode_is_taken_t_dst_base_fp_f::InteractionClaim>,
jnz_t_t: Vec<jnz_opcode_is_taken_t_dst_base_fp_t::InteractionClaim>,
mul_f_f: Vec<mul_opcode_is_small_f_is_imm_f::InteractionClaim>,
mul_f_t: Vec<mul_opcode_is_small_f_is_imm_t::InteractionClaim>,
ret: Vec<ret_opcode::InteractionClaim>,
}
impl OpcodeInteractionClaim {
Expand All @@ -457,6 +521,8 @@ impl OpcodeInteractionClaim {
self.jnz_f_t.iter().for_each(|c| c.mix_into(channel));
self.jnz_t_f.iter().for_each(|c| c.mix_into(channel));
self.jnz_t_t.iter().for_each(|c| c.mix_into(channel));
self.mul_f_f.iter().for_each(|c| c.mix_into(channel));
self.mul_f_t.iter().for_each(|c| c.mix_into(channel));
self.ret.iter().for_each(|c| c.mix_into(channel));
}

Expand Down Expand Up @@ -539,6 +605,20 @@ impl OpcodeInteractionClaim {
None => total_sum,
};
}
for interaction_claim in &self.mul_f_f {
let (total_sum, claimed_sum) = interaction_claim.logup_sums;
sum += match claimed_sum {
Some((claimed_sum, ..)) => claimed_sum,
None => total_sum,
};
}
for interaction_claim in &self.mul_f_t {
let (total_sum, claimed_sum) = interaction_claim.logup_sums;
sum += match claimed_sum {
Some((claimed_sum, ..)) => claimed_sum,
None => total_sum,
};
}
for interaction_claim in &self.ret {
let (total_sum, claimed_sum) = interaction_claim.logup_sums;
sum += match claimed_sum {
Expand All @@ -562,6 +642,8 @@ pub struct OpcodesInteractionClaimGenerator {
jnz_f_t: Vec<jnz_opcode_is_taken_f_dst_base_fp_t::InteractionClaimGenerator>,
jnz_t_f: Vec<jnz_opcode_is_taken_t_dst_base_fp_f::InteractionClaimGenerator>,
jnz_t_t: Vec<jnz_opcode_is_taken_t_dst_base_fp_t::InteractionClaimGenerator>,
mul_f_f: Vec<mul_opcode_is_small_f_is_imm_f::InteractionClaimGenerator>,
mul_f_t: Vec<mul_opcode_is_small_f_is_imm_t::InteractionClaimGenerator>,
ret_interaction_gens: Vec<ret_opcode::InteractionClaimGenerator>,
}
impl OpcodesInteractionClaimGenerator {
Expand Down Expand Up @@ -713,6 +795,34 @@ impl OpcodesInteractionClaimGenerator {
)
})
.collect();
let mul_f_f_interaction_claims = self
.mul_f_f
.into_iter()
.map(|gen| {
gen.write_interaction_trace(
tree_builder,
&interaction_elements.memory_address_to_id,
&interaction_elements.memory_id_to_value,
&interaction_elements.opcodes,
&interaction_elements.range_check_19,
&interaction_elements.verify_instruction,
)
})
.collect();
let mul_f_t_interaction_claims = self
.mul_f_t
.into_iter()
.map(|gen| {
gen.write_interaction_trace(
tree_builder,
&interaction_elements.memory_address_to_id,
&interaction_elements.memory_id_to_value,
&interaction_elements.opcodes,
&interaction_elements.range_check_19,
&interaction_elements.verify_instruction,
)
})
.collect();
let ret_interaction_claims = self
.ret_interaction_gens
.into_iter()
Expand All @@ -738,6 +848,8 @@ impl OpcodesInteractionClaimGenerator {
jnz_f_t: jnz_f_t_interaction_claims,
jnz_t_f: jnz_t_f_interaction_claims,
jnz_t_t: jnz_t_t_interaction_claims,
mul_f_f: mul_f_f_interaction_claims,
mul_f_t: mul_f_t_interaction_claims,
ret: ret_interaction_claims,
}
}
Expand All @@ -755,6 +867,8 @@ pub struct OpcodeComponents {
jnz_f_t: Vec<jnz_opcode_is_taken_f_dst_base_fp_t::Component>,
jnz_t_f: Vec<jnz_opcode_is_taken_t_dst_base_fp_f::Component>,
jnz_t_t: Vec<jnz_opcode_is_taken_t_dst_base_fp_t::Component>,
mul_f_f: Vec<mul_opcode_is_small_f_is_imm_f::Component>,
mul_f_t: Vec<mul_opcode_is_small_f_is_imm_t::Component>,
ret: Vec<ret_opcode::Component>,
}
impl OpcodeComponents {
Expand Down Expand Up @@ -1026,6 +1140,56 @@ impl OpcodeComponents {
)
})
.collect_vec();
let mul_f_f_components = claim
.mul_f_f
.iter()
.zip(interaction_claim.mul_f_f.iter())
.map(|(&claim, &interaction_claim)| {
mul_opcode_is_small_f_is_imm_f::Component::new(
tree_span_provider,
mul_opcode_is_small_f_is_imm_f::Eval {
claim,
memoryaddresstoid_lookup_elements: interaction_elements
.memory_address_to_id
.clone(),
memoryidtobig_lookup_elements: interaction_elements
.memory_id_to_value
.clone(),
opcodes_lookup_elements: interaction_elements.opcodes.clone(),
rangecheck_19_lookup_elements: interaction_elements.range_check_19.clone(),
verifyinstruction_lookup_elements: interaction_elements
.verify_instruction
.clone(),
},
interaction_claim.logup_sums,
)
})
.collect_vec();
let mul_f_t_components = claim
.mul_f_t
.iter()
.zip(interaction_claim.mul_f_t.iter())
.map(|(&claim, &interaction_claim)| {
mul_opcode_is_small_f_is_imm_t::Component::new(
tree_span_provider,
mul_opcode_is_small_f_is_imm_t::Eval {
claim,
memoryaddresstoid_lookup_elements: interaction_elements
.memory_address_to_id
.clone(),
memoryidtobig_lookup_elements: interaction_elements
.memory_id_to_value
.clone(),
opcodes_lookup_elements: interaction_elements.opcodes.clone(),
rangecheck_19_lookup_elements: interaction_elements.range_check_19.clone(),
verifyinstruction_lookup_elements: interaction_elements
.verify_instruction
.clone(),
},
interaction_claim.logup_sums,
)
})
.collect_vec();
let ret_components = claim
.ret
.iter()
Expand Down Expand Up @@ -1062,6 +1226,8 @@ impl OpcodeComponents {
jnz_f_t: jnz_f_t_components,
jnz_t_f: jnz_t_f_components,
jnz_t_t: jnz_t_t_components,
mul_f_f: mul_f_f_components,
mul_f_t: mul_f_t_components,
ret: ret_components,
}
}
Expand Down Expand Up @@ -1123,6 +1289,16 @@ impl OpcodeComponents {
.iter()
.map(|component| component as &dyn ComponentProver<SimdBackend>),
);
vec.extend(
self.mul_f_f
.iter()
.map(|component| component as &dyn ComponentProver<SimdBackend>),
);
vec.extend(
self.mul_f_t
.iter()
.map(|component| component as &dyn ComponentProver<SimdBackend>),
);
vec.extend(
self.ret
.iter()
Expand Down
4 changes: 4 additions & 0 deletions stwo_cairo_prover/crates/prover/src/components/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ pub mod range_check_vector;
pub mod ret_opcode;
pub mod verify_instruction;

// TODO(Ohad): mul small.
pub mod mul_opcode_is_small_f_is_imm_f;
pub mod mul_opcode_is_small_f_is_imm_t;

pub use memory::{memory_address_to_id, memory_id_to_big};
pub use range_check_vector::{range_check_19, range_check_4_3, range_check_7_2_5, range_check_9_9};

Expand Down
Loading
Loading