Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Word-addressable memory - Opcodes #65

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 158 additions & 9 deletions src/zkevm_specs/evm_circuit/execution/memory.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
from zkevm_specs.util.arithmetic import RLC
from ..instruction import Instruction, Transition
from ..opcode import Opcode
from ..table import RW
from ...util import FQ


def memory(instruction: Instruction):

opcode = instruction.opcode_lookup(True)

address = instruction.stack_pop()
offset = address.int_value % 32
offset_bits = to_5_bits(offset)
slot = address.int_value - offset

addr_left = FQ(slot)
addr_right = FQ(slot + 32)

is_mload = instruction.is_equal(opcode, Opcode.MLOAD)
is_mstore8 = instruction.is_equal(opcode, Opcode.MSTORE8)
Expand All @@ -21,19 +29,63 @@ def memory(instruction: Instruction):
memory_offset, address.expr() + FQ(1) + (is_not_mstore8 * FQ(31))
)

# Generate the binary mask that selects the bytes to be read/written.
mask = make_mask(offset, is_mstore8)
constrain_mask(instruction, mask, offset_bits, is_mstore8)
not_mask = not_vec(mask)

# Compute powers of the RLC challenge. These are used to shift bytes in equations below.
X = instruction.randomness
X32 = make_X32(X)
X31_off = make_X31_off(X, offset_bits)
X32_off = X * X31_off

# Read the left slot in all cases.
left, left_prev = instruction.memory_lookup_update(
RW.Write if is_store == FQ(1) else RW.Read, addr_left
)

# Check the consistency of unchanged bytes: L’ & M == L & M
instruction.constrain_equal(
instruction.rlc_encode(mul_vec(left_prev.le_bytes, mask)).rlc_value,
instruction.rlc_encode(mul_vec(left.le_bytes, mask)).rlc_value,
)

# RLC of the B part: the bytes read/written from the left slot.
b = rev_vec(mul_vec(left.le_bytes, not_mask))
b_r = instruction.rlc_encode(b).rlc_value

if is_mstore8 == FQ(1):
instruction.is_equal(
instruction.memory_lookup(RW.Write, address.expr()), FQ(value.le_bytes[0])
# Check the consistency of the one byte to write versus the left slot.
instruction.constrain_equal(
value.le_bytes[0] * X31_off,
b_r,
)

if is_not_mstore8 == FQ(1):
for idx in range(32):
instruction.is_equal(
instruction.memory_lookup(
RW.Write if is_store == FQ(1) else RW.Read, address.expr() + idx
),
FQ(value.le_bytes[31 - idx]),
)

# Read the right slot in the MLOAD/MSTORE case.
right, right_prev = instruction.memory_lookup_update(
RW.Write if is_store == FQ(1) else RW.Read, addr_right
)

# Check the consistency of unchanged bytes: R’ & !M == R & !M
instruction.constrain_equal(
instruction.rlc_encode(mul_vec(right_prev.le_bytes, not_mask)).rlc_value,
instruction.rlc_encode(mul_vec(right.le_bytes, not_mask)).rlc_value,
)

# RLC of the C part: the bytes read/written from the right slot.
c = rev_vec(mul_vec(right.le_bytes, mask))
c_r = instruction.rlc_encode(c).rlc_value

w_r = value.rlc_value # Same value as given from the stack operation.

# Check the consistency of the value with parts from the left and right slots.
instruction.constrain_equal(
w_r * X32_off,
b_r * X32 + c_r,
)

instruction.step_state_transition_in_same_context(
opcode,
Expand All @@ -43,3 +95,100 @@ def memory(instruction: Instruction):
memory_word_size=Transition.to(next_memory_size),
dynamic_gas_cost=memory_expansion_gas_cost,
)


def constrain_mask(instruction, mask, offset_bits, is_mstore8):
# Interpret the mask as a binary number.
mask_value = 0
for (i, m) in enumerate(mask):
m = FQ(m)
# Make sure the mask elements are either 0 or 1.
instruction.constrain_zero(m * (1 - m))
mask_value += 2**i * m

# Compute 2**offset. As a binary number, it looks like this (example offset=4):
# 00001000000000000000000000000000
two_pow_offset = FQ(make_two_pow(offset_bits))

if is_mstore8 == FQ(1):
# If MSTORE8, the mask looks like this (example offset=4):
# 11110111111111111111111111111111
instruction.constrain_equal(
mask_value, 2**32 - 1 - two_pow_offset)

else:
# If MLOAD or  MSTORE, the mask looks like this (example offset=4):
# 11110000000000000000000000000000
instruction.constrain_equal(
mask_value, two_pow_offset - 1)


def make_mask(offset, is_mstore8):
M = [1] * 32

if is_mstore8 == FQ(1):
# If MSTORE8, the mask looks like this (example offset=4):
# 11110111111111111111111111111111
M[offset] = 0
else:
# If MLOAD or  MSTORE, the mask looks like this (example offset=4):
# 11110000000000000000000000000000
for i in range(offset, 32):
M[i] = 0

return bytes(M)


# Witness and constrain the bits of the exponent.
def to_5_bits(offset):
assert offset < 2**5
# Witness, LSB-first.
bits = [(offset >> i) & 1 for i in range(5)]
# Constrain.
assert sum(bit * 2**i for (i, bit) in enumerate(bits)) == offset
for bit in bits:
assert bit * (1 - bit) == 0
return bits


# Compute 2**offset by squaring-and-multiplying.
def make_two_pow(offset_bits):
assert len(offset_bits) == 5
two_pow_offset = 1
for bit in reversed(offset_bits):
two_pow_offset = two_pow_offset * two_pow_offset * (1 + bit)
return two_pow_offset


# Compute `X**(31-offset)` by squaring-and-multiplying.
def make_X31_off(X, offset_bits):
# Express the bits of `31-offset` by flipping the bits of `offset`.
assert len(offset_bits) == 5
not_bits = [1 - b for b in offset_bits]

X_pow = 1
for bit in reversed(not_bits):
X_pow = X_pow * X_pow
X_pow = X_pow * (X if bit else 1)
return X_pow


# Compute X**32 by squaring. This does *not* depend on a witness, only a the challenge X.
def make_X32(X):
X32 = X
for _ in range(5):
X32 = X32 * X32
return X32


def mul_vec(a, b):
assert len(a) == len(b)
return bytes(a[i] * b[i] for i in range(len(a)))


def not_vec(a):
return bytes(1 - v for v in a)


def rev_vec(a):
return bytes(reversed(a))
11 changes: 9 additions & 2 deletions src/zkevm_specs/evm_circuit/instruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,10 +838,17 @@ def stack_lookup(self, rw: RW, stack_pointer_offset: Expression) -> RLC:

def memory_lookup(
self, rw: RW, memory_address: Expression, call_id: Optional[Expression] = None
) -> FQ:
) -> RLC:
curr, _prev = self.memory_lookup_update(rw, memory_address, call_id)
return curr

