Skip to content

Commit

Permalink
Fix segments order in cairo runner (kkrt-labs#893)
Browse files Browse the repository at this point in the history
I didn't realize first that calling a test entrypoint is like calling
any function in cairo, meaning that one has
- args up to `[fp - 3]`
- return `fp`
- return `pc`

and that it's precisely what's done in the runner initialization. By
just updating the order of the segment, the `output` segment is just
accessed as the last (and only currently as we yet don't parse input
args) argument.

<!-- Reviewable:start -->
This change is [<img src="https://reviewable.io/review_button.svg"
height="34" align="absmiddle"
alt="Reviewable"/>](https://reviewable.io/reviews/kkrt-labs/kakarot/893)
<!-- Reviewable:end -->
  • Loading branch information
ClementWalter authored Jan 19, 2024
1 parent 954e009 commit d6c26cb
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 104 deletions.
49 changes: 21 additions & 28 deletions tests/fixtures/starknet.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
from starkware.starknet.business_logic.state.state_api_objects import BlockInfo
from starkware.starknet.compiler.starknet_pass_manager import starknet_pass_manager
from starkware.starknet.core.os import os_utils
from starkware.starknet.definitions.general_config import StarknetGeneralConfig
from starkware.starknet.testing.starknet import Starknet

Expand Down Expand Up @@ -130,28 +131,24 @@ def starknet_snapshot(starknet):
starknet.state.state = initial_cache_state


@pytest.fixture(scope="session")
def cairo_compile(request):
def _factory(path) -> list:
module_reader = get_module_reader(cairo_path=["src"])
def cairo_compile(path):
module_reader = get_module_reader(cairo_path=["src"])

pass_manager = starknet_pass_manager(
prime=DEFAULT_PRIME,
read_module=module_reader.read,
disable_hint_validation=True,
)

return compile_cairo(
Path(path).read_text(),
pass_manager=pass_manager,
debug_info=request.config.getoption("profile_cairo"),
)
pass_manager = starknet_pass_manager(
prime=DEFAULT_PRIME,
read_module=module_reader.read,
disable_hint_validation=True,
)

return _factory
return compile_cairo(
Path(path).read_text(),
pass_manager=pass_manager,
debug_info=True,
)


@pytest.fixture(scope="module")
def cairo_run(request, cairo_compile) -> list:
def cairo_run(request) -> list:
"""
Run the cairo program corresponding to the python test file at a given entrypoint with given program inputs as kwargs.
Returns the output of the cairo program put in the output memory segment.
Expand All @@ -174,21 +171,17 @@ def _factory(entrypoint, **kwargs) -> list:
proof_mode=False,
allow_missing_builtins=False,
)

runner.initialize_segments()
stack = []
for builtin_name in runner.program.builtins:
builtin_runner = runner.builtin_runners.get(f"{builtin_name}_builtin")
if builtin_runner is None:
assert runner.allow_missing_builtins, "Missing builtin."
stack += [0]
else:
stack += builtin_runner.initial_stack()

# Prepare implicit arguments.
implicit_args = os_utils.prepare_os_implicit_args_for_version0_class(
runner=runner
)

output = runner.segments.add()
return_fp = runner.segments.add()
end = runner.segments.add()
output = runner.segments.add()
stack = stack + [return_fp, end, output]
stack = implicit_args + [output, return_fp, end]

runner.initialize_state(
entrypoint=program.identifiers.get_by_full_name(
Expand Down
17 changes: 9 additions & 8 deletions tests/src/kakarot/test_gas.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,17 @@ from kakarot.gas import Gas
from starkware.cairo.common.uint256 import Uint256
from starkware.cairo.lang.compiler.lib.registers import get_fp_and_pc

func test__memory_cost{range_check_ptr}() {
func test__memory_cost{range_check_ptr}(output_ptr: felt*) {
tempvar words_len: felt;
%{ ids.words_len = program_input["words_len"]; %}
let cost = Gas.memory_cost(words_len);

%{ segments.write_arg(output, [ids.cost]); %}
assert [output_ptr] = cost;

return ();
}

func test__memory_expansion_cost{range_check_ptr}() {
func test__memory_expansion_cost{range_check_ptr}(output_ptr: felt*) {
tempvar words_len: felt;
tempvar max_offset: felt;
%{
Expand All @@ -22,11 +23,11 @@ func test__memory_expansion_cost{range_check_ptr}() {
%}
let cost = Gas.calculate_gas_extend_memory(words_len, max_offset);

%{ segments.write_arg(output, [ids.cost]); %}
assert [output_ptr] = cost;
return ();
}

func test__max_memory_expansion_cost{range_check_ptr}() {
func test__max_memory_expansion_cost{range_check_ptr}(output_ptr: felt*) {
alloc_locals;
let fp_and_pc = get_fp_and_pc();
local __fp__: felt* = fp_and_pc.fp_val;
Expand All @@ -48,11 +49,11 @@ func test__max_memory_expansion_cost{range_check_ptr}() {
%}
let cost = Gas.max_memory_expansion_cost(words_len, &offset_1, &size_1, &offset_2, &size_2);

%{ segments.write_arg(output, [ids.cost]); %}
assert [output_ptr] = cost;
return ();
}

func test__compute_message_call_gas{range_check_ptr}() {
func test__compute_message_call_gas{range_check_ptr}(output_ptr: felt*) {
tempvar gas_param: Uint256;
tempvar gas_left: felt;
%{
Expand All @@ -62,6 +63,6 @@ func test__compute_message_call_gas{range_check_ptr}() {
%}
let gas = Gas.compute_message_call_gas(gas_param, gas_left);

%{ segments.write_arg(output, [ids.gas]); %}
assert [output_ptr] = gas;
return ();
}
21 changes: 8 additions & 13 deletions tests/src/utils/test_array.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ from starkware.cairo.common.alloc import alloc

from utils.array import reverse, count_not_zero, slice, contains

func test__reverse() {
func test__reverse(output_ptr: felt*) {
alloc_locals;
tempvar arr_len: felt;
let (arr) = alloc();
Expand All @@ -15,14 +15,11 @@ func test__reverse() {
segments.write_arg(ids.arr, program_input["arr"])
%}

tempvar rev: felt*;
%{ ids.rev = output %}

reverse(rev, arr_len, arr);
reverse(output_ptr, arr_len, arr);
return ();
}

func test__count_not_zero() {
func test__count_not_zero(output_ptr: felt*) {
tempvar arr_len: felt;
let (arr) = alloc();
%{
Expand All @@ -31,11 +28,11 @@ func test__count_not_zero() {
%}

let count = count_not_zero(arr_len, arr);
%{ segments.write_arg(output, [ids.count]) %}
assert [output_ptr] = count;
return ();
}

func test__slice{range_check_ptr}() {
func test__slice{range_check_ptr}(output_ptr: felt*) {
alloc_locals;
tempvar arr_len: felt;
let (arr) = alloc();
Expand All @@ -48,13 +45,11 @@ func test__slice{range_check_ptr}() {
ids.size = program_input["size"]
%}

tempvar sliced: felt*;
%{ ids.sliced = output %}
slice(sliced, arr_len, arr, offset, size);
slice(output_ptr, arr_len, arr, offset, size);
return ();
}

func test_contains{range_check_ptr}() {
func test_contains{range_check_ptr}(output_ptr: felt*) {
alloc_locals;
tempvar arr_len: felt;
let (arr) = alloc();
Expand All @@ -66,6 +61,6 @@ func test_contains{range_check_ptr}() {
%}

let result = contains(arr_len, arr, value);
%{ segments.write_arg(output, [ids.result]) %}
assert [output_ptr] = result;
return ();
}
54 changes: 16 additions & 38 deletions tests/src/utils/test_bytes.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -14,99 +14,79 @@ from utils.bytes import (
bytes_to_bytes8_little_endian,
)

func test__felt_to_ascii{range_check_ptr}() {
func test__felt_to_ascii{range_check_ptr}(output_ptr: felt*) {
alloc_locals;
tempvar n: felt;
%{ ids.n = program_input["n"] %}

tempvar ascii: felt*;
%{ ids.ascii = output %}
felt_to_ascii(ascii, n);
felt_to_ascii(output_ptr, n);
return ();
}

func test__felt_to_bytes_little() {
func test__felt_to_bytes_little(output_ptr: felt*) {
alloc_locals;
tempvar n: felt;
%{ ids.n = program_input["n"] %}

tempvar bytes: felt*;
%{ ids.bytes = output %}

felt_to_bytes_little(bytes, n);
felt_to_bytes_little(output_ptr, n);
return ();
}

func test__felt_to_bytes() {
func test__felt_to_bytes(output_ptr: felt*) {
alloc_locals;
tempvar n: felt;
%{ ids.n = program_input["n"] %}

tempvar bytes: felt*;
%{ ids.bytes = output %}

felt_to_bytes(bytes, n);
felt_to_bytes(output_ptr, n);
return ();
}

func test__felt_to_bytes20{range_check_ptr}() {
func test__felt_to_bytes20{range_check_ptr}(output_ptr: felt*) {
alloc_locals;
tempvar n: felt;
%{ ids.n = program_input["n"] %}

tempvar bytes20: felt*;
%{ ids.bytes20 = output %}

felt_to_bytes20(bytes20, n);
felt_to_bytes20(output_ptr, n);
return ();
}

func test__uint256_to_bytes_little{range_check_ptr}() {
func test__uint256_to_bytes_little{range_check_ptr}(output_ptr: felt*) {
alloc_locals;
tempvar n: Uint256;
%{
ids.n.low = program_input["n"][0]
ids.n.high = program_input["n"][1]
%}

tempvar bytes: felt*;
%{ ids.bytes = output %}

uint256_to_bytes_little(bytes, n);
uint256_to_bytes_little(output_ptr, n);
return ();
}

func test__uint256_to_bytes{range_check_ptr}() {
func test__uint256_to_bytes{range_check_ptr}(output_ptr: felt*) {
alloc_locals;
tempvar n: Uint256;
%{
ids.n.low = program_input["n"][0]
ids.n.high = program_input["n"][1]
%}

tempvar bytes: felt*;
%{ ids.bytes = output %}

uint256_to_bytes(bytes, n);
uint256_to_bytes(output_ptr, n);
return ();
}

func test__uint256_to_bytes32{range_check_ptr}() {
func test__uint256_to_bytes32{range_check_ptr}(output_ptr: felt*) {
alloc_locals;
tempvar n: Uint256;
%{
ids.n.low = program_input["n"][0]
ids.n.high = program_input["n"][1]
%}

tempvar bytes: felt*;
%{ ids.bytes = output %}

uint256_to_bytes32(bytes, n);
uint256_to_bytes32(output_ptr, n);
return ();
}

func test__bytes_to_bytes8_little_endian() {
func test__bytes_to_bytes8_little_endian(output_ptr: felt*) {
alloc_locals;
tempvar bytes_len: felt;
let (bytes) = alloc();
Expand All @@ -115,9 +95,7 @@ func test__bytes_to_bytes8_little_endian() {
segments.write_arg(ids.bytes, program_input["bytes"])
%}

tempvar res: felt*;
%{ ids.res = output %}
bytes_to_bytes8_little_endian(res, bytes_len, bytes);
bytes_to_bytes8_little_endian(output_ptr, bytes_len, bytes);

return ();
}
24 changes: 11 additions & 13 deletions tests/src/utils/test_eth_transaction.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ from starkware.cairo.common.uint256 import Uint256
from starkware.cairo.common.alloc import alloc
from starkware.cairo.common.memcpy import memcpy

func test__decode{bitwise_ptr: BitwiseBuiltin*, range_check_ptr}() {
func test__decode{bitwise_ptr: BitwiseBuiltin*, range_check_ptr}(output_ptr: felt*) {
alloc_locals;
// Given
tempvar data_len: felt;
Expand All @@ -28,18 +28,16 @@ func test__decode{bitwise_ptr: BitwiseBuiltin*, range_check_ptr}() {
payload: felt*,
) = EthTransaction.decode(data_len, data);

tempvar output: felt*;
%{ ids.output = output %}
assert [output] = msg_hash.low;
assert [output + 1] = msg_hash.high;
assert [output + 2] = nonce;
assert [output + 3] = gas_price;
assert [output + 4] = gas_limit;
assert [output + 5] = destination;
assert [output + 6] = amount;
assert [output + 7] = chain_id;
assert [output + 8] = payload_len;
memcpy(output + 9, payload, payload_len);
assert [output_ptr] = msg_hash.low;
assert [output_ptr + 1] = msg_hash.high;
assert [output_ptr + 2] = nonce;
assert [output_ptr + 3] = gas_price;
assert [output_ptr + 4] = gas_limit;
assert [output_ptr + 5] = destination;
assert [output_ptr + 6] = amount;
assert [output_ptr + 7] = chain_id;
assert [output_ptr + 8] = payload_len;
memcpy(output_ptr + 9, payload, payload_len);

return ();
}
Expand Down
6 changes: 2 additions & 4 deletions tests/src/utils/test_rlp.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ from starkware.cairo.common.cairo_builtins import HashBuiltin, BitwiseBuiltin
from starkware.cairo.common.alloc import alloc
from starkware.cairo.common.memcpy import memcpy

func test__decode{range_check_ptr}() {
func test__decode{range_check_ptr}(output_ptr: felt*) {
alloc_locals;
// Given
tempvar data_len: felt;
Expand All @@ -25,9 +25,7 @@ func test__decode{range_check_ptr}() {
%{ ids.is_list = program_input["is_list"] %}
assert item.is_list = is_list;

tempvar output: felt*;
%{ ids.output = output %}
memcpy(output, item.data, item.data_len);
memcpy(output_ptr, item.data, item.data_len);

return ();
}

0 comments on commit d6c26cb

Please sign in to comment.