Skip to content

Commit

Permalink
feat: support for snapshot cheatcodes (#427)
Browse files Browse the repository at this point in the history
  • Loading branch information
daejunpark authored Dec 17, 2024
1 parent f4f3185 commit 0ef9341
Show file tree
Hide file tree
Showing 9 changed files with 381 additions and 36 deletions.
44 changes: 42 additions & 2 deletions examples/simple/test/SimpleState.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,55 @@ contract SimpleStateTest is SymTest, Test {
target = new SimpleState();
}

function check_buggy() public {
function check_buggy_excluding_view() public {
bool success;

// note: a total of 253 feasible paths are generated, of which only 10 unique states exist
for (uint i = 0; i < 10; i++) {
(success,) = address(target).call(svm.createCalldata("SimpleState"));
(success,) = address(target).call(svm.createCalldata("SimpleState")); // excluding view functions
vm.assume(success);
}

assertFalse(target.buggy());
}

function check_buggy_with_storage_snapshot() public {
bool success;

// take the initial storage snapshot
uint prev = svm.snapshotStorage(address(target));

// note: a total of 253 feasible paths are generated, of which only 10 unique states exist
for (uint i = 0; i < 10; i++) {
(success,) = address(target).call(svm.createCalldata("SimpleState", true)); // including view functions
vm.assume(success);

// ignore if no storage changes
uint curr = svm.snapshotStorage(address(target));
vm.assume(curr != prev);
prev = curr;
}

assertFalse(target.buggy());
}

function check_buggy_with_state_snapshot() public {
bool success;

// take the initial state snapshot
uint prev = vm.snapshotState();

// note: a total of 253 feasible paths are generated, of which only 10 unique states exist
for (uint i = 0; i < 10; i++) {
(success,) = address(target).call(svm.createCalldata("SimpleState", true)); // including view functions
vm.assume(success);

// ignore if no state changes
uint curr = vm.snapshotState();
vm.assume(curr != prev);
prev = curr;
}

assertFalse(target.buggy());
}
}
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ dependencies = [
"toml>=0.10.2",
"z3-solver==4.12.6.0",
"eth_hash[pysha3]>=0.7.0",
"rich>=13.9.4"
"rich>=13.9.4",
"xxhash>=3.5.0"
]
dynamic = ["version"]

Expand Down
53 changes: 53 additions & 0 deletions src/halmos/cheatcodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dataclasses import dataclass
from subprocess import PIPE, Popen

from xxhash import xxh3_64, xxh3_64_digest, xxh3_128
from z3 import (
ULT,
And,
Expand Down Expand Up @@ -230,6 +231,49 @@ def symbolic_storage(ex, arg, sevm, stack, step_id):
return ByteVec() # empty return data


def snapshot_storage(ex, arg, sevm, stack, step_id):
account = uint160(arg.get_word(4))
account_alias = sevm.resolve_address_alias(
ex, account, stack, step_id, allow_branching=False
)

if account_alias is None:
error_msg = f"snapshotStorage() is not allowed for a nonexistent account: {hexify(account)}"
raise HalmosException(error_msg)

zero_pad = b"\x00" * 16
return ByteVec(zero_pad + ex.storage[account_alias].digest())


def snapshot_state(ex, arg, sevm, stack, step_id):
"""
Generates a snapshot ID by hashing the current state (balance, code, and storage).
The snapshot ID is constructed by concatenating three hashes of: balance (64 bits), code (64 bits), and storage (128 bits).
This design ensures that the lower 128 bits of both storage and state snapshot IDs correspond to storage hashes.
"""
# balance
balance_hash = xxh3_64_digest(int.to_bytes(ex.balance.get_id(), length=32))

# code
m = xxh3_64()
# note: iteration order is guaranteed to be the insertion order
for addr, code in ex.code.items():
m.update(int.to_bytes(int_of(addr), length=32))
# simply the object address is used, as code remains unchanged after deployment
m.update(int.to_bytes(id(code), length=32))
code_hash = m.digest()

# storage
m = xxh3_128()
for addr, storage in ex.storage.items():
m.update(int.to_bytes(int_of(addr), length=32))
m.update(storage.digest())
storage_hash = m.digest()

return ByteVec(balance_hash + code_hash + storage_hash)


def create_calldata_contract(ex, arg, sevm, stack, step_id):
contract_name = name_of(extract_string_argument(arg, 0))
return create_calldata_generic(ex, sevm, contract_name)
Expand Down Expand Up @@ -447,6 +491,8 @@ class halmos_cheat_code:
0x3B0FA01B: create_address, # createAddress(string)
0x6E0BB659: create_bool, # createBool(string)
0xDC00BA4D: symbolic_storage, # enableSymbolicStorage(address)
0x5DBB8438: snapshot_storage, # snapshotStorage(address)
0x9CD23835: snapshot_state, # snapshotState()
0xBE92D5A2: create_calldata_contract, # createCalldata(string)
0xDEEF391B: create_calldata_contract_bool, # createCalldata(string,bool)
0x88298B32: create_calldata_file_contract, # createCalldata(string,string)
Expand Down Expand Up @@ -550,6 +596,9 @@ class hevm_cheat_code:
# bytes4(keccak256("getBlockNumber()"))
get_block_number_sig: int = 0x42CBB15C

# snapshotState()
snapshot_state_sig: int = 0x9CD23835

@staticmethod
def handle(sevm, ex, arg: ByteVec, stack, step_id) -> ByteVec | None:
funsig: int = int_of(arg[:4].unwrap(), "symbolic hevm cheatcode")
Expand Down Expand Up @@ -852,6 +901,10 @@ def handle(sevm, ex, arg: ByteVec, stack, step_id) -> ByteVec | None:
ret.append(uint256(ex.block.number))
return ret

# vm.snapshotState() return (uint256)
elif funsig == hevm_cheat_code.snapshot_state_sig:
return snapshot_state(ex, arg, sevm, stack, step_id)

else:
# TODO: support other cheat codes
msg = f"Unsupported cheat code: calldata = {hexify(arg)}"
Expand Down
87 changes: 57 additions & 30 deletions src/halmos/sevm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
TypeVar,
)

import xxhash
from eth_hash.auto import keccak
from rich.status import Status
from z3 import (
Expand Down Expand Up @@ -829,10 +830,39 @@ def extend_path(self, path):
self.extend(path.conditions.keys())


@dataclass
class StorageData:
symbolic: bool = False
mapping: dict = field(default_factory=dict)
def __init__(self):
self.symbolic = False
self._mapping = {}

def __getitem__(self, key) -> ArrayRef | BitVecRef:
return self._mapping[key]

def __setitem__(self, key, value) -> None:
self._mapping[key] = value

def __contains__(self, key) -> bool:
return key in self._mapping

def digest(self) -> bytes:
"""
Computes the xxh3_128 hash of the storage mapping.
The hash input is constructed by serializing each key-value pair into a byte sequence.
Keys are encoded as 256-bit integers for GenericStorage, or as arrays of 256-bit integers for SolidityStorage.
Values, being Z3 objects, are encoded using their unique identifiers (get_id()) as 256-bit integers.
For simplicity, all numbers are represented as 256-bit integers, regardless of their actual size.
"""
m = xxhash.xxh3_128()
for key, val in self._mapping.items():
if isinstance(key, int): # GenericStorage
m.update(int.to_bytes(key, length=32))
else: # SolidityStorage
for _k in key:
# The first key (slot) is of size 256 bits
m.update(int.to_bytes(_k, length=32))
m.update(int.to_bytes(val.get_id(), length=32))
return m.digest()


class Exec: # an execution path
Expand Down Expand Up @@ -1303,7 +1333,7 @@ class Storage:
class SolidityStorage(Storage):
@classmethod
def mk_storagedata(cls) -> StorageData:
return StorageData(mapping=defaultdict(lambda: defaultdict(dict)))
return StorageData()

@classmethod
def empty(cls, addr: BitVecRef, slot: int, keys: tuple) -> ArrayRef:
Expand All @@ -1327,26 +1357,25 @@ def init(
"""
assert_address(addr)

storage_data = ex.storage[addr]
mapping = storage_data.mapping[slot][num_keys]
storage = ex.storage[addr]

if size_keys in mapping:
if (slot, num_keys, size_keys) in storage:
return

if size_keys > 0:
# do not use z3 const array `K(BitVecSort(size_keys), ZERO)` when not ex.symbolic
# instead use normal smt array, and generate emptyness axiom; see load()
mapping[size_keys] = cls.empty(addr, slot, keys)
storage[slot, num_keys, size_keys] = cls.empty(addr, slot, keys)
return

# size_keys == 0
mapping[size_keys] = (
storage[slot, num_keys, size_keys] = (
BitVec(
# note: uuid is excluded to be deterministic
f"storage_{id_str(addr)}_{slot}_{num_keys}_{size_keys}_00",
BitVecSort256,
)
if storage_data.symbolic
if storage.symbolic
else ZERO
)

Expand All @@ -1356,44 +1385,43 @@ def load(cls, ex: Exec, addr: Any, loc: Word) -> Word:

cls.init(ex, addr, slot, keys, num_keys, size_keys)

storage_data = ex.storage[addr]
mapping = storage_data.mapping[slot][num_keys]
storage = ex.storage[addr]
storage_chunk = storage[slot, num_keys, size_keys]

if num_keys == 0:
return mapping[size_keys]
return storage_chunk

symbolic = storage_data.symbolic
symbolic = storage.symbolic
concat_keys = concat(keys)

if not symbolic:
# generate emptyness axiom for each array index, instead of using quantified formula; see init()
default_value = Select(cls.empty(addr, slot, keys), concat_keys)
ex.path.append(default_value == ZERO)

return ex.select(mapping[size_keys], concat_keys, ex.storages, symbolic)
return ex.select(storage_chunk, concat_keys, ex.storages, symbolic)

@classmethod
def store(cls, ex: Exec, addr: Any, loc: Any, val: Any) -> None:
(slot, keys, num_keys, size_keys) = cls.get_key_structure(ex, loc)

cls.init(ex, addr, slot, keys, num_keys, size_keys)

storage_data = ex.storage[addr]
mapping = storage_data.mapping[slot][num_keys]
storage = ex.storage[addr]

if num_keys == 0:
mapping[size_keys] = val
storage[slot, num_keys, size_keys] = val
return

new_storage_var = Array(
f"storage_{id_str(addr)}_{slot}_{num_keys}_{size_keys}_{uid()}_{1+len(ex.storages):>02}",
BitVecSorts[size_keys],
BitVecSort256,
)
new_storage = Store(mapping[size_keys], concat(keys), val)
new_storage = Store(storage[slot, num_keys, size_keys], concat(keys), val)
ex.path.append(new_storage_var == new_storage)

mapping[size_keys] = new_storage_var
storage[slot, num_keys, size_keys] = new_storage_var
ex.storages[new_storage_var] = new_storage

@classmethod
Expand Down Expand Up @@ -1497,10 +1525,10 @@ def init(cls, ex: Exec, addr: Any, loc: BitVecRef, size_keys: int) -> None:
"""
assert_address(addr)

mapping = ex.storage[addr].mapping
storage = ex.storage[addr]

if size_keys not in mapping:
mapping[size_keys] = cls.empty(addr, loc)
if size_keys not in storage:
storage[size_keys] = cls.empty(addr, loc)

@classmethod
def load(cls, ex: Exec, addr: Any, loc: Word) -> Word:
Expand All @@ -1509,16 +1537,15 @@ def load(cls, ex: Exec, addr: Any, loc: Word) -> Word:

cls.init(ex, addr, loc, size_keys)

storage_data = ex.storage[addr]
mapping = storage_data.mapping
symbolic = storage_data.symbolic
storage = ex.storage[addr]
symbolic = storage.symbolic

if not symbolic:
# generate emptyness axiom for each array index, instead of using quantified formula; see init()
default_value = Select(cls.empty(addr, loc), loc)
ex.path.append(default_value == ZERO)

return ex.select(mapping[size_keys], loc, ex.storages, symbolic)
return ex.select(storage[size_keys], loc, ex.storages, symbolic)

@classmethod
def store(cls, ex: Exec, addr: Any, loc: Any, val: Any) -> None:
Expand All @@ -1527,17 +1554,17 @@ def store(cls, ex: Exec, addr: Any, loc: Any, val: Any) -> None:

cls.init(ex, addr, loc, size_keys)

mapping = ex.storage[addr].mapping
storage = ex.storage[addr]

new_storage_var = Array(
f"storage_{id_str(addr)}_{size_keys}_{uid()}_{1+len(ex.storages):>02}",
BitVecSorts[size_keys],
BitVecSort256,
)
new_storage = Store(mapping[size_keys], loc, val)
new_storage = Store(storage[size_keys], loc, val)
ex.path.append(new_storage_var == new_storage)

mapping[size_keys] = new_storage_var
storage[size_keys] = new_storage_var
ex.storages[new_storage_var] = new_storage

@classmethod
Expand Down
Loading

0 comments on commit 0ef9341

Please sign in to comment.