def memory_lookup_update(
self, rw: RW, memory_address: Expression, call_id: Optional[Expression] = None
) -> Tuple[RLC, RLC]:
if call_id is None:
call_id = self.curr.call_id
return cast_expr(self.rw_lookup(rw, RWTableTag.Memory, call_id, memory_address).value, FQ)
res = self.rw_lookup(rw, RWTableTag.Memory, call_id, memory_address)
return cast_expr(res.value, RLC), cast_expr(res.value_prev, RLC)

def tx_refund_read(self, tx_id: Expression) -> FQ:
return cast_expr(self.rw_lookup(RW.Read, RWTableTag.TxRefund, tx_id).value, FQ)
Expand Down
21 changes: 17 additions & 4 deletions src/zkevm_specs/evm_circuit/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,18 +428,31 @@ def stack_write(self, call_id: IntOrFQ, stack_pointer: IntOrFQ, value: RLC) -> R
RW.Write, RWTableTag.Stack, key1=FQ(call_id), key2=FQ(stack_pointer), value=value
)

def memory_read(self, call_id: IntOrFQ, memory_address: IntOrFQ, byte: IntOrFQ) -> RWDictionary:
def memory_read(self, call_id: IntOrFQ, memory_address: IntOrFQ, value: RLC) -> RWDictionary:
return self._append(
RW.Read, RWTableTag.Memory, key1=FQ(call_id), key2=FQ(memory_address), value=FQ(byte)
RW.Read, RWTableTag.Memory, key1=FQ(call_id), key2=FQ(memory_address), value=value, value_prev=value
)

def memory_write(
self, call_id: IntOrFQ, memory_address: IntOrFQ, byte: IntOrFQ
self, call_id: IntOrFQ, memory_address: IntOrFQ, value: RLC
) -> RWDictionary:
prev = self._memory_find_prev(call_id, memory_address)
return self._append(
RW.Write, RWTableTag.Memory, key1=FQ(call_id), key2=FQ(memory_address), value=FQ(byte)
RW.Write, RWTableTag.Memory, key1=FQ(call_id), key2=FQ(memory_address), value=value, value_prev=prev
)

