diff --git a/examples/simple/test/SimpleState.t.sol b/examples/simple/test/SimpleState.t.sol index d4da4b99..1fba072e 100644 --- a/examples/simple/test/SimpleState.t.sol +++ b/examples/simple/test/SimpleState.t.sol @@ -33,15 +33,55 @@ contract SimpleStateTest is SymTest, Test { target = new SimpleState(); } - function check_buggy() public { + function check_buggy_excluding_view() public { bool success; // note: a total of 253 feasible paths are generated, of which only 10 unique states exist for (uint i = 0; i < 10; i++) { - (success,) = address(target).call(svm.createCalldata("SimpleState")); + (success,) = address(target).call(svm.createCalldata("SimpleState")); // excluding view functions vm.assume(success); } assertFalse(target.buggy()); } + + function check_buggy_with_storage_snapshot() public { + bool success; + + // take the initial storage snapshot + uint prev = svm.snapshotStorage(address(target)); + + // note: a total of 253 feasible paths are generated, of which only 10 unique states exist + for (uint i = 0; i < 10; i++) { + (success,) = address(target).call(svm.createCalldata("SimpleState", true)); // including view functions + vm.assume(success); + + // ignore if no storage changes + uint curr = svm.snapshotStorage(address(target)); + vm.assume(curr != prev); + prev = curr; + } + + assertFalse(target.buggy()); + } + + function check_buggy_with_state_snapshot() public { + bool success; + + // take the initial state snapshot + uint prev = vm.snapshotState(); + + // note: a total of 253 feasible paths are generated, of which only 10 unique states exist + for (uint i = 0; i < 10; i++) { + (success,) = address(target).call(svm.createCalldata("SimpleState", true)); // including view functions + vm.assume(success); + + // ignore if no state changes + uint curr = vm.snapshotState(); + vm.assume(curr != prev); + prev = curr; + } + + assertFalse(target.buggy()); + } } diff --git a/pyproject.toml b/pyproject.toml index 114462c7..796b5820 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,8 @@ dependencies = [ "toml>=0.10.2", "z3-solver==4.12.6.0", "eth_hash[pysha3]>=0.7.0", - "rich>=13.9.4" + "rich>=13.9.4", + "xxhash>=3.5.0" ] dynamic = ["version"] diff --git a/src/halmos/cheatcodes.py b/src/halmos/cheatcodes.py index 8edaa7be..505b5a24 100644 --- a/src/halmos/cheatcodes.py +++ b/src/halmos/cheatcodes.py @@ -5,6 +5,7 @@ from dataclasses import dataclass from subprocess import PIPE, Popen +from xxhash import xxh3_64, xxh3_64_digest, xxh3_128 from z3 import ( ULT, And, @@ -230,6 +231,49 @@ def symbolic_storage(ex, arg, sevm, stack, step_id): return ByteVec() # empty return data +def snapshot_storage(ex, arg, sevm, stack, step_id): + account = uint160(arg.get_word(4)) + account_alias = sevm.resolve_address_alias( + ex, account, stack, step_id, allow_branching=False + ) + + if account_alias is None: + error_msg = f"snapshotStorage() is not allowed for a nonexistent account: {hexify(account)}" + raise HalmosException(error_msg) + + zero_pad = b"\x00" * 16 + return ByteVec(zero_pad + ex.storage[account_alias].digest()) + + +def snapshot_state(ex, arg, sevm, stack, step_id): + """ + Generates a snapshot ID by hashing the current state (balance, code, and storage). + + The snapshot ID is constructed by concatenating three hashes of: balance (64 bits), code (64 bits), and storage (128 bits). + This design ensures that the lower 128 bits of both storage and state snapshot IDs correspond to storage hashes. + """ + # balance + balance_hash = xxh3_64_digest(int.to_bytes(ex.balance.get_id(), length=32)) + + # code + m = xxh3_64() + # note: iteration order is guaranteed to be the insertion order + for addr, code in ex.code.items(): + m.update(int.to_bytes(int_of(addr), length=32)) + # simply the object address is used, as code remains unchanged after deployment + m.update(int.to_bytes(id(code), length=32)) + code_hash = m.digest() + + # storage + m = xxh3_128() + for addr, storage in ex.storage.items(): + m.update(int.to_bytes(int_of(addr), length=32)) + m.update(storage.digest()) + storage_hash = m.digest() + + return ByteVec(balance_hash + code_hash + storage_hash) + + def create_calldata_contract(ex, arg, sevm, stack, step_id): contract_name = name_of(extract_string_argument(arg, 0)) return create_calldata_generic(ex, sevm, contract_name) @@ -447,6 +491,8 @@ class halmos_cheat_code: 0x3B0FA01B: create_address, # createAddress(string) 0x6E0BB659: create_bool, # createBool(string) 0xDC00BA4D: symbolic_storage, # enableSymbolicStorage(address) + 0x5DBB8438: snapshot_storage, # snapshotStorage(address) + 0x9CD23835: snapshot_state, # snapshotState() 0xBE92D5A2: create_calldata_contract, # createCalldata(string) 0xDEEF391B: create_calldata_contract_bool, # createCalldata(string,bool) 0x88298B32: create_calldata_file_contract, # createCalldata(string,string) @@ -550,6 +596,9 @@ class hevm_cheat_code: # bytes4(keccak256("getBlockNumber()")) get_block_number_sig: int = 0x42CBB15C + # snapshotState() + snapshot_state_sig: int = 0x9CD23835 + @staticmethod def handle(sevm, ex, arg: ByteVec, stack, step_id) -> ByteVec | None: funsig: int = int_of(arg[:4].unwrap(), "symbolic hevm cheatcode") @@ -852,6 +901,10 @@ def handle(sevm, ex, arg: ByteVec, stack, step_id) -> ByteVec | None: ret.append(uint256(ex.block.number)) return ret + # vm.snapshotState() return (uint256) + elif funsig == hevm_cheat_code.snapshot_state_sig: + return snapshot_state(ex, arg, sevm, stack, step_id) + else: # TODO: support other cheat codes msg = f"Unsupported cheat code: calldata = {hexify(arg)}" diff --git a/src/halmos/sevm.py b/src/halmos/sevm.py index 82609bfe..e20479ed 100644 --- a/src/halmos/sevm.py +++ b/src/halmos/sevm.py @@ -15,6 +15,7 @@ TypeVar, ) +import xxhash from eth_hash.auto import keccak from rich.status import Status from z3 import ( @@ -829,10 +830,39 @@ def extend_path(self, path): self.extend(path.conditions.keys()) -@dataclass class StorageData: - symbolic: bool = False - mapping: dict = field(default_factory=dict) + def __init__(self): + self.symbolic = False + self._mapping = {} + + def __getitem__(self, key) -> ArrayRef | BitVecRef: + return self._mapping[key] + + def __setitem__(self, key, value) -> None: + self._mapping[key] = value + + def __contains__(self, key) -> bool: + return key in self._mapping + + def digest(self) -> bytes: + """ + Computes the xxh3_128 hash of the storage mapping. + + The hash input is constructed by serializing each key-value pair into a byte sequence. + Keys are encoded as 256-bit integers for GenericStorage, or as arrays of 256-bit integers for SolidityStorage. + Values, being Z3 objects, are encoded using their unique identifiers (get_id()) as 256-bit integers. + For simplicity, all numbers are represented as 256-bit integers, regardless of their actual size. + """ + m = xxhash.xxh3_128() + for key, val in self._mapping.items(): + if isinstance(key, int): # GenericStorage + m.update(int.to_bytes(key, length=32)) + else: # SolidityStorage + for _k in key: + # The first key (slot) is of size 256 bits + m.update(int.to_bytes(_k, length=32)) + m.update(int.to_bytes(val.get_id(), length=32)) + return m.digest() class Exec: # an execution path @@ -1303,7 +1333,7 @@ class Storage: class SolidityStorage(Storage): @classmethod def mk_storagedata(cls) -> StorageData: - return StorageData(mapping=defaultdict(lambda: defaultdict(dict))) + return StorageData() @classmethod def empty(cls, addr: BitVecRef, slot: int, keys: tuple) -> ArrayRef: @@ -1327,26 +1357,25 @@ def init( """ assert_address(addr) - storage_data = ex.storage[addr] - mapping = storage_data.mapping[slot][num_keys] + storage = ex.storage[addr] - if size_keys in mapping: + if (slot, num_keys, size_keys) in storage: return if size_keys > 0: # do not use z3 const array `K(BitVecSort(size_keys), ZERO)` when not ex.symbolic # instead use normal smt array, and generate emptyness axiom; see load() - mapping[size_keys] = cls.empty(addr, slot, keys) + storage[slot, num_keys, size_keys] = cls.empty(addr, slot, keys) return # size_keys == 0 - mapping[size_keys] = ( + storage[slot, num_keys, size_keys] = ( BitVec( # note: uuid is excluded to be deterministic f"storage_{id_str(addr)}_{slot}_{num_keys}_{size_keys}_00", BitVecSort256, ) - if storage_data.symbolic + if storage.symbolic else ZERO ) @@ -1356,13 +1385,13 @@ def load(cls, ex: Exec, addr: Any, loc: Word) -> Word: cls.init(ex, addr, slot, keys, num_keys, size_keys) - storage_data = ex.storage[addr] - mapping = storage_data.mapping[slot][num_keys] + storage = ex.storage[addr] + storage_chunk = storage[slot, num_keys, size_keys] if num_keys == 0: - return mapping[size_keys] + return storage_chunk - symbolic = storage_data.symbolic + symbolic = storage.symbolic concat_keys = concat(keys) if not symbolic: @@ -1370,7 +1399,7 @@ def load(cls, ex: Exec, addr: Any, loc: Word) -> Word: default_value = Select(cls.empty(addr, slot, keys), concat_keys) ex.path.append(default_value == ZERO) - return ex.select(mapping[size_keys], concat_keys, ex.storages, symbolic) + return ex.select(storage_chunk, concat_keys, ex.storages, symbolic) @classmethod def store(cls, ex: Exec, addr: Any, loc: Any, val: Any) -> None: @@ -1378,11 +1407,10 @@ def store(cls, ex: Exec, addr: Any, loc: Any, val: Any) -> None: cls.init(ex, addr, slot, keys, num_keys, size_keys) - storage_data = ex.storage[addr] - mapping = storage_data.mapping[slot][num_keys] + storage = ex.storage[addr] if num_keys == 0: - mapping[size_keys] = val + storage[slot, num_keys, size_keys] = val return new_storage_var = Array( @@ -1390,10 +1418,10 @@ def store(cls, ex: Exec, addr: Any, loc: Any, val: Any) -> None: BitVecSorts[size_keys], BitVecSort256, ) - new_storage = Store(mapping[size_keys], concat(keys), val) + new_storage = Store(storage[slot, num_keys, size_keys], concat(keys), val) ex.path.append(new_storage_var == new_storage) - mapping[size_keys] = new_storage_var + storage[slot, num_keys, size_keys] = new_storage_var ex.storages[new_storage_var] = new_storage @classmethod @@ -1497,10 +1525,10 @@ def init(cls, ex: Exec, addr: Any, loc: BitVecRef, size_keys: int) -> None: """ assert_address(addr) - mapping = ex.storage[addr].mapping + storage = ex.storage[addr] - if size_keys not in mapping: - mapping[size_keys] = cls.empty(addr, loc) + if size_keys not in storage: + storage[size_keys] = cls.empty(addr, loc) @classmethod def load(cls, ex: Exec, addr: Any, loc: Word) -> Word: @@ -1509,16 +1537,15 @@ def load(cls, ex: Exec, addr: Any, loc: Word) -> Word: cls.init(ex, addr, loc, size_keys) - storage_data = ex.storage[addr] - mapping = storage_data.mapping - symbolic = storage_data.symbolic + storage = ex.storage[addr] + symbolic = storage.symbolic if not symbolic: # generate emptyness axiom for each array index, instead of using quantified formula; see init() default_value = Select(cls.empty(addr, loc), loc) ex.path.append(default_value == ZERO) - return ex.select(mapping[size_keys], loc, ex.storages, symbolic) + return ex.select(storage[size_keys], loc, ex.storages, symbolic) @classmethod def store(cls, ex: Exec, addr: Any, loc: Any, val: Any) -> None: @@ -1527,17 +1554,17 @@ def store(cls, ex: Exec, addr: Any, loc: Any, val: Any) -> None: cls.init(ex, addr, loc, size_keys) - mapping = ex.storage[addr].mapping + storage = ex.storage[addr] new_storage_var = Array( f"storage_{id_str(addr)}_{size_keys}_{uid()}_{1+len(ex.storages):>02}", BitVecSorts[size_keys], BitVecSort256, ) - new_storage = Store(mapping[size_keys], loc, val) + new_storage = Store(storage[size_keys], loc, val) ex.path.append(new_storage_var == new_storage) - mapping[size_keys] = new_storage_var + storage[size_keys] = new_storage_var ex.storages[new_storage_var] = new_storage @classmethod diff --git a/tests/expected/all.json b/tests/expected/all.json index 33261dd5..e84858f3 100644 --- a/tests/expected/all.json +++ b/tests/expected/all.json @@ -2709,6 +2709,53 @@ "num_bounded_loops": null } ], + "test/Snapshot.t.sol:SnapshotTest": [ + { + "name": "check_balance_snapshot()", + "exitcode": 0, + "num_models": 0, + "models": null, + "num_paths": null, + "time": null, + "num_bounded_loops": null + }, + { + "name": "check_new_account_snapshot()", + "exitcode": 0, + "num_models": 0, + "models": null, + "num_paths": null, + "time": null, + "num_bounded_loops": null + }, + { + "name": "check_snapshot()", + "exitcode": 0, + "num_models": 0, + "models": null, + "num_paths": null, + "time": null, + "num_bounded_loops": null + }, + { + "name": "check_this_balance_snapshot()", + "exitcode": 0, + "num_models": 0, + "models": null, + "num_paths": null, + "time": null, + "num_bounded_loops": null + }, + { + "name": "check_this_storage_snapshot()", + "exitcode": 0, + "num_models": 0, + "models": null, + "num_paths": null, + "time": null, + "num_bounded_loops": null + } + ], "test/Solver.t.sol:SolverTest": [ { "name": "check_dynamic_array_overflow()", diff --git a/tests/expected/simple.json b/tests/expected/simple.json index 9285225f..d88116fb 100644 --- a/tests/expected/simple.json +++ b/tests/expected/simple.json @@ -227,7 +227,25 @@ ], "test/SimpleState.t.sol:SimpleStateTest": [ { - "name": "check_buggy()", + "name": "check_buggy_excluding_view()", + "exitcode": 1, + "num_models": 1, + "models": null, + "num_paths": null, + "time": null, + "num_bounded_loops": null + }, + { + "name": "check_buggy_with_state_snapshot()", + "exitcode": 1, + "num_models": 1, + "models": null, + "num_paths": null, + "time": null, + "num_bounded_loops": null + }, + { + "name": "check_buggy_with_storage_snapshot()", "exitcode": 1, "num_models": 1, "models": null, diff --git a/tests/lib/forge-std b/tests/lib/forge-std index 978ac6fa..1eea5bae 160000 --- a/tests/lib/forge-std +++ b/tests/lib/forge-std @@ -1 +1 @@ -Subproject commit 978ac6fadb62f5f0b723c996f64be52eddba6801 +Subproject commit 1eea5bae12ae557d589f9f0f0edae2faa47cb262 diff --git a/tests/lib/halmos-cheatcodes b/tests/lib/halmos-cheatcodes index a02072cd..7328abe1 160000 --- a/tests/lib/halmos-cheatcodes +++ b/tests/lib/halmos-cheatcodes @@ -1 +1 @@ -Subproject commit a02072cd5eb8560d00c3f4a73b27831ec6e3137e +Subproject commit 7328abe100445fc53885c21d0e713b95293cf14c diff --git a/tests/regression/test/Snapshot.t.sol b/tests/regression/test/Snapshot.t.sol new file mode 100644 index 00000000..a203f244 --- /dev/null +++ b/tests/regression/test/Snapshot.t.sol @@ -0,0 +1,159 @@ +// SPDX-License-Identifier: AGPL-3.0 +pragma solidity >=0.8.0 <0.9.0; + +import "forge-std/Test.sol"; +import {SymTest} from "halmos-cheatcodes/SymTest.sol"; + +contract C { + uint num; + + receive() external payable {} + + function set(uint val) public { + num = val; + } +} + +/// @custom:halmos --storage-layout solidity +contract SnapshotTest is SymTest, Test { + C c; + + function setUp() public { + c = new C(); + } + + // NOTE: In halmos, the state snapshot ID is constructed by concatenating three hashes of: balance (64 bits), code (64 bits), and storage (128 bits). + + function check_snapshot() public { + uint storage0 = svm.snapshotStorage(address(c)); + uint state0 = vm.snapshotState(); + console.log(storage0); + console.log(state0); + + c.set(0); + + uint storage1 = svm.snapshotStorage(address(c)); + uint state1 = vm.snapshotState(); + console.log(storage1); + console.log(state1); + + // NOTE: two storages are semantically equal, but not structually equal + // assertEq(storage0, storage1); + // assertEq(state0, state1); + assertEq(bytes16(bytes32(state0)), bytes16(bytes32(state1))); // no changes to balance & code + + c.set(0); + + uint storage2 = svm.snapshotStorage(address(c)); + uint state2 = vm.snapshotState(); + console.log(storage2); + console.log(state2); + + // NOTE: failed with the generic storage layout, as the whole storage is an smt array + assertEq(storage1, storage2); + assertEq(state1, state2); + + c.set(1); + + uint storage3 = svm.snapshotStorage(address(c)); + uint state3 = vm.snapshotState(); + console.log(storage3); + console.log(state3); + + assertNotEq(storage2, storage3); + assertNotEq(state2, state3); + assertNotEq(uint128(state2), uint128(state3)); // storage + assertEq(bytes16(bytes32(state2)), bytes16(bytes32(state3))); // no changes to balance & code + + c.set(0); + + uint storage4 = svm.snapshotStorage(address(c)); + uint state4 = vm.snapshotState(); + console.log(storage4); + console.log(state4); + + // NOTE: failed with the generic storage layout, as the whole storage is an smt array + assertEq(storage2, storage4); + assertEq(state2, state4); + } + + function check_this_balance_snapshot() public { + vm.deal(address(this), 10); + + uint state0 = vm.snapshotState(); + console.log(state0); + + payable(c).transfer(1); + + uint state1 = vm.snapshotState(); + console.log(state1); + + assertNotEq(state0, state1); + assertNotEq(bytes8(bytes32(state0)), bytes8(bytes32(state1))); // balance + assertEq(uint192(state0), uint192(state1)); // no changes to code & storage + + payable(c).transfer(0); + + uint state2 = vm.snapshotState(); + console.log(state2); + + assertEq(state1, state2); + } + + function check_this_storage_snapshot() public { + uint state0 = vm.snapshotState(); + uint storage0 = svm.snapshotStorage(address(this)); + uint storage0_c = svm.snapshotStorage(address(c)); + console.log(state0); + console.log(storage0); + console.log(storage0_c); + + address old_c = address(c); + + c = C(payable(0)); + + uint state1 = vm.snapshotState(); + uint storage1 = svm.snapshotStorage(address(this)); + uint storage1_c = svm.snapshotStorage(old_c); + console.log(state1); + console.log(storage1); + console.log(storage1_c); + + assertNotEq(state0, state1); + assertNotEq(uint128(state0), uint128(state1)); // storage + assertEq(bytes16(bytes32(state0)), bytes16(bytes32(state1))); // no changes to balance & code + + assertNotEq(storage0, storage1); // global variable updated + assertEq(storage0_c, storage1_c); // existing account preserved + } + + function check_new_account_snapshot() public { + uint state0 = vm.snapshotState(); + console.log(state0); + + /* C tmp = */ new C(); + + uint state1 = vm.snapshotState(); + console.log(state1); + + assertNotEq(state0, state1); // new account in state1 + assertNotEq(uint192(state0), uint192(state1)); // code & storage + assertEq(bytes8(bytes32(state0)), bytes8(bytes32(state1))); // no changes to balance + } + + function check_balance_snapshot() public { + vm.deal(address(c), 10); + + uint state0 = vm.snapshotState(); + console.log(state0); + + vm.deal(address(c), 10); + + uint state1 = vm.snapshotState(); + console.log(state1); + + // NOTE: symbolic balance mappings are not structurally equal + // assertEq(state0, state1); + assertEq(uint192(state0), uint192(state1)); // no changes to code & storage + } +}