diff --git a/stwo_cairo_prover/crates/prover/src/input/state_transitions.rs b/stwo_cairo_prover/crates/prover/src/input/state_transitions.rs index d38fc50c..431aca48 100644 --- a/stwo_cairo_prover/crates/prover/src/input/state_transitions.rs +++ b/stwo_cairo_prover/crates/prover/src/input/state_transitions.rs @@ -637,6 +637,11 @@ fn is_small_mul(op0: MemoryValue, op1: MemoryValue) -> bool { }) } + +// Ensures that instructions are correctly mapped with the adapter. +// All components were checked except: +// - `jmp rel [ap/fp + offset]` +// - `jmp abs [[ap/fp + offset1] + offset2]` #[cfg(test)] mod tests { use cairo_lang_casm::casm; @@ -675,6 +680,24 @@ mod tests { ); } + #[test] + fn test_jmp_rel() { + let instructions = casm! { + jmp rel 2; + [ap] = [ap-1] + 3, ap++; + } + .instructions; + + let input = input_from_plain_casm(instructions, false); + let casm_states_by_opcode = input.state_transitions.casm_states_by_opcode; + assert_eq!( + casm_states_by_opcode + .jump_opcode_is_rel_t_is_imm_t_is_double_deref_f + .len(), + 1 + ); + } + #[test] fn test_add_ap() { let instructions = casm! { @@ -682,6 +705,7 @@ mod tests { [ap] = 12, ap++; ap += [ap -2]; ap += [fp + 1]; + ap += 1; [ap] = 1, ap++; } .instructions; @@ -696,7 +720,13 @@ mod tests { ); assert_eq!( casm_states_by_opcode - .add_ap_opcode_is_imm_f_op1_base_fp_f + .add_ap_opcode_is_imm_f_op1_base_fp_t + .len(), + 1 + ); + assert_eq!( + casm_states_by_opcode + .add_ap_opcode_is_imm_t_op1_base_fp_f .len(), 1 ); @@ -747,7 +777,7 @@ mod tests { } #[test] - fn test_jnz_taken() { + fn test_jnz_not_taken_ap() { let instructions = casm! { [ap] = 0, ap++; jmp rel 2 if [ap-1] != 0; @@ -766,7 +796,27 @@ mod tests { } #[test] - fn test_jnz_not_taken() { + fn test_jnz_not_taken_fp() { + let instructions = casm! { + call rel 2; + [ap] = 0, ap++; + jmp rel 2 if [fp] != 0; + [ap] = 1, ap++; + } + .instructions; + + let input = input_from_plain_casm(instructions, false); + let casm_states_by_opcode = input.state_transitions.casm_states_by_opcode; + assert_eq!( + casm_states_by_opcode + .jnz_opcode_is_taken_f_dst_base_fp_t + .len(), + 1 + ); + } + + #[test] + fn test_jnz_taken_fp() { let instructions = casm! { call rel 2; jmp rel 2 if [fp-1] != 0; @@ -784,6 +834,25 @@ mod tests { ); } + #[test] + fn test_jnz_taken_ap() { + let instructions = casm! { + [ap] = 5, ap++; + jmp rel 2 if [ap-1] != 0; + [ap] = 1, ap++; + } + .instructions; + + let input = input_from_plain_casm(instructions, false); + let casm_states_by_opcode = input.state_transitions.casm_states_by_opcode; + assert_eq!( + casm_states_by_opcode + .jnz_opcode_is_taken_t_dst_base_fp_f + .len(), + 1 + ); + } + #[test] fn test_assert_equal() { let instructions = casm! { @@ -824,6 +893,12 @@ mod tests { casm_states_by_opcode.add_opcode_is_small_t_is_imm_f.len(), 1 ); + assert_eq!( + casm_states_by_opcode + .assert_eq_opcode_is_double_deref_f_is_imm_t + .len(), + 2 + ); assert_eq!( casm_states_by_opcode.add_opcode_is_small_t_is_imm_t.len(), 1 @@ -905,4 +980,60 @@ mod tests { 1 ); } + + #[test] + fn test_generic() { + let instructions = casm! { + [ap]=1, ap++; + [ap]=2, ap++; + jmp rel [ap-2] if [ap-1] != 0; + [ap]=1, ap++; + } + .instructions; + + let input = input_from_plain_casm(instructions, false); + let casm_states_by_opcode = input.state_transitions.casm_states_by_opcode; + assert_eq!(casm_states_by_opcode.generic_opcode.len(), 1); + } + + #[test] + fn test_ret() { + let instructions = casm! { + [ap] = 10, ap++; + call rel 4; + jmp rel 11; + + jmp rel 4 if [fp-3] != 0; + jmp rel 6; + [ap] = [fp-3] + (-1), ap++; + call rel (-6); + ret; + } + .instructions; + + let input = input_from_plain_casm(instructions, false); + let casm_states_by_opcode = input.state_transitions.casm_states_by_opcode; + assert_eq!(casm_states_by_opcode.ret_opcode.len(), 11); + } + + #[test] + fn test_assert_eq_double_deref() { + let instructions = casm! { + call rel 2; + [ap] = 100, ap++; + [ap] = [[fp - 2] + 2], ap++; // [fp - 2] is the old fp. + [ap] = 5; + } + .instructions; + + let input = input_from_plain_casm(instructions, false); + let casm_states_by_opcode = input.state_transitions.casm_states_by_opcode; + println!("{:?}", casm_states_by_opcode.counts()); + assert_eq!( + casm_states_by_opcode + .assert_eq_opcode_is_double_deref_t_is_imm_f + .len(), + 1 + ); + } }