Skip to content

Commit

Permalink
fix bug in the adapter for small/big add and mul
Browse files Browse the repository at this point in the history
  • Loading branch information
Stavbe authored and ohad-starkware committed Dec 5, 2024
1 parent 91e9a7c commit 4a3b790
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 39 deletions.
74 changes: 45 additions & 29 deletions stwo_cairo_prover/crates/prover/src/input/state_transitions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ use super::decode::Instruction;
use super::mem::{MemoryBuilder, MemoryValue};
use super::vm_import::TraceEntry;

// The range of small values is 27 bits.
const SMALL_MAX_VALUE: i32 = 2_i32.pow(27) - 1;
const SMALL_MIN_VALUE: i32 = -(2_i32.pow(27));

// TODO (Stav): Ensure it stays synced with that opcdode AIR's list.
/// This struct holds the components used to prove the opcodes in a Cairo program,
/// and should match the opcode's air used by `stwo-cairo-air`.
Expand Down Expand Up @@ -474,9 +478,9 @@ impl StateTransitions {
// mul.
Instruction {
offset0,
offset1: _,
offset1,
offset2,
dst_base_fp: _,
dst_base_fp,
op0_base_fp,
op1_imm,
op1_base_fp,
Expand All @@ -492,12 +496,20 @@ impl StateTransitions {
opcode_ret: false,
opcode_assert_eq: true,
} if !dev_mode => {
let op1_addr = if op0_base_fp { fp } else { ap };
let op1 = mem.get(op1_addr.0.checked_add_signed(offset0 as i32).unwrap());
let (dst_addr, op0_addr, op1_addr) = (
if dst_base_fp { fp } else { ap },
if op0_base_fp { fp } else { ap },
if op1_base_fp { fp } else { ap },
);
let (dst, op0, op1) = (
mem.get(dst_addr.0.checked_add_signed(offset0 as i32).unwrap()),
mem.get(op0_addr.0.checked_add_signed(offset1 as i32).unwrap()),
mem.get(op1_addr.0.checked_add_signed(offset2 as i32).unwrap()),
);
if op1_imm {
// [ap/fp + offset0] = [ap/fp + offset1] * Imm.
assert!(!op1_base_fp && !op1_base_ap && offset2 == 1);
if let MemoryValue::Small(_) = op1 {
if are_small_operands(dst, op0, op1) {
self.casm_states_by_opcode
.mul_opcode_is_small_t_is_imm_t
.push(state);
Expand All @@ -509,19 +521,13 @@ impl StateTransitions {
} else {
// [ap/fp + offset0] = [ap/fp + offset1] * [ap/fp + offset2].
assert!((op1_base_fp || op1_base_ap));
let op0_addr = if op0_base_fp { fp } else { ap };
let op0 = mem.get(op0_addr.0.checked_add_signed(offset0 as i32).unwrap());
if let MemoryValue::F252(_) = op1 {
self.casm_states_by_opcode
.mul_opcode_is_small_f_is_imm_f
.push(state);
} else if let MemoryValue::F252(_) = op0 {
if are_small_operands(dst, op0, op1) {
self.casm_states_by_opcode
.mul_opcode_is_small_f_is_imm_f
.mul_opcode_is_small_t_is_imm_f
.push(state);
} else {
self.casm_states_by_opcode
.mul_opcode_is_small_t_is_imm_f
.mul_opcode_is_small_f_is_imm_f
.push(state);
}
}
Expand All @@ -530,9 +536,9 @@ impl StateTransitions {
// add.
Instruction {
offset0,
offset1: _,
offset1,
offset2,
dst_base_fp: _,
dst_base_fp,
op0_base_fp,
op1_imm,
op1_base_fp,
Expand All @@ -548,14 +554,20 @@ impl StateTransitions {
opcode_ret: false,
opcode_assert_eq: true,
} if !dev_mode => {
let op1_addr = if op0_base_fp { fp } else { ap };
let op1 = mem.get(op1_addr.0.checked_add_signed(offset0 as i32).unwrap());
let (dst_addr, op0_addr, op1_addr) = (
if dst_base_fp { fp } else { ap },
if op0_base_fp { fp } else { ap },
if op1_base_fp { fp } else { ap },
);
let (dst, op0, op1) = (
mem.get(dst_addr.0.checked_add_signed(offset0 as i32).unwrap()),
mem.get(op0_addr.0.checked_add_signed(offset1 as i32).unwrap()),
mem.get(op1_addr.0.checked_add_signed(offset2 as i32).unwrap()),
);
if op1_imm {
// [ap/fp + offset0] = [ap/fp + offset1] + Imm.
assert!(!op1_base_fp && !op1_base_ap && offset2 == 1);
let op1_addr = if op0_base_fp { fp } else { ap };
let op1 = mem.get(op1_addr.0.checked_add_signed(offset0 as i32).unwrap());
if let MemoryValue::Small(_) = op1 {
if are_small_operands(dst, op0, op1) {
self.casm_states_by_opcode
.add_opcode_is_small_t_is_imm_t
.push(state);
Expand All @@ -567,19 +579,13 @@ impl StateTransitions {
} else {
// [ap/fp + offset0] = [ap/fp + offset1] + [ap/fp + offset2].
assert!((op1_base_fp || op1_base_ap));
let op0_addr = if op0_base_fp { fp } else { ap };
let op0 = mem.get(op0_addr.0.checked_add_signed(offset0 as i32).unwrap());
if let MemoryValue::F252(_) = op1 {
self.casm_states_by_opcode
.add_opcode_is_small_f_is_imm_f
.push(state);
} else if let MemoryValue::F252(_) = op0 {
if are_small_operands(dst, op0, op1) {
self.casm_states_by_opcode
.add_opcode_is_small_t_is_imm_f
.push(state);
} else {
self.casm_states_by_opcode
.add_opcode_is_small_t_is_imm_f
.add_opcode_is_small_f_is_imm_f
.push(state);
}
}
Expand All @@ -592,3 +598,13 @@ impl StateTransitions {
}
}
}

// Returns 'true' if all the operands are within the range of [-2^27, 2^27 - 1].
fn are_small_operands(dst: MemoryValue, op0: MemoryValue, op1: MemoryValue) -> bool {
is_small(dst) && is_small(op0) && is_small(op1)
}

// Returns 'true' if the memory value is within the range of [-2^27, 2^27 - 1].
fn is_small(val: MemoryValue) -> bool {
matches!(val, MemoryValue::Small(val) if (val as i128>= SMALL_MIN_VALUE as i128) && (val as i128 <= SMALL_MAX_VALUE as i128))
}
23 changes: 13 additions & 10 deletions stwo_cairo_prover/crates/prover/src/input/vm_import/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ pub mod tests {
}

// TODO (Stav): Once all the components are in, verify the proof to ensure the sort was correct.
// TODO (Ohad): remove the following doc after deleting dev_mod.
/// When not ignored, the test passes only with dev_mod = false.
#[ignore]
#[test]
fn test_read_from_large_files() {
Expand All @@ -158,10 +160,10 @@ pub mod tests {
assert_eq!(components.add_ap_opcode_is_imm_f_op1_base_fp_f.len(), 0);
assert_eq!(components.add_ap_opcode_is_imm_t_op1_base_fp_f.len(), 36895);
assert_eq!(components.add_ap_opcode_is_imm_f_op1_base_fp_t.len(), 33);
assert_eq!(components.add_opcode_is_small_t_is_imm_t.len(), 94680);
assert_eq!(components.add_opcode_is_small_f_is_imm_f.len(), 181481);
assert_eq!(components.add_opcode_is_small_t_is_imm_f.len(), 44567);
assert_eq!(components.add_opcode_is_small_f_is_imm_t.len(), 12141);
assert_eq!(components.add_opcode_is_small_t_is_imm_t.len(), 83399);
assert_eq!(components.add_opcode_is_small_f_is_imm_f.len(), 189425);
assert_eq!(components.add_opcode_is_small_t_is_imm_f.len(), 36623);
assert_eq!(components.add_opcode_is_small_f_is_imm_t.len(), 23422);
assert_eq!(
components.assert_eq_opcode_is_double_deref_f_is_imm_f.len(),
233432
Expand Down Expand Up @@ -205,13 +207,14 @@ pub mod tests {
.len(),
0
);
assert_eq!(components.mul_opcode_is_small_t_is_imm_t.len(), 14653);
assert_eq!(components.mul_opcode_is_small_t_is_imm_f.len(), 8574);
assert_eq!(components.mul_opcode_is_small_f_is_imm_f.len(), 2572);
assert_eq!(components.mul_opcode_is_small_f_is_imm_t.len(), 3390);
assert_eq!(components.mul_opcode_is_small_t_is_imm_t.len(), 11955);
assert_eq!(components.mul_opcode_is_small_t_is_imm_f.len(), 6895);
assert_eq!(components.mul_opcode_is_small_f_is_imm_f.len(), 4251);
assert_eq!(components.mul_opcode_is_small_f_is_imm_t.len(), 6088);
assert_eq!(components.ret_opcode.len(), 49472);
}

// When not ignored, the test passes only with dev_mod = false.
#[ignore]
#[test]
fn test_read_from_small_files() {
Expand All @@ -221,10 +224,10 @@ pub mod tests {
assert_eq!(components.add_ap_opcode_is_imm_f_op1_base_fp_f.len(), 0);
assert_eq!(components.add_ap_opcode_is_imm_t_op1_base_fp_f.len(), 2);
assert_eq!(components.add_ap_opcode_is_imm_f_op1_base_fp_t.len(), 1);
assert_eq!(components.add_opcode_is_small_t_is_imm_t.len(), 750);
assert_eq!(components.add_opcode_is_small_t_is_imm_t.len(), 950);
assert_eq!(components.add_opcode_is_small_f_is_imm_f.len(), 0);
assert_eq!(components.add_opcode_is_small_t_is_imm_f.len(), 0);
assert_eq!(components.add_opcode_is_small_f_is_imm_t.len(), 200);
assert_eq!(components.add_opcode_is_small_f_is_imm_t.len(), 0);
assert_eq!(
components.assert_eq_opcode_is_double_deref_f_is_imm_f.len(),
55
Expand Down

0 comments on commit 4a3b790

Please sign in to comment.