def _memory_find_prev(
self, call_id: IntOrFQ, memory_address: IntOrFQ
) -> Optional[RLC]:
for rw in reversed(self.rws):
if (
rw.key0 == RWTableTag.Memory
and rw.key1 == FQ(call_id)
and rw.key2 == FQ(memory_address)
):
return rw.value
return RLC(0)

def call_context_read(
self, call_id: IntOrFQ, field_tag: CallContextFieldTag, value: Union[int, FQ, RLC]
) -> RWDictionary:
Expand Down
55 changes: 33 additions & 22 deletions tests/evm/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,58 +41,68 @@
(
Opcode.MSTORE8,
0,
bytes.fromhex("FFFF"),
bytes.fromhex("FF"),
bytes.fromhex("1122"),
bytes.fromhex("11"),
),
(
Opcode.MSTORE8,
1,
bytes.fromhex("FF"),
bytes.fromhex("FFFF"),
bytes.fromhex("1122"),
bytes.fromhex("0011"),
),
)


@pytest.mark.parametrize("opcode, offset, value, memory", TESTING_DATA)
def test_memory(opcode: Opcode, offset: int, value: bytes, memory: bytes):
@pytest.mark.parametrize("opcode, address, value, memory", TESTING_DATA)
def test_memory(opcode: Opcode, address: int, value: bytes, memory: bytes):

# pad memory with 0s to the right up to 64 bytes
memory = memory + bytes(64 - len(memory))

randomness = rand_fq()

offset_rlc = RLC(offset, randomness)
address_rlc = RLC(address, randomness)
value_rlc = RLC(value, randomness)
call_id = 1
curr_memory_word_size = 0
length = offset

is_mload = opcode == Opcode.MLOAD
is_mstore8 = opcode == Opcode.MSTORE8
is_store = 1 - is_mload
is_not_mstore8 = 1 - is_mstore8

bytecode = (
Bytecode().mload(offset_rlc).stop()
Bytecode().mload(address_rlc).stop()
if is_mload
else Bytecode().mstore8(offset_rlc, value_rlc).stop()
else Bytecode().mstore8(address_rlc, value_rlc).stop()
if is_mstore8
else Bytecode().mstore(offset_rlc, value_rlc).stop()
else Bytecode().mstore(address_rlc, value_rlc).stop()
)
rw_dictionary = (
RWDictionary(1).stack_read(call_id, 1022, offset_rlc).stack_write(call_id, 1022, value_rlc)
RWDictionary(1).stack_read(call_id, 1022, address_rlc).stack_write(call_id, 1022, value_rlc)
if is_mload
else RWDictionary(1)
.stack_read(call_id, 1022, offset_rlc)
.stack_read(call_id, 1022, address_rlc)
.stack_read(call_id, 1023, value_rlc)
)

bytecode_hash = RLC(bytecode.hash(), randomness)
shift = address % 32
addr_left = address - shift
addr_right = addr_left + 32

value_left = RLC(memory[:32], randomness)
value_right = RLC(memory[32:], randomness)

if is_mstore8:
rw_dictionary.memory_write(call_id, length, value[0])
rw_dictionary.memory_write(call_id, addr_left, value_left)

if is_not_mstore8:
for idx in range(32):
if is_mload:
rw_dictionary.memory_read(call_id, offset + idx, memory[idx])
else:
rw_dictionary.memory_write(call_id, offset + idx, memory[idx])
if is_mload:
rw_dictionary.memory_read(call_id, addr_left, value_left)
rw_dictionary.memory_read(call_id, addr_right, value_right)
else:
rw_dictionary.memory_write(call_id, addr_left, value_left)
rw_dictionary.memory_write(call_id, addr_right, value_right)

tables = Tables(
block_table=set(Block().table_assignments(randomness)),
Expand All @@ -101,10 +111,11 @@ def test_memory(opcode: Opcode, offset: int, value: bytes, memory: bytes):
rw_table=rw_dictionary.rws,
)

address = offset + 1 + (is_not_mstore8 * 31)
next_mem_size, memory_gas_cost = memory_expansion(curr_memory_word_size, address)
mem_byte_size = address + 1 + (is_not_mstore8 * 31)
next_mem_size, memory_gas_cost = memory_expansion(curr_memory_word_size, mem_byte_size)
gas = Opcode.MLOAD.constant_gas_cost() + memory_gas_cost

bytecode_hash = RLC(bytecode.hash(), randomness)
rw_counter = 35 - (is_mstore8 * 31)
program_counter = 66 - (is_mload * 33)
stack_pointer = 1022 + (is_store * 2)
Expand Down