diff --git a/.env b/.env index 1fc1ae611..93681e353 100644 --- a/.env +++ b/.env @@ -1,8 +1,10 @@ AMQP_URI=amqp://localhost:5672 -ARITHMETIC_CIRCUIT_SIZE=16..23 -BYTE_PACKING_CIRCUIT_SIZE=9..21 -CPU_CIRCUIT_SIZE=12..25 -KECCAK_CIRCUIT_SIZE=14..20 -KECCAK_SPONGE_CIRCUIT_SIZE=9..15 -LOGIC_CIRCUIT_SIZE=12..18 -MEMORY_CIRCUIT_SIZE=17..28 +ARITHMETIC_CIRCUIT_SIZE=16..21 +BYTE_PACKING_CIRCUIT_SIZE=8..21 +CPU_CIRCUIT_SIZE=8..21 +KECCAK_CIRCUIT_SIZE=4..20 +KECCAK_SPONGE_CIRCUIT_SIZE=8..17 +LOGIC_CIRCUIT_SIZE=4..21 +MEMORY_CIRCUIT_SIZE=17..24 +MEMORY_BEFORE_CIRCUIT_SIZE=16..23 +MEMORY_AFTER_CIRCUIT_SIZE=7..23 diff --git a/Cargo.lock b/Cargo.lock index a75cf45ad..18e708b0a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2016,7 +2016,7 @@ dependencies = [ "pest", "pest_derive", "plonky2", - "plonky2_maybe_rayon", + "plonky2_maybe_rayon 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)", "plonky2_util", "rand", "rand_chacha", @@ -2024,10 +2024,12 @@ dependencies = [ "rlp", "rlp-derive", "serde", + "serde-big-array", "serde_json", "sha2", "starky", "static_assertions", + "thiserror", "tiny-keccak", "zk_evm_proc_macro", ] @@ -3303,6 +3305,7 @@ dependencies = [ "paladin-core", "proof_gen", "serde", + "trace_decoder", "tracing", "zero_bin_common", ] @@ -3619,8 +3622,7 @@ checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" [[package]] name = "plonky2" version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85f26b090b989aebdeaf6a4eed748c1fbcabf67e7273a22e4e0c877b63846d0f" +source = "git+https://github.com/0xPolygonZero/plonky2.git?rev=dc77c77f2b06500e16ad4d7f1c2b057903602eed#dc77c77f2b06500e16ad4d7f1c2b057903602eed" dependencies = [ "ahash", "anyhow", @@ -3631,7 +3633,7 @@ dependencies = [ "log", "num", "plonky2_field", - "plonky2_maybe_rayon", + "plonky2_maybe_rayon 0.2.0 (git+https://github.com/0xPolygonZero/plonky2.git?rev=dc77c77f2b06500e16ad4d7f1c2b057903602eed)", "plonky2_util", "rand", "rand_chacha", @@ -3644,8 +3646,7 @@ dependencies = [ [[package]] name = "plonky2_field" version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a1dca60ad900d81b1fe2df3d0b88d43345988e2935e6709176e96573f4bcf5d" +source = "git+https://github.com/0xPolygonZero/plonky2.git?rev=dc77c77f2b06500e16ad4d7f1c2b057903602eed#dc77c77f2b06500e16ad4d7f1c2b057903602eed" dependencies = [ "anyhow", "itertools 0.11.0", @@ -3666,11 +3667,18 @@ dependencies = [ "rayon", ] +[[package]] +name = "plonky2_maybe_rayon" +version = "0.2.0" +source = "git+https://github.com/0xPolygonZero/plonky2.git?rev=dc77c77f2b06500e16ad4d7f1c2b057903602eed#dc77c77f2b06500e16ad4d7f1c2b057903602eed" +dependencies = [ + "rayon", +] + [[package]] name = "plonky2_util" version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b16136f5f3019c1e83035af76cccddd56d789a5e2933306270185c3f99f12259" +source = "git+https://github.com/0xPolygonZero/plonky2.git?rev=dc77c77f2b06500e16ad4d7f1c2b057903602eed#dc77c77f2b06500e16ad4d7f1c2b057903602eed" [[package]] name = "plotters" @@ -3845,6 +3853,7 @@ name = "proof_gen" version = "0.4.0" dependencies = [ "evm_arithmetization", + "hashbrown", "log", "paste", "plonky2", @@ -3877,10 +3886,14 @@ version = "0.1.0" dependencies = [ "alloy", "anyhow", + "clap", + "evm_arithmetization", "futures", "num-traits", "ops", "paladin-core", + "plonky2", + "plonky2_maybe_rayon 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)", "proof_gen", "ruint", "serde", @@ -4525,6 +4538,15 @@ dependencies = [ "serde_derive", ] +[[package]] +name = "serde-big-array" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11fc7cc2c76d73e0f27ee52abbd64eec84d46f370c88371120433196934e4b7f" +dependencies = [ + "serde", +] + [[package]] name = "serde_derive" version = "1.0.204" @@ -4724,8 +4746,7 @@ checksum = "8acdd7dbfcfb5dd6e46c63512508bf71c2043f70b8f143813ad75cb5e8a589f2" [[package]] name = "starky" version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a8480ca5b8eedf83ad070a780783b4e21a56c6ef66b4c0d1b7520b72bdfda1b" +source = "git+https://github.com/0xPolygonZero/plonky2.git?rev=dc77c77f2b06500e16ad4d7f1c2b057903602eed#dc77c77f2b06500e16ad4d7f1c2b057903602eed" dependencies = [ "ahash", "anyhow", @@ -4734,7 +4755,7 @@ dependencies = [ "log", "num-bigint", "plonky2", - "plonky2_maybe_rayon", + "plonky2_maybe_rayon 0.2.0 (git+https://github.com/0xPolygonZero/plonky2.git?rev=dc77c77f2b06500e16ad4d7f1c2b057903602eed)", "plonky2_util", ] @@ -5168,7 +5189,7 @@ dependencies = [ "mpt_trie", "nunny", "plonky2", - "plonky2_maybe_rayon", + "plonky2_maybe_rayon 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)", "pretty_env_logger", "prover", "rlp", @@ -5842,6 +5863,7 @@ dependencies = [ "serde_json", "thiserror", "tokio", + "trace_decoder", "tracing", "vergen", ] diff --git a/Cargo.toml b/Cargo.toml index 8e1279fdb..111ec8569 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -91,6 +91,7 @@ rlp = "0.5.2" rlp-derive = "0.1.0" ruint = "1.12.3" serde = "1.0.203" +serde-big-array = "0.5.1" serde_json = "1.0.118" serde_path_to_error = "0.1.16" serde_with = "3.8.1" @@ -125,10 +126,10 @@ rpc = { path = "zero_bin/rpc" } zero_bin_common = { path = "zero_bin/common" } # plonky2-related dependencies -plonky2 = "0.2.2" +plonky2 = { git = "https://github.com/0xPolygonZero/plonky2.git", rev = "dc77c77f2b06500e16ad4d7f1c2b057903602eed" } plonky2_maybe_rayon = "0.2.0" -plonky2_util = "0.2.0" -starky = "0.4.0" +plonky2_util = { git = "https://github.com/0xPolygonZero/plonky2.git", rev = "dc77c77f2b06500e16ad4d7f1c2b057903602eed" } +starky = { git = "https://github.com/0xPolygonZero/plonky2.git", rev = "dc77c77f2b06500e16ad4d7f1c2b057903602eed" } # proc macro related dependencies proc-macro2 = "1.0" diff --git a/docs/arithmetization/cpulogic.tex b/docs/arithmetization/cpulogic.tex index 4f68faa7d..b10f01169 100644 --- a/docs/arithmetization/cpulogic.tex +++ b/docs/arithmetization/cpulogic.tex @@ -2,7 +2,7 @@ \section{CPU logic} \label{cpulogic} The CPU is in charge of coordinating the different STARKs, proving the correct execution of the instructions it reads and guaranteeing -that the final state of the EVM corresponds to the starting state after executing the input transaction. All design choices were made +that the final state of the EVM corresponds to the starting state after executing the input transactions. All design choices were made to make sure these properties can be adequately translated into constraints of degree at most 3 while minimizing the size of the different table traces (number of columns and number of rows). @@ -11,29 +11,49 @@ \section{CPU logic} \subsection{Kernel} The kernel is in charge of the proving logic. This section aims at providing a high level overview of this logic. For details about any specific part of the logic, one can consult the various ``asm'' files in the \href{https://github.com/0xPolygonZero/plonky2/tree/main/evm/src/cpu/kernel}{``kernel'' folder}. -We prove one transaction at a time. These proofs can later be aggregated recursively to prove a block. Proof aggregation is however not in the scope of this section. Here, we assume that we have an initial state of the EVM, and we wish to prove that a single transaction was correctly executed, leading to a correct update of the state. +We prove a batch of transactions, split into segments. These proofs can later be aggregated recursively to prove a block. Proof aggregation is however not in the scope of this section. Here, we assume that we have an initial state of the EVM, and we wish to prove that a batch of contiguous transactions was correctly executed, leading to a correct update of the state. -Since we process one transaction at a time, a few intermediary values need to be provided by the prover. Indeed, to prove that the registers in the EVM state are correctly updated, we need to have access to their initial values. When aggregating proofs, we can also constrain those values to match from one transaction to the next. Let us consider the example of the transaction number. Let $n$ be the number of transactions executed so far in the current block. If the current proof is not a dummy one (we are indeed executing a transaction), then the transaction number should be updated: $n := n+1$. Otherwise, the number remains unchanged. We can easily constrain this update. When aggregating the previous transaction proof ($lhs$) with the current one ($rhs$), we also need to check that the output transaction number of $lhs$ is the same as the input transaction number of $rhs$. +Since we process transactions and not entire blocks, a few intermediary values need to be provided by the prover. Indeed, to prove that the registers in the EVM state are correctly updated, we need to have access to their initial values. When aggregating proofs, we can also constrain those values to match from one batch to the next. Let us consider the example of the transaction number. Let $n$ be the number of transactions executed so far in the current block. If the current proof is not a dummy one (we are indeed executing a batch of transactions), then the transaction number should be updated: $n := n+k$ with $k$ the number of transactions in the batch. Otherwise, the number remains unchanged. We can easily constrain this update. When aggregating the previous transaction batch proof ($lhs$) with the current one ($rhs$), we also need to check that the output transaction number of $lhs$ is the same as the input transaction number of $rhs$. Those prover provided values are stored in memory prior to entering the kernel, and are used in the kernel to assert correct updates. The list of prover provided values necessary to the kernel is the following: \begin{enumerate} - \item the previous transaction number: $t_n$, - \item the gas used before executing the current transaction: $g\_u_0$, - \item the gas used after executing the current transaction: $g\_u_1$, - \item the state, transaction and receipts MPTs before executing the current transaction: $\texttt{tries}_0$, - \item the hash of all MPTs before executing the current transaction: $\texttt{digests}_0$, - \item the hash of all MPTs after executing the current transaction: $\texttt{digests}_1$, - \item the RLP encoding of the transaction. + \item the number of the last transaction executed: $t_n$, + \item the gas used before executing the current transactions: $g\_u_0$, + \item the gas used after executing the current transactions: $g\_u_1$, + \item the state, transaction and receipts MPTs before executing the current transactions: $\texttt{tries}_0$, + \item the hash of all MPTs before executing the current transactions: $\texttt{digests}_0$, + \item the hash of all MPTs after executing the current transactions: $\texttt{digests}_1$, + \item the RLP encoding of the transactions. \end{enumerate} -\paragraph*{Initialization:} The first step consists in initializing: +\paragraph*{Segment handling:} +An execution run is split into one or more segments. To ensure continuity, the first cycles of a segment are used to "load" segment data from the previous segment, and the last cycles to +"save" segment data for the next segment. The number of CPU cycles of a segment is bounded by \texttt{MAX\_CPU\_CYCLES}, which can be tweaked for best performance. The segment data values are: \begin{itemize} - \item The shift table: it maps the number of bit shifts $s$ with its shifted value $1 << s$. Note that $0 \leq s \leq 255$. - \item The initial MPTs: the initial state, transaction and receipt tries $\texttt{tries}_0$ are loaded from memory and hashed. The hashes are then compared to $\texttt{digests}\_0$. + \item the stack length, + \item the stack top, + \item the context, + \item the \texttt{is\_kernel} flag, + \item the gas used, + \item the program counter. +\end{itemize} +These values are stored as global metadata, and are loaded from (resp. written to) memory at the beginning (resp. at the end) of a segment. They are propagated +between proofs as public values. + +The initial memory of the first segment is fixed and contains: +\begin{itemize} + \item the kernel code, + \item the shift table. +\end{itemize} + +\paragraph*{Initialization:} The first step of a run consists in initializing: +\begin{itemize} + \item The initial transaction and receipt tries $\texttt{tries}_0$ are loaded from memory. The transaction and the receipt tries are hashed and the hashes are then compared to $\texttt{digests}\_0$. +For efficiency, the initial state trie will be hashed for verification at the end of the run. \item We load the transaction number $t\_n$ and the current gas used $g\_u_0$ from memory. \end{itemize} -If no transaction is provided, we can halt after this initialization. Otherwise, we start processing the transaction. The transaction is provided as its RLP encoding. We can deduce the various transaction fields (such as its type or the transfer value) from its encoding. Based on this, the kernel updates the state trie by executing the transaction. Processing the transaction also includes updating the transactions MPT with the transaction at hand. +We start processing the transactions (if any) sequentially, provided in RLP encoded format. The processing of the transaction returns a boolean ``success'' that indicates whether the transaction was executed successfully, along with the leftover gas. @@ -44,8 +64,9 @@ \subsection{Kernel} Finally, once the three MPTs have been updated, we need to carry out final checks: \begin{itemize} \item the gas used after the execution is equal to $g\_u_1$, - \item the new transaction number is $n+1$ if there was a transaction, - \item the three MPTs are hashed and checked against $\texttt{digests}_1$. + \item the new transaction number is $n + k$ with $k$ the number of processed transactions, + \item the initial state MPT is hashed and checked against $\texttt{digests}_0$. + \item the initial state MPT is updated to reflect the processed transactions, then the three final MPTs are hashed and checked against $\texttt{digests}_1$. \end{itemize} Once those final checks are performed, the program halts. @@ -283,3 +304,57 @@ \subsection{Exceptions} push once at most), and that the faulty instruction is pushing. If the exception is not raised, stack constraints ensure that a stack length of 1025 in user mode will fail the proof. \end{enumerate} + +\subsection{Linked lists} + +Individual account information are contained in the state and the storage MPTs. However, accessing and modifying MPT data requires heavy trie +traversal, insertion and deletion functions. To alleviate these costs, during an execution run, we store all account information in linked list structures +and only modify the state trie at the end of the run. + +Our linked list construction guarantees these properties: +\begin{itemize} + \item A linked list is cyclic. The last element's successor is the first element. + \item A linked list is always sorted by a certain index, which can be one or more fields of an element. + \item The last element of a linked list is MAX, whose index is always higher than any possible index value. + \item An index cannot appear twice in the linked list. +\end{itemize} + +These properties allows us to efficiently modify the list. + +\paragraph*{Search} +To search a node given its index, we provide via \texttt{PROVER\_INPUT} a pointer to its predecessor $p$. We first check that $p$'s index is strictly lower than +the node index, if not, the provided pointer is invalid. Then, we check $s$, $p$'s successor. If $s$'s index is equal to the node index, we found the node. +If $s$'s index is lower than the node index, then the provided $p$ was invalid. If $s$'s index is greater than the node index, then the node doesn't exist. + +\paragraph*{Insertion} +To insert a node given its index, we provide via \texttt{PROVER\_INPUT} a pointer to its predecessor $p$. We first check that $p$'s index is strictly lower than +the node index, if not, the provided pointer is invalid. Then, we check $s$, $p$'s successor, and make sure that $s$ is strictly greater than the node index. +We create a new node, and make it $p$'s successor; then we make $s$ the new node's successor. + +\paragraph*{Deletion} +To delete a node given its index, we provide via \texttt{PROVER\_INPUT} a pointer to its predecessor $p$. We check that $p$'s successor is equal to the node index; if not +either $p$ is invalid or the node doesn't exist. Then we set $p$'s successor to the node's successor. To indicate that the node is now deleted and to make sure that it's +never accessed again, we set its next pointer to MAX. + +We maintain two linked lists: one for the state accounts and one for the storage slots. + +\subsubsection*{Account linked list} + +An account node is made of four memory cells: +\begin{itemize} + \item The account key (the hash of the account address). This is the index of the node. + \item A pointer to the account payload, in segment \texttt{@TrieData}. + \item A pointer to the initial account payload, in segment \texttt{@TrieData}. This is the value of the account at the beginning of the execution, before processing any transaction. This payload never changes. + \item A pointer to the next node (which points to the next node's account key). +\end{itemize} + +\subsubsection*{Storage linked list} + +A storage node is made of five memory cells: +\begin{itemize} + \item The account key (the hash of the account address). + \item The slot key (the hash of the slot). Nodes are indexed by \texttt{(account\_key, slot\_key)}. + \item The slot value. + \item The initial slot value. This is the value of the account at the beginning of the execution, before processing any transaction. It never changes. + \item A pointer to the next node (which points to the next node's account key). +\end{itemize} \ No newline at end of file diff --git a/docs/arithmetization/tables.tex b/docs/arithmetization/tables.tex index 43b45eb58..3094c462a 100644 --- a/docs/arithmetization/tables.tex +++ b/docs/arithmetization/tables.tex @@ -6,5 +6,6 @@ \section{Tables} \input{tables/byte-packing} \input{tables/logic} \input{tables/memory} +\input{tables/mem-continuations} \input{tables/keccak-f} \input{tables/keccak-sponge} diff --git a/docs/arithmetization/tables/cpu.tex b/docs/arithmetization/tables/cpu.tex index 590b95668..ecbf68267 100644 --- a/docs/arithmetization/tables/cpu.tex +++ b/docs/arithmetization/tables/cpu.tex @@ -21,6 +21,16 @@ \subsubsection{CPU flow} executing user code (transaction or contract code). In a non-zero user context, syscalls may be executed, which are specific instructions written in the kernel. They don't change the context but change the code context, which is where the instructions are read from. +\paragraph*{Continuations} + +A full run of the zkEVM consists in initializing the zkEVM with the input state, executing a certain number of transactions, and then validating the output state. +However, for performance reasons, a run is split in multiple segments of at most \texttt{MAX\_CPU\_CYCLES} cycles, which can be proven individually. Continuations ensure that the segments are part of the +same run and guarantees that the state at the end of a segment is equal to the state at the beginning of the next. + +The state to propagate from one segment to another contains some of the zkEVM registers plus the current memory. These registers +are stored in memory as dedicated global metadata, and the memory to propagate is stored in two STARK tables: \texttt{MemBefore} and \texttt{MemAfter}. To check the +consistency of the memory, the Merkle cap of the previous \texttt{MemAfter} is compared to the Merkle cap of the next \texttt{MemBefore}. + \subsubsection{CPU columns} \paragraph*{Registers:} \begin{itemize} @@ -72,4 +82,6 @@ \subsubsection{CPU columns} See \ref{stackhandling} for more details. \label{push_general_view} \item \texttt{Push}: \texttt{is\_not\_kernel} is used to skip range-checking the output of a PUSH operation when we are in privileged mode, as the kernel code is known and trusted. + \item \texttt{Context pruning}: When \texttt{SET\_CONTEXT} is called to return to a parent context, this makes the current context stale. The kernel indicates it +by setting one general column to 1. For more details about context pruning, see \ref{context-pruning}. \end{itemize} diff --git a/docs/arithmetization/tables/mem-continuations.tex b/docs/arithmetization/tables/mem-continuations.tex new file mode 100644 index 000000000..9aee737a6 --- /dev/null +++ b/docs/arithmetization/tables/mem-continuations.tex @@ -0,0 +1,15 @@ +\subsection{Memory continuations} +\label{mem-continuations} + +The MemBefore (resp. MemAfter) table holds the content of the memory before (resp. after) the execution of the current segment. +For consistency, the MemAfter trace of a segment must be identical to the MemAfter trace of the next segment. +Each row of these tables contains: + +\begin{enumerate} + \item $a$, the memory cell address, + \item $v$, the initial value of the cell. +\end{enumerate} +The tables should be ordered by $(a, \tau)$. Since they only hold values, there are no constraints between the rows. + +A CTL copies all of the MemBefore values in the memory trace as reads, at timestamp $\tau = 0$. +Another CTL copies the final values from memory to MemAfter. For more details on which values are propagated, consult \ref{final-memory}. \ No newline at end of file diff --git a/docs/arithmetization/tables/memory.tex b/docs/arithmetization/tables/memory.tex index d6f3267c3..6eef3bab1 100644 --- a/docs/arithmetization/tables/memory.tex +++ b/docs/arithmetization/tables/memory.tex @@ -68,7 +68,7 @@ \subsubsection{Timestamps} Note that it doesn't mean that all memory operations have unique timestamps. There are two exceptions: \begin{itemize} - \item Before the CPU cycles, we write some global metadata in memory. These extra operations are done at timestamp $\tau = 0$. + \item Before the CPU cycles, we preinitialize the memory with the flashed state stored in the MemBefore table and we write some global metadata. These operations are done at timestamp $\tau = 0$. \item Some tables other than CPU can generate memory operations, like KeccakSponge. When this happens, these operations all have the timestamp of the CPU row of the instruction which invoked the table (for KeccakSponge, KECCAK\_GENERAL). \end{itemize} @@ -77,11 +77,22 @@ \subsubsection{Memory initialization} By default, all memory is zero-initialized. However, to save numerous writes, we allow some specific segments to be initialized with arbitrary values. \begin{itemize} - \item The read-only kernel code (in segment 0, context 0) is initialized with its correct values. It's checked by hashing the segment and verifying -that the hash value matches a verifier-provided one. - \item The code segment (segment 0) in other contexts is initialized with externally-provided account code, then checked against the account code hash. -If the code is meant to be executed, there is a soundness concern: if the code is malformed and ends with an incomplete PUSH, then the missing bytes must + \item The code segment (segment 0) is either part of the initial memory for the kernel (context 0), or is initialized with externally-provided account code, then checked against the account code hash. +In non-zero contexts, if the code is meant to be executed, there is a soundness concern: if the code is malformed and ends with an incomplete PUSH, then the missing bytes must be 0 accordingly to the Ethereum specs. To prevent the issue, we manually write 33 zeros (at most 32 bytes for the PUSH argument, and an extra one for the post-PUSH PC value). \item The ``TrieData'' segment is initialized with the input tries. The stored tries are hashed and checked against the provided initial hash. Note that the length of the segment and the pointers -- within the ``TrieData'' segment -- for the three tries are provided as prover inputs. The length is then checked against a value computed when hashing the tries. \end{itemize} + +\subsubsection{Final memory} +\label{final-memory} + +The final value of each cell of the memory must be propagated to the MemAfter table. Since memory operations are ordered by address and by timestamps, this is +easy to do: the last value of an address is the value of the last row touching this address. In other words, we propagate values of rows before the address changes. + +\paragraph*{Context pruning} +\label{context-pruning} + +We can observe that whenever we return from a context (e.g. with a RETURN opcode, from an exception...), we will never access it again and all its memory is now stale. +We make use of this fact to prune stale contexts and exclude them from MemAfter. + diff --git a/docs/arithmetization/zkevm.pdf b/docs/arithmetization/zkevm.pdf index d1986f70d..ee815d59a 100644 Binary files a/docs/arithmetization/zkevm.pdf and b/docs/arithmetization/zkevm.pdf differ diff --git a/evm_arithmetization/Cargo.toml b/evm_arithmetization/Cargo.toml index 7398aee08..d69f2a140 100644 --- a/evm_arithmetization/Cargo.toml +++ b/evm_arithmetization/Cargo.toml @@ -41,8 +41,10 @@ serde = { workspace = true, features = ["derive"] } sha2 = { workspace = true } static_assertions = { workspace = true } hashbrown = { workspace = true } +thiserror = { workspace = true } tiny-keccak = { workspace = true } serde_json = { workspace = true } +serde-big-array = { workspace = true } # Local dependencies mpt_trie = { workspace = true } diff --git a/evm_arithmetization/benches/fibonacci_25m_gas.rs b/evm_arithmetization/benches/fibonacci_25m_gas.rs index 9959acbc1..ca2b74e04 100644 --- a/evm_arithmetization/benches/fibonacci_25m_gas.rs +++ b/evm_arithmetization/benches/fibonacci_25m_gas.rs @@ -177,7 +177,7 @@ fn prepare_setup() -> anyhow::Result { }; Ok(GenerationInputs { - signed_txn: Some(txn.to_vec()), + signed_txns: vec![txn.to_vec()], withdrawals: vec![], tries: tries_before, trie_roots_after, diff --git a/evm_arithmetization/src/all_stark.rs b/evm_arithmetization/src/all_stark.rs index f8422b300..c3b5733c1 100644 --- a/evm_arithmetization/src/all_stark.rs +++ b/evm_arithmetization/src/all_stark.rs @@ -1,4 +1,5 @@ use core::ops::Deref; +use std::iter; use plonky2::field::extension::Extendable; use plonky2::field::types::Field; @@ -11,8 +12,8 @@ use starky::stark::Stark; use crate::arithmetic::arithmetic_stark; use crate::arithmetic::arithmetic_stark::ArithmeticStark; use crate::byte_packing::byte_packing_stark::{self, BytePackingStark}; -use crate::cpu::cpu_stark; use crate::cpu::cpu_stark::CpuStark; +use crate::cpu::cpu_stark::{self, ctl_context_pruning_looked}; use crate::cpu::membus::NUM_GP_CHANNELS; use crate::keccak::keccak_stark; use crate::keccak::keccak_stark::KeccakStark; @@ -21,8 +22,9 @@ use crate::keccak_sponge::keccak_sponge_stark; use crate::keccak_sponge::keccak_sponge_stark::KeccakSpongeStark; use crate::logic; use crate::logic::LogicStark; -use crate::memory::memory_stark; use crate::memory::memory_stark::MemoryStark; +use crate::memory::memory_stark::{self, ctl_context_pruning_looking}; +use crate::memory_continuation::memory_continuation_stark::{self, MemoryContinuationStark}; /// Structure containing all STARKs and the cross-table lookups. #[derive(Clone)] @@ -34,6 +36,8 @@ pub struct AllStark, const D: usize> { pub(crate) keccak_sponge_stark: KeccakSpongeStark, pub(crate) logic_stark: LogicStark, pub(crate) memory_stark: MemoryStark, + pub(crate) mem_before_stark: MemoryContinuationStark, + pub(crate) mem_after_stark: MemoryContinuationStark, pub(crate) cross_table_lookups: Vec>, } @@ -49,6 +53,8 @@ impl, const D: usize> Default for AllStark { keccak_sponge_stark: KeccakSpongeStark::default(), logic_stark: LogicStark::default(), memory_stark: MemoryStark::default(), + mem_before_stark: MemoryContinuationStark::default(), + mem_after_stark: MemoryContinuationStark::default(), cross_table_lookups: all_cross_table_lookups(), } } @@ -64,6 +70,8 @@ impl, const D: usize> AllStark { self.keccak_sponge_stark.num_lookup_helper_columns(config), self.logic_stark.num_lookup_helper_columns(config), self.memory_stark.num_lookup_helper_columns(config), + self.mem_before_stark.num_lookup_helper_columns(config), + self.mem_after_stark.num_lookup_helper_columns(config), ] } } @@ -80,6 +88,8 @@ pub enum Table { KeccakSponge = 4, Logic = 5, Memory = 6, + MemBefore = 7, + MemAfter = 8, } impl Deref for Table { @@ -88,12 +98,12 @@ impl Deref for Table { fn deref(&self) -> &Self::Target { // Hacky way to implement `Deref` for `Table` so that we don't have to // call `Table::Foo as usize`, but perhaps too ugly to be worth it. - [&0, &1, &2, &3, &4, &5, &6][*self as TableIdx] + [&0, &1, &2, &3, &4, &5, &6, &7, &8][*self as TableIdx] } } /// Number of STARK tables. -pub(crate) const NUM_TABLES: usize = Table::Memory as usize + 1; +pub(crate) const NUM_TABLES: usize = Table::MemAfter as usize + 1; impl Table { /// Returns all STARK table indices. @@ -106,6 +116,8 @@ impl Table { Self::KeccakSponge, Self::Logic, Self::Memory, + Self::MemBefore, + Self::MemAfter, ] } } @@ -120,6 +132,9 @@ pub(crate) fn all_cross_table_lookups() -> Vec> { ctl_keccak_outputs(), ctl_logic(), ctl_memory(), + ctl_mem_before(), + ctl_mem_after(), + ctl_context_pruning(), ] } @@ -287,6 +302,11 @@ fn ctl_memory() -> CrossTableLookup { byte_packing_stark::ctl_looking_memory_filter(i), ) }); + let mem_before_ops = TableWithColumns::new( + *Table::MemBefore, + memory_continuation_stark::ctl_data_memory(), + memory_continuation_stark::ctl_filter(), + ); let all_lookers = vec![ cpu_memory_code_read, cpu_push_write_ops, @@ -297,6 +317,7 @@ fn ctl_memory() -> CrossTableLookup { .chain(cpu_memory_gp_ops) .chain(keccak_sponge_reads) .chain(byte_packing_ops) + .chain(iter::once(mem_before_ops)) .collect(); let memory_looked = TableWithColumns::new( *Table::Memory, @@ -305,3 +326,45 @@ fn ctl_memory() -> CrossTableLookup { ); CrossTableLookup::new(all_lookers, memory_looked) } + +/// `CrossTableLookup` for `Cpu` to propagate stale contexts to `Memory`. +fn ctl_context_pruning() -> CrossTableLookup { + CrossTableLookup::new( + vec![ctl_context_pruning_looking()], + ctl_context_pruning_looked(), + ) +} + +/// `CrossTableLookup` for `MemBefore` table to connect it with the `Memory` +/// module. +fn ctl_mem_before() -> CrossTableLookup { + let memory_looking = TableWithColumns::new( + *Table::Memory, + memory_stark::ctl_looking_mem(), + memory_stark::ctl_filter_mem_before(), + ); + let all_lookers = vec![memory_looking]; + let mem_before_looked = TableWithColumns::new( + *Table::MemBefore, + memory_continuation_stark::ctl_data(), + memory_continuation_stark::ctl_filter(), + ); + CrossTableLookup::new(all_lookers, mem_before_looked) +} + +/// `CrossTableLookup` for `MemAfter` table to connect it with the `Memory` +/// module. +fn ctl_mem_after() -> CrossTableLookup { + let memory_looking = TableWithColumns::new( + *Table::Memory, + memory_stark::ctl_looking_mem(), + memory_stark::ctl_filter_mem_after(), + ); + let all_lookers = vec![memory_looking]; + let mem_after_looked = TableWithColumns::new( + *Table::MemAfter, + memory_continuation_stark::ctl_data(), + memory_continuation_stark::ctl_filter(), + ); + CrossTableLookup::new(all_lookers, mem_after_looked) +} diff --git a/evm_arithmetization/src/cpu/clock.rs b/evm_arithmetization/src/cpu/clock.rs index 4fa917a21..cf23f21df 100644 --- a/evm_arithmetization/src/cpu/clock.rs +++ b/evm_arithmetization/src/cpu/clock.rs @@ -1,3 +1,8 @@ +// In the context of continuations, we subdivide proofs into segments. To pass +// the necessary memory values from one segment to the next, we write those +// initial values at timestamp 0. For this reason, the clock has to be +// initialized to 1 at the start of a segment execution. + use plonky2::field::extension::Extendable; use plonky2::field::packed::PackedField; use plonky2::hash::hash_types::RichField; @@ -12,8 +17,8 @@ pub(crate) fn eval_packed( nv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { - // The clock is 0 at the beginning. - yield_constr.constraint_first_row(lv.clock); + // The clock is 1 at the beginning. + yield_constr.constraint_first_row(lv.clock - P::ONES); // The clock is incremented by 1 at each row. yield_constr.constraint_transition(nv.clock - lv.clock - P::ONES); } @@ -26,8 +31,9 @@ pub(crate) fn eval_ext_circuit, const D: usize>( nv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { - // The clock is 0 at the beginning. - yield_constr.constraint_first_row(builder, lv.clock); + let first_clock = builder.add_const_extension(lv.clock, F::NEG_ONE); + // The clock is 1 at the beginning. + yield_constr.constraint_first_row(builder, first_clock); // The clock is incremented by 1 at each row. { let new_clock = builder.add_const_extension(lv.clock, F::ONE); diff --git a/evm_arithmetization/src/cpu/columns/general.rs b/evm_arithmetization/src/cpu/columns/general.rs index 9ce620078..5ab033b82 100644 --- a/evm_arithmetization/src/cpu/columns/general.rs +++ b/evm_arithmetization/src/cpu/columns/general.rs @@ -15,6 +15,7 @@ pub(crate) union CpuGeneralColumnsView { shift: CpuShiftView, stack: CpuStackView, push: CpuPushView, + context_pruning: CpuContextPruningView, } impl CpuGeneralColumnsView { @@ -91,6 +92,18 @@ impl CpuGeneralColumnsView { pub(crate) fn push_mut(&mut self) -> &mut CpuPushView { unsafe { &mut self.push } } + + /// View of the column for context pruning. + /// SAFETY: Each view is a valid interpretation of the underlying array. + pub(crate) fn context_pruning(&self) -> &CpuContextPruningView { + unsafe { &self.context_pruning } + } + + /// Mutable view of the column for context pruning. + /// SAFETY: Each view is a valid interpretation of the underlying array. + pub(crate) fn context_pruning_mut(&mut self) -> &mut CpuContextPruningView { + unsafe { &mut self.context_pruning } + } } impl PartialEq for CpuGeneralColumnsView { @@ -203,6 +216,17 @@ pub(crate) struct CpuPushView { /// Reserve the unused columns. _padding_columns: [T; NUM_SHARED_COLUMNS - 1], } + +/// View of the first `CpuGeneralColumn` storing a flag for context pruning. +#[derive(Copy, Clone)] +pub(crate) struct CpuContextPruningView { + /// The flag is 1 if the OP flag `context_op` is set, the operation is + /// `SET_CONTEXT` and `new_ctx < old_ctx`, and 0 otherwise. + pub(crate) pruning_flag: T, + /// Reserve the unused columns. + _padding_columns: [T; NUM_SHARED_COLUMNS - 1], +} + /// The number of columns shared by all views of [`CpuGeneralColumnsView`]. /// This is defined in terms of the largest view in order to determine the /// number of padding columns to add to each field without creating a cycle @@ -219,3 +243,4 @@ const_assert!(size_of::>() == NUM_SHARED_COLUMNS); const_assert!(size_of::>() == NUM_SHARED_COLUMNS); const_assert!(size_of::>() == NUM_SHARED_COLUMNS); const_assert!(size_of::>() == NUM_SHARED_COLUMNS); +const_assert!(size_of::>() == NUM_SHARED_COLUMNS); diff --git a/evm_arithmetization/src/cpu/contextops.rs b/evm_arithmetization/src/cpu/contextops.rs index 6a7abed89..9fdd92f29 100644 --- a/evm_arithmetization/src/cpu/contextops.rs +++ b/evm_arithmetization/src/cpu/contextops.rs @@ -89,6 +89,9 @@ fn eval_packed_get( yield_constr.constraint(filter * limb); } + // We cannot prune a context in GET_CONTEXT. + yield_constr.constraint(filter * lv.general.context_pruning().pruning_flag); + // Constrain new stack length. yield_constr.constraint(filter * (nv.stack_len - (lv.stack_len + P::ONES))); @@ -121,6 +124,10 @@ fn eval_ext_circuit_get, const D: usize>( yield_constr.constraint(builder, constr); } + // We cannot prune a context in GET_CONTEXT. + let constr = builder.mul_extension(filter, lv.general.context_pruning().pruning_flag); + yield_constr.constraint(builder, constr); + // Constrain new stack length. { let new_len = builder.add_const_extension(lv.stack_len, F::ONE); @@ -148,12 +155,24 @@ fn eval_packed_set( // The next row's context is read from stack_top. yield_constr.constraint(filter * (stack_top[2] - nv.context)); - for (_, &limb) in stack_top.iter().enumerate().filter(|(i, _)| *i != 2) { + // The stack top contains the new context in the third limb, and a flag + // indicating whether the old context should be pruned in the first limb. The + // other limbs should be 0. + for (_, &limb) in stack_top[1..].iter().enumerate().filter(|(i, _)| *i != 1) { yield_constr.constraint(filter * limb); } - // The old SP is decremented (since the new context was popped) and stored in - // memory. The new SP is loaded from memory. + // Check that the pruning flag is binary. + yield_constr.constraint( + lv.op.context_op + * lv.general.context_pruning().pruning_flag + * (lv.general.context_pruning().pruning_flag - P::Scalar::ONES), + ); + // stack_top[0] contains a flag indicating whether the context should be pruned. + yield_constr.constraint(filter * (lv.general.context_pruning().pruning_flag - stack_top[0])); + + // The old SP is decremented (since the new context was popped) + // and stored in memory. The new SP is loaded from memory. // This is all done with CTLs: nothing is constrained here. // Constrain stack_inv_aux_2. @@ -197,11 +216,25 @@ fn eval_ext_circuit_set, const D: usize>( let constr = builder.mul_extension(filter, diff); yield_constr.constraint(builder, constr); } - for (_, &limb) in stack_top.iter().enumerate().filter(|(i, _)| *i != 2) { + for (_, &limb) in stack_top[1..].iter().enumerate().filter(|(i, _)| *i != 1) { let constr = builder.mul_extension(filter, limb); yield_constr.constraint(builder, constr); } + // Check that the pruning flag is binary. + let diff = builder.mul_sub_extension( + lv.general.context_pruning().pruning_flag, + lv.general.context_pruning().pruning_flag, + lv.general.context_pruning().pruning_flag, + ); + let constr = builder.mul_extension(lv.op.context_op, diff); + yield_constr.constraint(builder, constr); + + // stack_top[0] contains a flag indicating whether the context should be pruned. + let diff = builder.sub_extension(lv.general.context_pruning().pruning_flag, stack_top[0]); + let constr = builder.mul_extension(filter, diff); + yield_constr.constraint(builder, constr); + // The old SP is decremented (since the new context was popped) and stored in // memory. The new SP is loaded from memory. // This is all done with CTLs: nothing is constrained here. diff --git a/evm_arithmetization/src/cpu/control_flow.rs b/evm_arithmetization/src/cpu/control_flow.rs index ba4e71890..b32ee0ae7 100644 --- a/evm_arithmetization/src/cpu/control_flow.rs +++ b/evm_arithmetization/src/cpu/control_flow.rs @@ -29,15 +29,15 @@ const NATIVE_INSTRUCTIONS: [usize; 12] = [ // not exceptions (also jump) ]; -/// Returns `halt`'s program counter. +/// Returns `halt_final`'s program counter. pub(crate) fn get_halt_pc() -> F { - let halt_pc = KERNEL.global_labels["halt"]; + let halt_pc = KERNEL.global_labels["halt_final"]; F::from_canonical_usize(halt_pc) } -/// Returns `main`'s program counter. +/// Returns `init`'s program counter. All segments should start at that PC. pub(crate) fn get_start_pc() -> F { - let start_pc = KERNEL.global_labels["main"]; + let start_pc = KERNEL.global_labels["init"]; F::from_canonical_usize(start_pc) } diff --git a/evm_arithmetization/src/cpu/cpu_stark.rs b/evm_arithmetization/src/cpu/cpu_stark.rs index 21c61f1e6..daafa164d 100644 --- a/evm_arithmetization/src/cpu/cpu_stark.rs +++ b/evm_arithmetization/src/cpu/cpu_stark.rs @@ -41,9 +41,13 @@ pub(crate) fn ctl_data_keccak_sponge() -> Vec> { let virt = Column::single(virt); let len = Column::single(COL_MAP.mem_channels[1].value[0]); + // Since we start the clock at 1, we have that: + // timestamp = (clock - 1) * num_channels + 1. let num_channels = F::from_canonical_usize(NUM_CHANNELS); - let timestamp = Column::linear_combination([(COL_MAP.clock, num_channels)]); - + let timestamp = Column::linear_combination_with_constant( + [(COL_MAP.clock, num_channels)], + F::ONE - num_channels, + ); let mut cols = vec![context, segment, virt, len, timestamp]; cols.extend(Column::singles_next_row(COL_MAP.mem_channels[0].value)); cols @@ -127,6 +131,21 @@ pub(crate) fn ctl_arithmetic_base_rows() -> TableWithColumns { ) } +/// Returns a column containing stale contexts. +pub(crate) fn ctl_context_pruning_looked() -> TableWithColumns { + TableWithColumns::new( + *Table::Cpu, + vec![Column::single(COL_MAP.context)], + Filter::new( + vec![( + Column::single(COL_MAP.op.context_op), + Column::single(COL_MAP.general.context_pruning().pruning_flag), + )], + vec![], + ), + ) +} + /// Creates the vector of `Columns` corresponding to the contents of General /// Purpose channels when calling byte packing. We use `ctl_data_keccak_sponge` /// because the `Columns` are the same as the ones computed for @@ -175,8 +194,13 @@ pub(crate) fn ctl_data_byte_unpacking() -> Vec> { ); res.push(len); + // Since we start the clock at 1, we have that: + // timestamp = (clock - 1) * num_channels + 1. let num_channels = F::from_canonical_usize(NUM_CHANNELS); - let timestamp = Column::linear_combination([(COL_MAP.clock, num_channels)]); + let timestamp = Column::linear_combination_with_constant( + [(COL_MAP.clock, num_channels)], + F::ONE - num_channels, + ); res.push(timestamp); let val = Column::singles(COL_MAP.mem_channels[1].value); @@ -219,8 +243,13 @@ pub(crate) fn ctl_data_jumptable_read() -> Vec> { let len = Column::constant(F::from_canonical_usize(3)); res.push(len); + // Since we start the clock at 1, we have that: + // timestamp = (clock - 1) * num_channels + 1. let num_channels = F::from_canonical_usize(NUM_CHANNELS); - let timestamp = Column::linear_combination([(COL_MAP.clock, num_channels)]); + let timestamp = Column::linear_combination_with_constant( + [(COL_MAP.clock, num_channels)], + F::ONE - num_channels, + ); res.push(timestamp); res.extend(val); @@ -249,8 +278,13 @@ pub(crate) fn ctl_data_byte_packing_push() -> Vec> { // 1`. let len = Column::le_bits_with_constant(&COL_MAP.opcode_bits[0..5], F::ONE); + // Since we start the clock at 1, we have that: + // timestamp = (clock - 1) * num_channels + 1. let num_channels = F::from_canonical_usize(NUM_CHANNELS); - let timestamp = Column::linear_combination([(COL_MAP.clock, num_channels)]); + let timestamp = Column::linear_combination_with_constant( + [(COL_MAP.clock, num_channels)], + F::ONE - num_channels, + ); let mut res = vec![is_read, context, segment, virt, len, timestamp]; res.extend(val); @@ -291,7 +325,7 @@ pub(crate) const fn get_addr(lv: &CpuColumnsView, mem_channel: usize /// Make the time/channel column for memory lookups. fn mem_time_and_channel(channel: usize) -> Column { let scalar = F::from_canonical_usize(NUM_CHANNELS); - let addend = F::from_canonical_usize(channel); + let addend = F::from_canonical_usize(channel) - scalar + F::ONE; Column::linear_combination_with_constant([(COL_MAP.clock, scalar)], addend) } diff --git a/evm_arithmetization/src/cpu/kernel/aggregator.rs b/evm_arithmetization/src/cpu/kernel/aggregator.rs index 211d1d69f..7069c14f6 100644 --- a/evm_arithmetization/src/cpu/kernel/aggregator.rs +++ b/evm_arithmetization/src/cpu/kernel/aggregator.rs @@ -9,7 +9,7 @@ use super::assembler::{assemble, Kernel}; use crate::cpu::kernel::constants::evm_constants; use crate::cpu::kernel::parser::parse; -pub const NUMBER_KERNEL_FILES: usize = 156; +pub const NUMBER_KERNEL_FILES: usize = 159; pub static KERNEL_FILES: [&str; NUMBER_KERNEL_FILES] = [ "global jumped_to_0: PANIC", @@ -129,6 +129,9 @@ pub static KERNEL_FILES: [&str; NUMBER_KERNEL_FILES] = [ include_str!("asm/mpt/insert/insert_extension.asm"), include_str!("asm/mpt/insert/insert_leaf.asm"), include_str!("asm/mpt/insert/insert_trie_specific.asm"), + include_str!("asm/mpt/linked_list/linked_list.asm"), + include_str!("asm/mpt/linked_list/initial_tries.asm"), + include_str!("asm/mpt/linked_list/final_tries.asm"), include_str!("asm/mpt/read.asm"), include_str!("asm/mpt/storage/storage_read.asm"), include_str!("asm/mpt/storage/storage_write.asm"), diff --git a/evm_arithmetization/src/cpu/kernel/asm/account_code.asm b/evm_arithmetization/src/cpu/kernel/asm/account_code.asm index 2654bedc7..62b5b968b 100644 --- a/evm_arithmetization/src/cpu/kernel/asm/account_code.asm +++ b/evm_arithmetization/src/cpu/kernel/asm/account_code.asm @@ -70,16 +70,20 @@ global sys_extcodesize: SWAP1 // stack: address, kexit_info %extcodesize + // stack: code_size, codesize_ctx, kexit_info + SWAP1 + // stack: codesize_ctx, code_size, kexit_info + %prune_context // stack: code_size, kexit_info SWAP1 EXIT_KERNEL +// Pre stack: address, retdest +// Post stack: code_size, codesize_ctx global extcodesize: // stack: address, retdest %next_context_id - // stack: codesize_ctx, address, retdest - SWAP1 - // stack: address, codesize_ctx, retdest + %stack(codesize_ctx, address, retdest) -> (address, codesize_ctx, retdest, codesize_ctx) %jump(load_code) // Loads the code at `address` into memory, in the code segment of the given context, starting at offset 0. diff --git a/evm_arithmetization/src/cpu/kernel/asm/beacon_roots.asm b/evm_arithmetization/src/cpu/kernel/asm/beacon_roots.asm index 125c9d58b..6fba36fec 100644 --- a/evm_arithmetization/src/cpu/kernel/asm/beacon_roots.asm +++ b/evm_arithmetization/src/cpu/kernel/asm/beacon_roots.asm @@ -12,84 +12,48 @@ global set_beacon_root: // stack: timestamp, 8191, timestamp, retdest MOD // stack: timestamp_idx, timestamp, retdest - PUSH write_beacon_roots_to_storage + %slot_to_storage_key + // stack: timestamp_slot_key, timestamp, retdest + PUSH @BEACON_ROOTS_CONTRACT_STATE_KEY + %addr_to_state_key %parent_beacon_block_root - // stack: calldata, write_beacon_roots_to_storage, timestamp_idx, timestamp, retdest - DUP3 + // stack: calldata, state_key, timestamp_slot_key, timestamp, retdest + PUSH @HISTORY_BUFFER_LENGTH + DUP5 + MOD + // stack: timestamp_idx, calldata, state_key, timestamp_slot_key, timestamp, retdest %add_const(@HISTORY_BUFFER_LENGTH) - // stack: root_idx, calldata, write_beacon_roots_to_storage, timestamp_idx, timestamp, retdest - - // If the calldata is zero, delete the slot from the storage trie. - DUP2 ISZERO %jumpi(delete_root_idx_slot) - -write_beacon_roots_to_storage: - // stack: slot, value, retdest - // First we write the value to MPT data, and get a pointer to it. - %get_trie_data_size - // stack: value_ptr, slot, value, retdest - SWAP2 - // stack: value, slot, value_ptr, retdest - %append_to_trie_data - // stack: slot, value_ptr, retdest - - // Next, call mpt_insert on the current account's storage root. - %stack (slot, value_ptr) -> (slot, value_ptr, after_beacon_roots_storage_insert) + // stack: root_idx, calldata, state_key, timestamp_slot_key, timestamp, retdest %slot_to_storage_key - // stack: storage_key, value_ptr, after_beacon_roots_storage_insert, retdest - PUSH 64 // storage_key has 64 nibbles - %get_storage_trie(@BEACON_ROOTS_CONTRACT_STATE_KEY) - // stack: storage_root_ptr, 64, storage_key, value_ptr, after_beacon_roots_storage_insert, retdest - %jump(mpt_insert) - -after_beacon_roots_storage_insert: - // stack: new_storage_root_ptr, retdest - %get_account_data(@BEACON_ROOTS_CONTRACT_STATE_KEY) - // stack: account_ptr, new_storage_root_ptr, retdest - - // Update the copied account with our new storage root pointer. - %add_const(2) - // stack: account_storage_root_ptr_ptr, new_storage_root_ptr, retdest - %mstore_trie_data + // stack: root_slot_key, calldata, state_key, timestamp_slot_key, timestamp, retdest + DUP3 + // stack: state_key, root_slot_key, calldata, state_key, timestamp_slot_key, timestamp, retdest + DUP3 ISZERO %jumpi(delete_root_idx_slot) + // stack: state_key, root_slot_key, calldata, state_key, timestamp_slot_key, timestamp, retdest + %insert_slot_with_value_from_keys + // stack: state_key, timestamp_slot_key, timestamp, retdest + %insert_slot_with_value_from_keys + // stack: retdest JUMP delete_root_idx_slot: - // stack: root_idx, 0, write_beacon_roots_to_storage, timestamp_idx, timestamp, retdest - PUSH after_root_idx_slot_delete - SWAP2 POP - // stack: root_idx, after_root_idx_slot_delete, write_beacon_roots_to_storage, timestamp_idx, timestamp, retdest - %slot_to_storage_key - // stack: storage_key, after_root_idx_slot_delete, write_beacon_roots_to_storage, timestamp_idx, timestamp, retdest - PUSH 64 // storage_key has 64 nibbles - %get_storage_trie(@BEACON_ROOTS_CONTRACT_STATE_KEY) - // stack: storage_root_ptr, 64, storage_key, after_root_idx_slot_delete, write_beacon_roots_to_storage, timestamp_idx, timestamp, retdest - - // If the slot is empty (i.e. ptr defaulting to 0), skip the deletion. - DUP1 ISZERO %jumpi(skip_empty_slot) - - // stack: storage_root_ptr, 64, storage_key, after_root_idx_slot_delete, write_beacon_roots_to_storage, timestamp_idx, timestamp, retdest - %stack (storage_root_ptr, nibbles, storage_key) -> (storage_root_ptr, nibbles, storage_key, checkpoint_delete_root_idx, storage_root_ptr, nibbles, storage_key) - %jump(mpt_read) -checkpoint_delete_root_idx: - // stack: value_ptr, storage_root_ptr, 64, storage_key, after_root_idx_slot_delete, write_beacon_roots_to_storage, timestamp_idx, timestamp, retdest - // If the the storage key is not found (i.e. ptr defaulting to 0), skip the deletion. - ISZERO %jumpi(skip_empty_slot) - - // stack: storage_root_ptr, 64, storage_key, after_root_idx_slot_delete, write_beacon_roots_to_storage, timestamp_idx, timestamp, retdest - %jump(mpt_delete) - -after_root_idx_slot_delete: - // stack: new_storage_root_ptr, write_beacon_roots_to_storage, timestamp_idx, timestamp, retdest - %get_account_data(@BEACON_ROOTS_CONTRACT_STATE_KEY) - // stack: account_ptr, new_storage_root_ptr, write_beacon_roots_to_storage, timestamp_idx, timestamp, retdest - - // Update the copied account with our new storage root pointer. - %add_const(2) - // stack: account_storage_root_ptr_ptr, new_storage_root_ptr, write_beacon_roots_to_storage, timestamp_idx, timestamp, retdest - %mstore_trie_data - // stack: write_beacon_roots_to_storage, timestamp_idx, timestamp, retdest + // stack: state_key, root_slot_key, 0, state_key, timestamp_slot_key, timestamp, retdest + DUP3 DUP3 DUP3 + %search_slot + // stack: slot_exists, state_key, root_slot_key, 0, state_key, timestamp_slot_key, timestamp, retdest + %jumpi(remove_root_idx_slot) + // stack: state_key, root_slot_key, 0, state_key, timestamp_slot_key, timestamp, retdest + %pop3 + // stack: state_key, timestamp_slot_key, timestamp, retdest + %insert_slot_with_value_from_keys + // stack: retdest JUMP -skip_empty_slot: - // stack: 0, 64, storage_key, after_root_idx_slot_delete, write_beacon_roots_to_storage, timestamp_idx, timestamp, retdest - %pop4 +remove_root_idx_slot: + // stack: state_key, root_slot_key, 0, state_key, timestamp_slot_key, timestamp, retdest + %stack(state_key, storage_key, zero) -> (storage_key, state_key) + %remove_slot + // stack: state_key, timestamp_slot_key, timestamp, retdest + %insert_slot_with_value_from_keys + // stack: retdest JUMP diff --git a/evm_arithmetization/src/cpu/kernel/asm/core/access_lists.asm b/evm_arithmetization/src/cpu/kernel/asm/core/access_lists.asm index 73b9401ca..8a487fed0 100644 --- a/evm_arithmetization/src/cpu/kernel/asm/core/access_lists.asm +++ b/evm_arithmetization/src/cpu/kernel/asm/core/access_lists.asm @@ -12,11 +12,16 @@ // Initialize SEGMENT_ACCESSED_ADDRESSES global init_access_lists: // stack: (empty) + + // Reset access lists data. + PUSH 0 %mstore_global_metadata(@GLOBAL_METADATA_ACCESS_LIST_DATA_COST) + PUSH 0 %mstore_global_metadata(@GLOBAL_METADATA_ACCESS_LIST_RLP_LEN) + PUSH 0 %mstore_global_metadata(@GLOBAL_METADATA_ACCESS_LIST_RLP_START) + // Store @U256_MAX at the beginning of the segment PUSH @SEGMENT_ACCESSED_ADDRESSES // ctx == virt == 0 DUP1 - PUSH @U256_MAX - MSTORE_GENERAL + %mstore_u256_max // Store @SEGMENT_ACCESSED_ADDRESSES at address 1 %increment DUP1 @@ -32,8 +37,7 @@ global init_access_lists: // Store @U256_MAX at the beginning of the segment PUSH @SEGMENT_ACCESSED_STORAGE_KEYS // ctx == virt == 0 DUP1 - PUSH @U256_MAX - MSTORE_GENERAL + %mstore_u256_max // Store @SEGMENT_ACCESSED_STORAGE_KEYS at address 3 %add_const(3) DUP1 @@ -195,8 +199,7 @@ global remove_accessed_addresses: MLOAD_GENERAL // stack: next_next_ptr, next_next_ptr_ptr, next_ptr_ptr, addr, retdest SWAP1 - PUSH @U256_MAX - MSTORE_GENERAL + %mstore_u256_max // stack: next_next_ptr, next_ptr_ptr, addr, retdest MSTORE_GENERAL POP @@ -379,9 +382,8 @@ global remove_accessed_storage_keys: MLOAD_GENERAL // stack: next_next_ptr, next_next_ptr_ptr, next_ptr_ptr, addr, key, retdest SWAP1 - PUSH @U256_MAX - MSTORE_GENERAL + %mstore_u256_max // stack: next_next_ptr, next_ptr_ptr, addr, key, retdest MSTORE_GENERAL %pop2 - JUMP \ No newline at end of file + JUMP diff --git a/evm_arithmetization/src/cpu/kernel/asm/core/create_receipt.asm b/evm_arithmetization/src/cpu/kernel/asm/core/create_receipt.asm index 742c4784c..bd378abde 100644 --- a/evm_arithmetization/src/cpu/kernel/asm/core/create_receipt.asm +++ b/evm_arithmetization/src/cpu/kernel/asm/core/create_receipt.asm @@ -208,9 +208,18 @@ process_receipt_after_write: %mpt_insert_receipt_trie // stack: new_cum_gas, txn_nb, num_nibbles, retdest - // We don't need to reset the bloom filter segment as we only process a single transaction. - // TODO: Revert in case we add back support for multi-txn proofs. - + // Now, we set the Bloom filter back to 0. We proceed by chunks of 32 bytes. + PUSH @SEGMENT_TXN_BLOOM // ctx == offset == 0 + %rep 8 + // stack: addr, new_cum_gas, txn_nb, num_nibbles, retdest + PUSH 0 // we will fill the memory segment with zeroes + SWAP1 + // stack: addr, 0, new_cum_gas, txn_nb, num_nibbles, retdest + MSTORE_32BYTES_32 + // stack: new_addr, new_cum_gas, txn_nb, num_nibbles, retdest + %endrep + POP + // stack: new_cum_gas, txn_nb, num_nibbles, retdest %stack (new_cum_gas, txn_nb, num_nibbles, retdest) -> (retdest, new_cum_gas) JUMP diff --git a/evm_arithmetization/src/cpu/kernel/asm/core/exception.asm b/evm_arithmetization/src/cpu/kernel/asm/core/exception.asm index 6e29af030..a2a2742ec 100644 --- a/evm_arithmetization/src/cpu/kernel/asm/core/exception.asm +++ b/evm_arithmetization/src/cpu/kernel/asm/core/exception.asm @@ -20,8 +20,11 @@ global exception_jumptable: // exception 5: stack overflow JUMPTABLE exc_stack_overflow - // exceptions 6 and 7: unused - JUMPTABLE panic + // exception 6: end of segmented proof. + // This reuses the exceptions logic but is part of any valid segment execution. + JUMPTABLE exc_stop + + // exceptions 7: unused JUMPTABLE panic @@ -165,6 +168,92 @@ global exc_stack_overflow_check_stack_length: %jumpi(fault_exception) PANIC +global exc_stop: + // Here, we need to check that the final registers have the correct value. + // stack: trap_info + PUSH @FINAL_REGISTERS_ADDR + // stack: addr_registers, trap_info + PUSH 3 + // If the current `stack_len` is 3, then the stack was empty before the exception and there's no stack top. + %stack_length + SUB + // First, check the stack length. + // stack: stack_len-3 = stack_len_before_exc, addr_registers, trap_info + DUP2 %add_const(2) + MLOAD_GENERAL + // stack: stored_stack_length, stack_len_before_exc, addr_registers, trap_info + DUP2 %assert_eq + + // Now, check that we end up with the correct stack_top. + // stack: stack_len_before_exc, addr_registers, trap_info + DUP1 PUSH 0 LT + // stack: 0 < stack_len_before_exc, stack_len_before_exc, addr_registers, trap_info + PUSH 1 DUP3 SUB + // stack: stack_len_before_exc - 1, 0 < stack_len_before_exc, stack_len_before_exc, addr_registers, trap_info + MUL + // If the previous stack length is 0, we load the first value in the stack segment: + // we do not need to constrain the value in that case, so this is just to avoid a jumpi. + // Not having a `jumpi` provides a constant number of operations, which is better for segmentation. + // stack: (stack_len_before_exc - 1) * (stack_len_before_exc != 0), stack_len_before_exc, addr_registers, trap_info + PUSH @SEGMENT_STACK + GET_CONTEXT + %build_address + // stack: stack_top_before_exc_addr, stack_len_before_exc, addr_registers, trap_info + MLOAD_GENERAL + // stack: stack_top_before_exc, stack_len_before_exc, addr_registers, trap_info + DUP3 %add_const(3) + MLOAD_GENERAL + // stack: stored_stack_top, stack_top_before_exc, stack_len_before_exc, addr_registers, trap_info + SUB MUL + // stack: (stored_stack_top - stack_top_before_exc) * stack_len_before_exc, addr_registers, trap_info + %assert_zero + + // Check the program counter. + // stack: addr_registers, trap_info + DUP2 %as_u32 + // stack: program_counter, addr_registers, trap_info + DUP2 + MLOAD_GENERAL + // stack: public_pc, program_counter, addr_registers, trap_info + %assert_eq + + // Check is_kernel_mode. + // stack: addr_registers, trap_info + DUP2 %shr_const(32) + %as_u32 + // stack: is_kernel_mode, addr_registers, trap_info + DUP2 %increment + MLOAD_GENERAL + %assert_eq + + // Check the gas used. + // stack: addr_registers, trap_info + SWAP1 %shr_const(192) + %as_u32 + // stack: gas_used, addr_registers + DUP2 %add_const(5) + MLOAD_GENERAL + %assert_eq + + // Check the context. + // stack: addr_registers + %add_const(4) + MLOAD_GENERAL + %shl_const(64) + // stack: stored_context + GET_CONTEXT + %assert_eq + // stack: (empty) + // The following two instructions are needed to not have failing constraints. + // `ISZERO` pops and pushes, which means that there is no need to read the next top of the stack after it. + // If we don't have it, there is a read of the top of the stack in padding rows, which have all channels disabled, + // thus making the constraints fail. + PUSH 1 + ISZERO + +global halt_final: + // Just for halting. Nothing is executed when this is reached. + PANIC // Given the exception trap info, load the opcode that caused the exception %macro opcode_from_exp_trap_info diff --git a/evm_arithmetization/src/cpu/kernel/asm/core/process_txn.asm b/evm_arithmetization/src/cpu/kernel/asm/core/process_txn.asm index c6d10eb40..31540fbad 100644 --- a/evm_arithmetization/src/cpu/kernel/asm/core/process_txn.asm +++ b/evm_arithmetization/src/cpu/kernel/asm/core/process_txn.asm @@ -450,22 +450,22 @@ global invalid_txn: POP %mload_txn_field(@TXN_FIELD_GAS_LIMIT) PUSH 0 - %jump(txn_after) + %jump(txn_loop_after) global invalid_txn_1: %pop2 %mload_txn_field(@TXN_FIELD_GAS_LIMIT) PUSH 0 - %jump(txn_after) + %jump(txn_loop_after) global invalid_txn_2: %pop3 %mload_txn_field(@TXN_FIELD_GAS_LIMIT) PUSH 0 - %jump(txn_after) + %jump(txn_loop_after) global invalid_txn_3: %pop4 %mload_txn_field(@TXN_FIELD_GAS_LIMIT) PUSH 0 - %jump(txn_after) + %jump(txn_loop_after) diff --git a/evm_arithmetization/src/cpu/kernel/asm/core/terminate.asm b/evm_arithmetization/src/cpu/kernel/asm/core/terminate.asm index d1a366ede..1d406097c 100644 --- a/evm_arithmetization/src/cpu/kernel/asm/core/terminate.asm +++ b/evm_arithmetization/src/cpu/kernel/asm/core/terminate.asm @@ -253,7 +253,7 @@ global terminate_common: // Go back to the parent context. %mload_context_metadata(@CTX_METADATA_PARENT_CONTEXT) - SET_CONTEXT + %set_and_prune_ctx %decrement_call_depth // stack: (empty) diff --git a/evm_arithmetization/src/cpu/kernel/asm/core/util.asm b/evm_arithmetization/src/cpu/kernel/asm/core/util.asm index 26478d0da..053c6159c 100644 --- a/evm_arithmetization/src/cpu/kernel/asm/core/util.asm +++ b/evm_arithmetization/src/cpu/kernel/asm/core/util.asm @@ -97,3 +97,16 @@ %mload_context_metadata(@CTX_METADATA_STACK_SIZE) // stack: stack_length %endmacro + +%macro set_and_prune_ctx + // stack: context + PUSH 1 ADD + SET_CONTEXT + // stack: (empty) +%endmacro + +%macro mstore_u256_max + // stack: addr + PUSH @U256_MAX + MSTORE_GENERAL +%endmacro \ No newline at end of file diff --git a/evm_arithmetization/src/cpu/kernel/asm/global_exit_root.asm b/evm_arithmetization/src/cpu/kernel/asm/global_exit_root.asm index ffcc377a5..94c81fdf4 100644 --- a/evm_arithmetization/src/cpu/kernel/asm/global_exit_root.asm +++ b/evm_arithmetization/src/cpu/kernel/asm/global_exit_root.asm @@ -7,77 +7,42 @@ global set_global_exit_roots: // stack: (empty) - PUSH start_txn + PUSH txn_loop // stack: retdest + PUSH @GLOBAL_EXIT_ROOT_MANAGER_L2_STATE_KEY + %addr_to_state_key PROVER_INPUT(ger) - // stack: num_ger, retdest + // stack: num_ger, state_key, retdest PUSH 0 ger_loop: - // stack: i, num_ger, retdest + // stack: i, num_ger, state_key, retdest DUP2 DUP2 EQ %jumpi(ger_loop_end) PROVER_INPUT(ger) - // stack: timestamp, i, num_ger, retdest + // stack: timestamp, i, num_ger, state_key, retdest PUSH @GLOBAL_EXIT_ROOT_STORAGE_POS PROVER_INPUT(ger) - // stack: root, GLOBAL_EXIT_ROOT_STORAGE_POS, timestamp, i, num_ger, retdest + // stack: root, GLOBAL_EXIT_ROOT_STORAGE_POS, timestamp, i, num_ger, state_key, retdest PUSH @SEGMENT_KERNEL_GENERAL - // stack: addr, root, GLOBAL_EXIT_ROOT_STORAGE_POS, timestamp, i, num_ger, retdest + // stack: addr, root, GLOBAL_EXIT_ROOT_STORAGE_POS, timestamp, i, num_ger, state_key, retdest MSTORE_32BYTES_32 - // stack: addr, GLOBAL_EXIT_ROOT_STORAGE_POS, timestamp, i, num_ger, retdest + // stack: addr, GLOBAL_EXIT_ROOT_STORAGE_POS, timestamp, i, num_ger, state_key, retdest MSTORE_32BYTES_32 - // stack: addr, timestamp, i, num_ger, retdest + // stack: addr, timestamp, i, num_ger, state_key, retdest POP - // stack: timestamp, i, num_ger, retdest + // stack: timestamp, i, num_ger, state_key, retdest PUSH 64 PUSH @SEGMENT_KERNEL_GENERAL - // stack: addr, len, timestamp, i, num_ger, retdest + // stack: addr, len, timestamp, i, num_ger, state_key, retdest KECCAK_GENERAL - // stack: slot, timestamp, i, num_ger, retdest - -write_timestamp_to_storage: - // stack: slot, timestamp, i, num_ger, retdest - // First we write the value to MPT data, and get a pointer to it. - %get_trie_data_size - // stack: value_ptr, slot, timestamp, i, num_ger, retdest - SWAP2 - // stack: timestamp, slot, value_ptr, i, num_ger, retdest - %append_to_trie_data - // stack: slot, value_ptr, i, num_ger, retdest - - // Next, call mpt_insert on the current account's storage root. - %stack (slot, value_ptr) -> (slot, value_ptr, after_timestamp_storage_insert) + // stack: slot, timestamp, i, num_ger, state_key, retdest %slot_to_storage_key - // stack: storage_key, value_ptr, after_timestamp_storage_insert - PUSH 64 // storage_key has 64 nibbles - %get_storage_trie(@GLOBAL_EXIT_ROOT_MANAGER_L2_STATE_KEY) - // stack: storage_root_ptr, 64, storage_key, value_ptr, after_timestamp_storage_insert - %stack (storage_root_ptr, num_nibbles, storage_key) -> (storage_root_ptr, num_nibbles, storage_key, after_read, storage_root_ptr, num_nibbles, storage_key) - %jump(mpt_read) -after_read: - // If the current value is non-zero, do nothing. - // stack: current_value_ptr, storage_root_ptr, 64, storage_key, value_ptr, after_timestamp_storage_insert - %mload_trie_data %jumpi(do_nothing) - // stack: storage_root_ptr, 64, storage_key, value_ptr, after_timestamp_storage_insert - %jump(mpt_insert) - -after_timestamp_storage_insert: - // stack: new_storage_root_ptr, i, num_ger, retdest - %get_account_data(@GLOBAL_EXIT_ROOT_MANAGER_L2_STATE_KEY) - // stack: account_ptr, new_storage_root_ptr - // Update the copied account with our new storage root pointer. - %add_const(2) - // stack: account_storage_root_ptr_ptr, new_storage_root_ptr - %mstore_trie_data - - // stack: i, num_ger, retdest + // stack: slot_key, timestamp, i, num_ger, state_key, retdest + DUP5 + // stack: state_key, slot_key, timestamp, i, num_ger, state_key, retdest + %insert_slot_with_value_from_keys + // stack: i, num_ger, state_key, retdest %increment %jump(ger_loop) ger_loop_end: - // stack: i, num_ger, retdest - %pop2 JUMP - -do_nothing: - // stack: storage_root_ptr, 64, storage_key, value_ptr, after_timestamp_storage_insert, i, num_ger, retdest - %pop7 - // stack: retdest - JUMP + // stack: i, num_ger, state_key, retdest + %pop3 JUMP diff --git a/evm_arithmetization/src/cpu/kernel/asm/journal/account_destroyed.asm b/evm_arithmetization/src/cpu/kernel/asm/journal/account_destroyed.asm index 3806a891d..d62f3d422 100644 --- a/evm_arithmetization/src/cpu/kernel/asm/journal/account_destroyed.asm +++ b/evm_arithmetization/src/cpu/kernel/asm/journal/account_destroyed.asm @@ -16,7 +16,10 @@ revert_account_destroyed_contd: SWAP1 // Remove `prev_balance` from `target`'s balance. // stack: target, address, prev_balance, retdest - %mpt_read_state_trie + %read_accounts_linked_list + // stack: target_payload_ptr, address, prev_balance, retdest + DUP1 + %assert_nonzero %add_const(1) // stack: target_balance_ptr, address, prev_balance, retdest DUP3 @@ -25,8 +28,11 @@ revert_account_destroyed_contd: SUB SWAP1 %mstore_trie_data // Set `address`'s balance to `prev_balance`. // stack: address, prev_balance, retdest - %mpt_read_state_trie - %add_const(1) + %read_accounts_linked_list + // stack: account_payload_ptr, prev_balance, retdest + DUP1 + %assert_nonzero + %increment + // stack: account_balance_payload_ptr, prev_balance, retdest %mstore_trie_data JUMP - diff --git a/evm_arithmetization/src/cpu/kernel/asm/journal/code_change.asm b/evm_arithmetization/src/cpu/kernel/asm/journal/code_change.asm index 5bb637c72..0fc33f9dd 100644 --- a/evm_arithmetization/src/cpu/kernel/asm/journal/code_change.asm +++ b/evm_arithmetization/src/cpu/kernel/asm/journal/code_change.asm @@ -9,7 +9,9 @@ global revert_code_change: POP %journal_load_2 // stack: address, prev_codehash, retdest - %mpt_read_state_trie + %read_accounts_linked_list + // stack: account_ptr, prev_codehash, retdest + DUP1 %assert_nonzero // stack: account_ptr, prev_codehash, retdest %add_const(3) // stack: codehash_ptr, prev_codehash, retdest diff --git a/evm_arithmetization/src/cpu/kernel/asm/journal/nonce_change.asm b/evm_arithmetization/src/cpu/kernel/asm/journal/nonce_change.asm index 3ab8f1367..0c1198e52 100644 --- a/evm_arithmetization/src/cpu/kernel/asm/journal/nonce_change.asm +++ b/evm_arithmetization/src/cpu/kernel/asm/journal/nonce_change.asm @@ -9,9 +9,10 @@ global revert_nonce_change: POP %journal_load_2 // stack: address, prev_nonce, retdest - %mpt_read_state_trie - // stack: nonce_ptr, prev_nonce retdest + %read_accounts_linked_list + // stack: payload_ptr, prev_nonce, retdest + DUP1 %assert_nonzero + // stack: nonce_ptr, prev_nonce, retdest %mstore_trie_data // stack: retdest JUMP - diff --git a/evm_arithmetization/src/cpu/kernel/asm/journal/storage_change.asm b/evm_arithmetization/src/cpu/kernel/asm/journal/storage_change.asm index 752674d1e..695975c1f 100644 --- a/evm_arithmetization/src/cpu/kernel/asm/journal/storage_change.asm +++ b/evm_arithmetization/src/cpu/kernel/asm/journal/storage_change.asm @@ -11,47 +11,14 @@ global revert_storage_change: // stack: address, slot, prev_value, retdest DUP3 ISZERO %jumpi(delete) // stack: address, slot, prev_value, retdest - SWAP1 %slot_to_storage_key - // stack: storage_key, address, prev_value, retdest - PUSH 64 // storage_key has 64 nibbles - // stack: 64, storage_key, address, prev_value, retdest - DUP3 %mpt_read_state_trie - DUP1 ISZERO %jumpi(panic) - // stack: account_ptr, 64, storage_key, address, prev_value, retdest - %add_const(2) - // stack: storage_root_ptr_ptr, 64, storage_key, address, prev_value, retdest - %mload_trie_data - %get_trie_data_size - DUP6 %append_to_trie_data - %stack (prev_value_ptr, storage_root_ptr, num_nibbles, storage_key, address, prev_value, retdest) -> - (storage_root_ptr, num_nibbles, storage_key, prev_value_ptr, new_storage_root, address, retdest) - %jump(mpt_insert) + %insert_slot_with_value + JUMP delete: // stack: address, slot, prev_value, retdest SWAP2 POP - %stack (slot, address, retdest) -> (slot, new_storage_root, address, retdest) + // stack: slot, address, retdest %slot_to_storage_key - // stack: storage_key, new_storage_root, address, retdest - PUSH 64 // storage_key has 64 nibbles - // stack: 64, storage_key, new_storage_root, address, retdest - DUP4 %mpt_read_state_trie - DUP1 ISZERO %jumpi(panic) - // stack: account_ptr, 64, storage_key, new_storage_root, address, retdest - %add_const(2) - // stack: storage_root_ptr_ptr, 64, storage_key, new_storage_root, address, retdest - %mload_trie_data - // stack: storage_root_ptr, 64, storage_key, new_storage_root, address, retdest - %jump(mpt_delete) - -new_storage_root: - // stack: new_storage_root_ptr, address, retdest - DUP2 %mpt_read_state_trie - // stack: account_ptr, new_storage_root_ptr, address, retdest - - // Update account with our new storage root pointer. - %add_const(2) - // stack: account_storage_root_ptr_ptr, new_storage_root_ptr, address, retdest - %mstore_trie_data - // stack: address, retdest - POP JUMP + SWAP1 %addr_to_state_key + // stack: addr_key, slot_key, retdest + %jump(remove_slot) diff --git a/evm_arithmetization/src/cpu/kernel/asm/main.asm b/evm_arithmetization/src/cpu/kernel/asm/main.asm index 5d6d96799..f2663347a 100644 --- a/evm_arithmetization/src/cpu/kernel/asm/main.asm +++ b/evm_arithmetization/src/cpu/kernel/asm/main.asm @@ -1,17 +1,66 @@ -global main: - // First, hash the kernel code - %mload_global_metadata(@GLOBAL_METADATA_KERNEL_LEN) - PUSH 0 - // stack: addr, len - KECCAK_GENERAL - // stack: hash - %mload_global_metadata(@GLOBAL_METADATA_KERNEL_HASH) - // stack: expected_hash, hash +global init: + PUSH @SEGMENT_REGISTERS_STATES + // stack: addr_registers + // First, set the registers correctly and verify their values. + PUSH 2 + %stack_length SUB + // stack: prev_stack_len, addr_registers + // First, check the stack length. + DUP1 + DUP3 %add_const(2) + // stack: stack_len_addr, prev_stack_len, prev_stack_len, addr_registers + MLOAD_GENERAL + %assert_eq + + // Now, we want to check the stack top. For this, we load + // the value at offset (prev_stack_len - 1) * (stack_len > 0), + // since we do not constrain the stack top when the stack is empty. + // stack: prev_stack_len, addr_registers + DUP1 PUSH 0 LT + // stack: 0 < prev_stack_len, prev_stack_len, addr_registers + PUSH 1 DUP3 SUB + // stack: prev_stack_len - 1, 0 < prev_stack_len, prev_stack_len, addr_registers + MUL + PUSH @SEGMENT_STACK + GET_CONTEXT + %build_address + // stack: stack_top_addr, prev_stack_len, addr_registers + MLOAD_GENERAL + + // stack: stack_top, prev_stack_len, addr_registers + DUP3 %add_const(3) + MLOAD_GENERAL + // stack: pv_stack_top, stack_top, prev_stack_len, addr_registers + SUB + // If the stack length was previously 0, we do not need to check the previous stack top. + MUL + // stack: (pv_stack_top - stack_top) * prev_stack_len, addr_registers + %assert_zero + + // Check the context. + GET_CONTEXT + // stack: context, addr_registers + DUP2 %add_const(4) + MLOAD_GENERAL %shl_const(64) + // stack: stored_context, context, addr_registers %assert_eq - // Initialise the shift table - %shift_table_init + // Construct `kexit_info`. + DUP1 MLOAD_GENERAL + // stack: program_counter, addr_registers + DUP2 %increment + MLOAD_GENERAL + // stack: is_kernel, program_counter, addr_registers + %shl_const(32) ADD + // stack: is_kernel << 32 + program_counter, addr_registers + SWAP1 %add_const(5) MLOAD_GENERAL + // stack: gas_used, is_kernel << 32 + program_counter + %shl_const(192) ADD + // stack: kexit_info = gas_used << 192 + is_kernel << 32 + program_counter + // Now, we set the PC, is_kernel and gas_used to the correct values and continue the execution. + EXIT_KERNEL +global main: // Initialize accessed addresses and storage keys lists %init_access_lists @@ -21,14 +70,23 @@ global main: // Initialize the RLP DATA pointer to its initial position, // skipping over the preinitialized empty node. PUSH @INITIAL_TXN_RLP_ADDR + %add_const(@MAX_RLP_BLOB_SIZE) %mstore_global_metadata(@GLOBAL_METADATA_RLP_DATA_SIZE) // Encode constant nodes %initialize_rlp_segment + + // Initialize trie data size. + PROVER_INPUT(trie_ptr::trie_data_size) + %mstore_global_metadata(@GLOBAL_METADATA_TRIE_DATA_SIZE) + +global store_initial: + // Store the initial accounts and slots for hashing later + %store_initial_accounts + %store_initial_slots - // Initialize the state, transaction and receipt trie root pointers. - PROVER_INPUT(trie_ptr::state) - %mstore_global_metadata(@GLOBAL_METADATA_STATE_TRIE_ROOT) +global after_store_initial: + // Initialize the transaction and receipt trie root pointers. PROVER_INPUT(trie_ptr::txn) %mstore_global_metadata(@GLOBAL_METADATA_TXN_TRIE_ROOT) PROVER_INPUT(trie_ptr::receipt) @@ -37,73 +95,160 @@ global main: global hash_initial_tries: // We compute the length of the trie data segment in `mpt_hash` so that we // can check the value provided by the prover. - // We initialize the segment length with 1 because the segment contains - // the null pointer `0` when the tries are empty. - PUSH 1 - %mpt_hash_state_trie %mload_global_metadata(@GLOBAL_METADATA_STATE_TRIE_DIGEST_BEFORE) %assert_eq - // stack: trie_data_len + // The trie data segment is already written by the linked lists + + // First, we compute the initial size of the trie data segment. + PUSH @ACCOUNTS_LINKED_LISTS_NODE_SIZE + PUSH @SEGMENT_ACCOUNTS_LINKED_LIST + %mload_global_metadata(@GLOBAL_METADATA_ACCOUNTS_LINKED_LIST_NEXT_AVAILABLE) + SUB + // stack: accounts_ll_full_len, accounts_ll_node_size + DIV + %decrement + // stack: actual_nb_accounts + // The initial payloads are written twice, and each payload requires 4 elements. + PUSH 8 MUL + %increment + // stack: init_trie_data_len + %mpt_hash_txn_trie %mload_global_metadata(@GLOBAL_METADATA_TXN_TRIE_DIGEST_BEFORE) %assert_eq // stack: trie_data_len %mpt_hash_receipt_trie %mload_global_metadata(@GLOBAL_METADATA_RECEIPT_TRIE_DIGEST_BEFORE) %assert_eq // stack: trie_data_full_len - %mstore_global_metadata(@GLOBAL_METADATA_TRIE_DATA_SIZE) - - // If txn_idx == 0, update the beacon_root and exit roots. - %mload_global_metadata(@GLOBAL_METADATA_TXN_NUMBER_BEFORE) - ISZERO - %jumpi(set_beacon_root) + // Check that the trie data length is correct. + %mload_global_metadata(@GLOBAL_METADATA_TRIE_DATA_SIZE) + %assert_eq -global start_txn: +global start_txns: // stack: (empty) %mload_global_metadata(@GLOBAL_METADATA_TXN_NUMBER_BEFORE) // stack: txn_nb DUP1 %scalar_to_rlp // stack: txn_counter, txn_nb DUP1 %num_bytes %mul_const(2) - // stack: num_nibbles, txn_counter, txn_nb - %increment_bounded_rlp - // stack: txn_counter, num_nibbles, next_txn_counter, next_num_nibbles, txn_nb + SWAP1 + // stack: txn_counter, num_nibbles, txn_nb %mload_global_metadata(@GLOBAL_METADATA_BLOCK_GAS_USED_BEFORE) + // stack: init_gas_used, txn_counter, num_nibbles, txn_nb - // stack: init_gas_used, txn_counter, num_nibbles, next_txn_counter, next_num_nibbles, txn_nb + // If txn_idx == 0, update the beacon_root and exit roots. + %mload_global_metadata(@GLOBAL_METADATA_TXN_NUMBER_BEFORE) + ISZERO + %jumpi(set_beacon_root) - // If the prover has no txn for us to process, halt. - PROVER_INPUT(no_txn) + // stack: init_gas_used, txn_counter, num_nibbles, txn_nb +global txn_loop: + // If the prover has no more txns for us to process, halt. + PROVER_INPUT(end_of_txns) %jumpi(execute_withdrawals) // Call route_txn. When we return, we will process the txn receipt. - PUSH txn_after - // stack: retdest, prev_gas_used, txn_counter, num_nibbles, next_txn_counter, next_num_nibbles, txn_nb - DUP4 DUP4 + PUSH txn_loop_after + // stack: retdest, prev_gas_used, txn_counter, num_nibbles, txn_nb + %stack(retdest, prev_gas_used, txn_counter, num_nibbles) -> (txn_counter, num_nibbles, retdest, prev_gas_used, txn_counter, num_nibbles) %jump(route_txn) -global txn_after: - // stack: success, leftover_gas, cur_cum_gas, prev_txn_counter, prev_num_nibbles, txn_counter, num_nibbles, txn_nb +global txn_loop_after: + // stack: success, leftover_gas, cur_cum_gas, prev_txn_counter, prev_num_nibbles, txn_nb + DUP5 DUP5 %increment_bounded_rlp + // stack: txn_counter, num_nibbles, success, leftover_gas, cur_cum_gas, prev_txn_counter, prev_num_nibbles, txn_nb + %stack (txn_counter, num_nibbles, success, leftover_gas, cur_cum_gas, prev_txn_counter, prev_num_nibbles) -> (success, leftover_gas, cur_cum_gas, prev_txn_counter, prev_num_nibbles, txn_counter, num_nibbles) %process_receipt + // stack: new_cum_gas, txn_counter, num_nibbles, txn_nb SWAP3 %increment SWAP3 - %jump(execute_withdrawals_post_stack_op) + + // Re-initialize memory values before processing the next txn. + %reinitialize_memory_pre_txn + + // stack: new_cum_gas, txn_counter, num_nibbles, new_txn_number + %jump(txn_loop) global execute_withdrawals: - // stack: cum_gas, txn_counter, num_nibbles, next_txn_counter, next_num_nibbles, txn_nb - %stack (cum_gas, txn_counter, num_nibbles, next_txn_counter, next_num_nibbles) -> (cum_gas, txn_counter, num_nibbles) -execute_withdrawals_post_stack_op: + // stack: cum_gas, txn_counter, num_nibbles, txn_nb %withdrawals global perform_final_checks: // stack: cum_gas, txn_counter, num_nibbles, txn_nb // Check that we end up with the correct `cum_gas`, `txn_nb` and bloom filter. %mload_global_metadata(@GLOBAL_METADATA_BLOCK_GAS_USED_AFTER) %assert_eq - DUP3 %mload_global_metadata(@GLOBAL_METADATA_TXN_NUMBER_AFTER) %assert_eq + DUP3 + %mload_global_metadata(@GLOBAL_METADATA_TXN_NUMBER_AFTER) %assert_eq %pop3 - PUSH 1 // initial trie data length -global check_state_trie: - %mpt_hash_state_trie %mload_global_metadata(@GLOBAL_METADATA_STATE_TRIE_DIGEST_AFTER) %assert_eq + + // We set a dummy value as an initial trie data length, + // since the final transaction and receipt tries have already been + // added to `GLOBAL_METADATA_TRIE_DATA_SIZE`. + PUSH 1 + global check_txn_trie: %mpt_hash_txn_trie %mload_global_metadata(@GLOBAL_METADATA_TXN_TRIE_DIGEST_AFTER) %assert_eq global check_receipt_trie: %mpt_hash_receipt_trie %mload_global_metadata(@GLOBAL_METADATA_RECEIPT_TRIE_DIGEST_AFTER) %assert_eq +global check_state_trie: + // First, check initial trie. + // We pop the dummy trie data length that was computed. + POP + // Now, we get the trie data size so we can add the values from the + // initial trie data size and check that the value stored in + // `GLOBAL_METADATA_TRIE_DATA_SIZE` is correct. + %get_trie_data_size + // stack: trie_data_len + PROVER_INPUT(trie_ptr::state) + + %mstore_global_metadata(@GLOBAL_METADATA_STATE_TRIE_ROOT) + + PROVER_INPUT(trie_ptr::trie_data_size) + %mstore_global_metadata(@GLOBAL_METADATA_TRIE_DATA_SIZE) + + %set_initial_tries + %mpt_hash_state_trie + + // stack: init_state_hash, trie_data_len + // Check that the initial trie is correct. + %mload_global_metadata(@GLOBAL_METADATA_STATE_TRIE_DIGEST_BEFORE) + %assert_eq + // Check that the stored trie data length is correct. + %mload_global_metadata(@GLOBAL_METADATA_TRIE_DATA_SIZE) + %assert_eq + + // We set a dummy value as an initial trie data length, + // as we do not need to compute the actual trie data length here. + PUSH 1 +global check_final_state_trie: + %set_final_tries + %mpt_hash_state_trie %mload_global_metadata(@GLOBAL_METADATA_STATE_TRIE_DIGEST_AFTER) %assert_eq // We don't need the trie data length here. POP + + // We have reached the end of the execution, so we set the pruning flag to 1 for context 0. + PUSH 1 + SET_CONTEXT + %jump(halt) + +%macro reinitialize_memory_pre_txn + // Reinitialize accessed addresses and storage keys lists + %init_access_lists + + // Reinitialize transient storage + %init_transient_storage_len + + // Reinitialize global metadata + PUSH 0 %mstore_global_metadata(@GLOBAL_METADATA_CONTRACT_CREATION) + PUSH 0 %mstore_global_metadata(@GLOBAL_METADATA_IS_PRECOMPILE_FROM_EOA) + PUSH 0 %mstore_global_metadata(@GLOBAL_METADATA_LOGS_LEN) + PUSH 0 %mstore_global_metadata(@GLOBAL_METADATA_LOGS_DATA_LEN) + PUSH 0 %mstore_global_metadata(@GLOBAL_METADATA_LOGS_PAYLOAD_LEN) + PUSH 0 %mstore_global_metadata(@GLOBAL_METADATA_JOURNAL_LEN) + PUSH 0 %mstore_global_metadata(@GLOBAL_METADATA_JOURNAL_DATA_LEN) + PUSH 0 %mstore_global_metadata(@GLOBAL_METADATA_REFUND_COUNTER) + PUSH 0 %mstore_global_metadata(@GLOBAL_METADATA_SELFDESTRUCT_LIST_LEN) + + // Reinitialize `chain_id` for legacy transactions and `to` transaction field + PUSH 0 %mstore_txn_field(@TXN_FIELD_CHAIN_ID_PRESENT) + PUSH 0 %mstore_txn_field(@TXN_FIELD_TO) + + %reset_blob_versioned_hashes +%endmacro diff --git a/evm_arithmetization/src/cpu/kernel/asm/memory/memset.asm b/evm_arithmetization/src/cpu/kernel/asm/memory/memset.asm index 2cd50e85a..a4f2c9385 100644 --- a/evm_arithmetization/src/cpu/kernel/asm/memory/memset.asm +++ b/evm_arithmetization/src/cpu/kernel/asm/memory/memset.asm @@ -46,3 +46,10 @@ memset_bytes_empty: %pop2 // stack: retdest JUMP + + +%macro memset + %stack (dst, count) -> (dst, count, %%after) + %jump(memset) +%%after: +%endmacro \ No newline at end of file diff --git a/evm_arithmetization/src/cpu/kernel/asm/memory/syscalls.asm b/evm_arithmetization/src/cpu/kernel/asm/memory/syscalls.asm index 9454090b6..b87556975 100644 --- a/evm_arithmetization/src/cpu/kernel/asm/memory/syscalls.asm +++ b/evm_arithmetization/src/cpu/kernel/asm/memory/syscalls.asm @@ -119,6 +119,15 @@ calldataload_large_offset: codecopy_within_bounds: // stack: total_size, segment, src_ctx, kexit_info, dest_offset, offset, size POP + // stack: segment, src_ctx, kexit_info, dest_offset, offset, size + GET_CONTEXT + %stack (context, segment, src_ctx, kexit_info, dest_offset, offset, size) -> + (src_ctx, segment, offset, @SEGMENT_MAIN_MEMORY, dest_offset, context, size, codecopy_after, src_ctx, kexit_info) + %build_address + SWAP3 %build_address + // stack: DST, SRC, size, codecopy_after, src_ctx, kexit_info + %jump(memcpy_bytes) + wcopy_within_bounds: // TODO: rework address creation to have less stack manipulation overhead // stack: segment, src_ctx, kexit_info, dest_offset, offset, size @@ -139,7 +148,15 @@ wcopy_empty: codecopy_large_offset: // stack: total_size, src_ctx, kexit_info, dest_offset, offset, size - %pop2 + POP + // offset is larger than the size of the {CALLDATA,CODE,RETURNDATA}. So we just have to write zeros. + // stack: src_ctx, kexit_info, dest_offset, offset, size + GET_CONTEXT + %stack (context, src_ctx, kexit_info, dest_offset, offset, size) -> + (context, @SEGMENT_MAIN_MEMORY, dest_offset, size, codecopy_after, src_ctx, kexit_info) + %build_address + %jump(memset) + wcopy_large_offset: // offset is larger than the size of the {CALLDATA,CODE,RETURNDATA}. So we just have to write zeros. // stack: kexit_info, dest_offset, offset, size @@ -149,6 +166,24 @@ wcopy_large_offset: %build_address %jump(memset) +codecopy_after: + // stack: src_ctx, kexit_info + DUP1 GET_CONTEXT + // stack: ctx, src_ctx, src_ctx, kexit_info + // If ctx == src_ctx, it's a CODECOPY, and we don't need to prune the context. + EQ + // stack: ctx == src_ctx, src_ctx, kexit_info + %jumpi(codecopy_no_prune) + // stack: src_ctx, kexit_info + %prune_context + // stack: kexit_info + EXIT_KERNEL + +codecopy_no_prune: + // stack: src_ctx, kexit_info + POP + EXIT_KERNEL + wcopy_after: // stack: kexit_info EXIT_KERNEL @@ -341,9 +376,37 @@ mcopy_empty: GET_CONTEXT %stack (context, new_dest_offset, copy_size, extra_size, segment, src_ctx, kexit_info, dest_offset, offset, size) -> - (src_ctx, segment, offset, @SEGMENT_MAIN_MEMORY, dest_offset, context, copy_size, wcopy_large_offset, kexit_info, new_dest_offset, offset, extra_size) + (src_ctx, segment, offset, @SEGMENT_MAIN_MEMORY, dest_offset, context, copy_size, codecopy_large_offset, copy_size, src_ctx, kexit_info, new_dest_offset, offset, extra_size) %build_address SWAP3 %build_address - // stack: DST, SRC, copy_size, wcopy_large_offset, kexit_info, new_dest_offset, offset, extra_size + // stack: DST, SRC, copy_size, codecopy_large_offset, copy_size, src_ctx, kexit_info, new_dest_offset, offset, extra_size %jump(memcpy_bytes) %endmacro + +// Adds stale_ctx to the list of stale contexts. You need to return to a previous, older context with +// a SET_CONTEXT instruction. By assumption, stale_ctx is greater than the current context. +%macro prune_context + // stack: stale_ctx + GET_CONTEXT + // stack: curr_ctx, stale_ctx + // When we go to stale_ctx, we want its stack to contain curr_ctx so that we can immediately + // call SET_CONTEXT. For that, we need a stack length of 1, and store curr_ctx in Segment::Stack[0]. + PUSH @SEGMENT_STACK + DUP3 ADD + // stack: stale_ctx_stack_addr, curr_ctx, stale_ctx + DUP2 + // stack: curr_ctx, stale_ctx_stack_addr, curr_ctx, stale_ctx + MSTORE_GENERAL + // stack: curr_ctx, stale_ctx + PUSH @CTX_METADATA_STACK_SIZE + DUP3 ADD + // stack: stale_ctx_stack_size_addr, curr_ctx, stale_ctx + PUSH 1 + MSTORE_GENERAL + // stack: curr_ctx, stale_ctx + POP + SET_CONTEXT + // We're now in stale_ctx, with stack: curr_ctx + %set_and_prune_ctx + // We're now in curr_ctx, with an empty stack. +%endmacro diff --git a/evm_arithmetization/src/cpu/kernel/asm/mpt/accounts.asm b/evm_arithmetization/src/cpu/kernel/asm/mpt/accounts.asm index 1f60a3f75..e1fa20c08 100644 --- a/evm_arithmetization/src/cpu/kernel/asm/mpt/accounts.asm +++ b/evm_arithmetization/src/cpu/kernel/asm/mpt/accounts.asm @@ -1,43 +1,39 @@ -// Return a pointer to the current account's data in the state trie. -%macro current_account_data - %address %mpt_read_state_trie +%macro clone_account // stack: account_ptr - // account_ptr should be non-null as long as the prover provided the proper - // Merkle data. But a bad prover may not have, and we don't want return a - // null pointer for security reasons. - DUP1 ISZERO %jumpi(panic) - // stack: account_ptr -%endmacro - -// Returns a pointer to the root of the storage trie associated with the current account. -%macro current_storage_trie - // stack: (empty) - %current_account_data - // stack: account_ptr - %add_const(2) - // stack: storage_root_ptr_ptr + %get_trie_data_size + // stack: cloned_account_ptr + SWAP1 + DUP1 + // Balance %mload_trie_data - // stack: storage_root_ptr -%endmacro - -// Return a pointer to the provided account's data in the state trie. -%macro get_account_data(addr) - PUSH $addr %mpt_read_state_trie - // stack: account_ptr - // account_ptr should be non-null as long as the prover provided the proper - // Merkle data. But a bad prover may not have, and we don't want return a - // null pointer for security reasons. - DUP1 ISZERO %jumpi(panic) - // stack: account_ptr + %append_to_trie_data + %increment + // Nonce + %increment + DUP1 + %mload_trie_data + %append_to_trie_data + // Storage trie root + %increment + DUP1 + %mload_trie_data + %append_to_trie_data + // Codehash + %increment + %mload_trie_data + %append_to_trie_data + // stack: cloned_account_ptr %endmacro -// Returns a pointer to the root of the storage trie associated with the provided account. -%macro get_storage_trie(key) - // stack: (empty) - %get_account_data($key) - // stack: account_ptr - %add_const(2) - // stack: storage_root_ptr_ptr +// The slot_ptr cannot be 0, because `insert_slot` +// is only called in `revert_storage_change` (where the case `slot_ptr = 0` +// is dealt with differently), and in `storage_write`, +// where writing 0 actually corresponds to a `delete`. +%macro clone_slot + // stack: slot_ptr + %get_trie_data_size + // stack: cloned_slot_ptr, slot_ptr + SWAP1 %mload_trie_data - // stack: storage_root_ptr -%endmacro \ No newline at end of file + %append_to_trie_data +%endmacro diff --git a/evm_arithmetization/src/cpu/kernel/asm/mpt/delete/delete.asm b/evm_arithmetization/src/cpu/kernel/asm/mpt/delete/delete.asm index 913ba1fcf..c878ae812 100644 --- a/evm_arithmetization/src/cpu/kernel/asm/mpt/delete/delete.asm +++ b/evm_arithmetization/src/cpu/kernel/asm/mpt/delete/delete.asm @@ -24,17 +24,13 @@ mpt_delete_leaf: SWAP1 JUMP global delete_account: - %stack (address, retdest) -> (address, delete_account_save, retdest) %addr_to_state_key - // stack: key, delete_account_save, retdest - PUSH 64 - // stack: 64, key, delete_account_save, retdest - %mload_global_metadata(@GLOBAL_METADATA_STATE_TRIE_ROOT) - // stack: state_root_prt, 64, key, delete_account_save, retdest - %jump(mpt_delete) -delete_account_save: - // stack: updated_state_root_ptr, retdest - %mstore_global_metadata(@GLOBAL_METADATA_STATE_TRIE_ROOT) + DUP1 + %remove_account_from_linked_list + // stack: addr_to_state_key, retdest + + // Now we also need to remove all the storage nodes associated with the deleted account. + %remove_all_account_slots JUMP %macro delete_account @@ -42,4 +38,4 @@ delete_account_save: %jump(delete_account) %%after: // stack: (empty) -%endmacro \ No newline at end of file +%endmacro diff --git a/evm_arithmetization/src/cpu/kernel/asm/mpt/hash/hash_trie_specific.asm b/evm_arithmetization/src/cpu/kernel/asm/mpt/hash/hash_trie_specific.asm index 8d8c7a419..1da7c6a31 100644 --- a/evm_arithmetization/src/cpu/kernel/asm/mpt/hash/hash_trie_specific.asm +++ b/evm_arithmetization/src/cpu/kernel/asm/mpt/hash/hash_trie_specific.asm @@ -101,10 +101,6 @@ global encode_account: DUP3 %add_const(2) %mload_trie_data // storage_root_ptr = value[2] // stack: storage_root_ptr, cur_len, rlp_pos_5, value_ptr, cur_len, retdest - - PUSH debug_after_hash_storage_trie - POP - // Hash storage trie. %mpt_hash_storage_trie // stack: storage_root_digest, new_len, rlp_pos_5, value_ptr, cur_len, retdest @@ -352,4 +348,3 @@ global encode_storage_value: // stack: rlp_addr', cur_len, retdest %stack (rlp_addr, cur_len, retdest) -> (retdest, rlp_addr, cur_len) JUMP - diff --git a/evm_arithmetization/src/cpu/kernel/asm/mpt/insert/insert_trie_specific.asm b/evm_arithmetization/src/cpu/kernel/asm/mpt/insert/insert_trie_specific.asm index 71f78ec5b..e1e82b562 100644 --- a/evm_arithmetization/src/cpu/kernel/asm/mpt/insert/insert_trie_specific.asm +++ b/evm_arithmetization/src/cpu/kernel/asm/mpt/insert/insert_trie_specific.asm @@ -6,15 +6,7 @@ // TODO: Have this take an address and do %mpt_insert_state_trie? To match mpt_read_state_trie. global mpt_insert_state_trie: // stack: key, value_ptr, retdest - %stack (key, value_ptr) - -> (key, value_ptr, mpt_insert_state_trie_save) - PUSH 64 // num_nibbles - %mload_global_metadata(@GLOBAL_METADATA_STATE_TRIE_ROOT) - // stack: state_root_ptr, num_nibbles, key, value_ptr, mpt_insert_state_trie_save, retdest - %jump(mpt_insert) -mpt_insert_state_trie_save: - // stack: updated_node_ptr, retdest - %mstore_global_metadata(@GLOBAL_METADATA_STATE_TRIE_ROOT) + %insert_account_with_overwrite JUMP %macro mpt_insert_state_trie diff --git a/evm_arithmetization/src/cpu/kernel/asm/mpt/linked_list/final_tries.asm b/evm_arithmetization/src/cpu/kernel/asm/mpt/linked_list/final_tries.asm new file mode 100644 index 000000000..9db07083d --- /dev/null +++ b/evm_arithmetization/src/cpu/kernel/asm/mpt/linked_list/final_tries.asm @@ -0,0 +1,238 @@ +// Given a pointer `root_ptr` to the root of a trie, insert all accounts in +// the accounts_linked_list starting at `account_ptr_ptr` as well as the +// respective storage slots in `storage_ptr_ptr`. +// Pre stack: account_ptr_ptr, root_ptr, storage_ptr_ptr, retdest +// Post stack: new_root_ptr. +global insert_all_accounts: + // stack: account_ptr_ptr, root_ptr, storage_ptr_ptr, retdest + SWAP2 + DUP3 + MLOAD_GENERAL + // stack: key, storage_ptr_ptr, root_ptr, account_ptr_ptr, retdest + DUP1 + %eq_const(@U256_MAX) + %jumpi(no_more_accounts) + // stack: key, storage_ptr_ptr, root_ptr, account_ptr_ptr, retdest + DUP4 + %increment + MLOAD_GENERAL + // stack: account_ptr, key, storage_ptr_ptr, root_ptr, account_ptr_ptr, retdest + %add_const(2) + DUP1 + %mload_trie_data + // stack: storage_root_ptr, storage_root_ptr_ptr, key, storage_ptr_ptr, root_ptr, account_ptr_ptr, retdest + %stack + (storage_root_ptr, storage_root_ptr_ptr, key, storage_ptr_ptr) -> + (key, storage_ptr_ptr, storage_root_ptr, after_insert_all_slots, storage_root_ptr_ptr, key) + %jump(insert_all_slots) + +after_insert_all_slots: + // stack: storage_ptr_ptr', storage_root_ptr', storage_root_ptr_ptr, key, root_ptr, account_ptr_ptr, retdest + SWAP2 + %mstore_trie_data + // stack: storage_ptr_ptr', key, root_ptr, account_ptr_ptr, retdest + DUP4 + %increment + MLOAD_GENERAL + %stack + (payload_ptr, storage_ptr_ptr_p, key, root_ptr, account_ptr_ptr) -> + (root_ptr, 64, key, payload_ptr, after_insert_account, account_ptr_ptr, storage_ptr_ptr_p) + %jump(mpt_insert) +after_insert_account: + // stack: root_ptr', account_ptr_ptr, storage_ptr_ptr', retdest + SWAP1 + %next_account + // stack: account_ptr_ptr', root_ptr', storage_ptr_ptr', retdest + %jump(insert_all_accounts) + +no_more_accounts: + // stack: key, storage_ptr_ptr, root_ptr, account_ptr_ptr, retdest + %stack (key, storage_ptr_ptr, root_ptr, account_ptr_ptr, retdest) ->(retdest, root_ptr) + JUMP + +// Insert all slots before the account key changes +// Pre stack: addr, storage_ptr_ptr, root_ptr, retdest +// Post stack: storage_ptr_ptr', root_ptr' +global insert_all_slots: + DUP2 + MLOAD_GENERAL + DUP2 + EQ // Check that the node addres is the same as `addr` + %jumpi(insert_next_slot) + // The addr has changed, meaning that we've inserted all slots for addr + // stack: addr, storage_ptr_ptr, root_ptr, retdest + %stack (addr, storage_ptr_ptr, root_ptr, retdest) -> (retdest, storage_ptr_ptr, root_ptr) + JUMP + +insert_next_slot: + // stack: addr, storage_ptr_ptr, root_ptr, retdest + DUP2 + %increment + MLOAD_GENERAL + // stack: key, addr, storage_ptr_ptr, root_ptr, retdest + DUP3 + %add_const(2) + MLOAD_GENERAL + // stack: value, key, addr, storage_ptr_ptr, root_ptr, retdest + // If the value is 0, then payload_ptr = 0, and we don't need to insert a value in the `TrieData` segment. + DUP1 ISZERO %jumpi(insert_with_payload_ptr) + %get_trie_data_size // payload_ptr + SWAP1 %append_to_trie_data // append the value to the trie data segment +insert_with_payload_ptr: + %stack (payload_ptr, key, addr, storage_ptr_ptr, root_ptr) -> (root_ptr, 64, key, payload_ptr, after_insert_slot, storage_ptr_ptr, addr) + %jump(mpt_insert) +after_insert_slot: + // stack: root_ptr', storage_ptr_ptr, addr, retdest + SWAP1 + %next_slot + // stack: storage_ptr_ptr', root_ptr', addr + %stack (storage_ptr_ptr_p, root_ptr_p, addr) -> (addr, storage_ptr_ptr_p, root_ptr_p) + %jump(insert_all_slots) + +// Delete all the accounts, referenced by the respective nodes in the linked list starting at +// `account_ptr_ptr`, which where deleted from the initial state. Delete also all slots of non-deleted accounts +// deleted from the storage trie. +// Pre stack: account_ptr_ptr, root_ptr, storage_ptr_ptr, retdest +// Post stack: new_root_ptr. +global delete_removed_accounts: + // stack: account_ptr_ptr, root_ptr, storage_ptr_ptr, retdest + DUP1 + // We assume that the size of the initial accounts linked list, containing the accounts + // of the initial state, was stored at `@GLOBAL_METADATA_INITIAL_ACCOUNTS_LINKED_LIST_LEN`. + %mload_global_metadata(@GLOBAL_METADATA_INITIAL_ACCOUNTS_LINKED_LIST_LEN) + // The initial accounts linked list was stored at addresses smaller than `@GLOBAL_METADATA_INITIAL_ACCOUNTS_LINKED_LIST_LEN`. + // If we also know that `@SEGMENT_ACCOUNT_LINKED_LIST <= account_ptr_ptr`, for deleting node at `addr_ptr_ptr` it + // suffices to check that `account_ptr_ptr` != `@GLOBAL_METADATA_INITIAL_ACCOUNTS_LINKED_LIST_LEN` + EQ + %jumpi(delete_removed_accounts_end) + // stack: account_ptr_ptr, root_ptr, storage_ptr_ptr, retdest + DUP1 + %next_account + %eq_const(@U256_MAX) // If the next node pointer is @U256_MAX, the node was deleted + %jumpi(delete_account) + // The account is still there so we need to delete any removed slot. + // stack: account_ptr_ptr, root_ptr, storage_ptr_ptr, retdest + DUP1 + MLOAD_GENERAL + // stack: key, account_ptr_ptr, root_ptr, storage_ptr_ptr, retdest + DUP2 + %add_const(2) + MLOAD_GENERAL // get initial payload_ptr + %add_const(2) // storage_root_ptr_ptr = payload_ptr + 2 + %mload_trie_data + // stack: storage_root_ptr, key, account_ptr_ptr, root_ptr, storage_ptr_ptr, retdest + DUP3 + %increment + MLOAD_GENERAL // get dynamic payload_ptr + %add_const(2) // storage_root_ptr_ptr = dyn_payload_ptr + 2 + %stack + (storage_root_ptr_ptr, storage_root_ptr, key, account_ptr_ptr, root_ptr, storage_ptr_ptr) -> + (key, storage_root_ptr, storage_ptr_ptr, after_delete_removed_slots, storage_root_ptr_ptr, account_ptr_ptr, root_ptr) + %jump(delete_removed_slots) +after_delete_removed_slots: + // stack: storage_root_ptr', storage_ptr_ptr', storage_root_ptr_ptr, account_ptr_ptr, root_ptr, retdest + SWAP1 SWAP2 + // stack: storage_root_ptr_ptr, storage_root_ptr', storage_ptr_ptr', account_ptr_ptr, root_ptr, retdest + %mstore_trie_data + // stack: storage_ptr_ptr', account_ptr_ptr, root_ptr, retdest + SWAP1 + %add_const(@ACCOUNTS_LINKED_LISTS_NODE_SIZE) // The next account in memory + // stack: account_ptr_ptr', storage_ptr_ptr', root_ptr, retdest + SWAP1 SWAP2 SWAP1 + %jump(delete_removed_accounts) + +delete_removed_accounts_end: + // stack: account_ptr_ptr, root_ptr, storage_ptr_ptr, retdest + %stack (account_ptr_ptr, root_ptr, storage_ptr_ptr, retdest) -> (retdest, root_ptr) + JUMP +delete_account: + // stack: account_ptr_ptr, root_ptr, storage_ptr_ptr, retdest + DUP1 + MLOAD_GENERAL + %stack (key, account_ptr_ptr, root_ptr) -> (root_ptr, 64, key, after_mpt_delete, account_ptr_ptr) + // Pre stack: node_ptr, num_nibbles, key, retdest + // Post stack: updated_node_ptr + %jump(mpt_delete) +after_mpt_delete: + // stack: root_ptr', account_ptr_ptr, storage_ptr_ptr, retdest + SWAP1 + %add_const(@ACCOUNTS_LINKED_LISTS_NODE_SIZE) + %jump(delete_removed_accounts) + +// Delete all slots in `storage_ptr_ptr` with address == `addr` and +// `storage_ptr_ptr` < `@GLOBAL_METADATA_INITIAL_STORAGE_LINKED_LIST_LEN`. +// Pre stack: addr, root_ptr, storage_ptr_ptr, retdest +// Post stack: new_root_ptr, storage_ptr_ptr'. +delete_removed_slots: + // stack: addr, root_ptr, storage_ptr_ptr, retdest + DUP3 + MLOAD_GENERAL + // stack: address, addr, root_ptr, storage_ptr_ptr, retdest + DUP2 + EQ + // stack: loaded_address == addr, addr, root_ptr, storage_ptr_ptr, retdest + %mload_global_metadata(@GLOBAL_METADATA_INITIAL_STORAGE_LINKED_LIST_LEN) + DUP5 + LT + MUL // AND + // stack: loaded_address == addr AND storage_ptr_ptr < GLOBAL_METADATA_INITIAL_STORAGE_LINKED_LIST_LEN, addr, root_ptr, storage_ptr_ptr, retdest + // jump if we either change the address or reach the end of the initial linked list + %jumpi(maybe_delete_this_slot) + // If we are here we have deleted all the slots for this key + %stack (addr, root_ptr, storage_ptr_ptr, retdest) -> (retdest, root_ptr, storage_ptr_ptr) + JUMP +maybe_delete_this_slot: + // stack: addr, root_ptr, storage_ptr_ptr, retdest + DUP3 + %next_slot + %eq_const(@U256_MAX) // Check if the node was deleted + %jumpi(delete_this_slot) + // The slot was not deleted, so we skip it. + // stack: addr, root_ptr, storage_ptr_ptr, retdest + SWAP2 + %add_const(@STORAGE_LINKED_LISTS_NODE_SIZE) + SWAP2 + %jump(delete_removed_slots) +delete_this_slot: + // stack: addr, root_ptr, storage_ptr_ptr, retdest + DUP3 + %increment + MLOAD_GENERAL + %stack (key, addr, root_ptr, storage_ptr_ptr) -> (root_ptr, 64, key, after_mpt_delete_slot, addr, storage_ptr_ptr) + %jump(mpt_delete) +after_mpt_delete_slot: + // stack: root_ptr', addr, storage_ptr_ptr + SWAP2 + %add_const(@STORAGE_LINKED_LISTS_NODE_SIZE) + %stack (storage_ptr_ptr_p, addr, root_ptr_p) -> (addr, root_ptr_p, storage_ptr_ptr_p) + %jump(delete_removed_slots) + +global set_final_tries: + PUSH set_final_tries_after + PUSH @SEGMENT_STORAGE_LINKED_LIST + %add_const(@STORAGE_LINKED_LISTS_NODE_SIZE) // Skip the first node. + %mload_global_metadata(@GLOBAL_METADATA_STATE_TRIE_ROOT) + PUSH @SEGMENT_ACCOUNTS_LINKED_LIST + %add_const(@ACCOUNTS_LINKED_LISTS_NODE_SIZE) // Skip the first node. + %jump(delete_removed_accounts) +set_final_tries_after: + // stack: new_state_root + PUSH set_final_tries_after_after SWAP1 + // stack: new_state_root, set_final_tries_after_after + PUSH @SEGMENT_STORAGE_LINKED_LIST + %next_slot + SWAP1 + PUSH @SEGMENT_ACCOUNTS_LINKED_LIST + %next_account + %jump(insert_all_accounts) +set_final_tries_after_after: + //stack: new_state_root + %mstore_global_metadata(@GLOBAL_METADATA_STATE_TRIE_ROOT) + JUMP + +%macro set_final_tries + // stack: (empty) + PUSH %%after + %jump(set_final_tries) +%%after: +%endmacro diff --git a/evm_arithmetization/src/cpu/kernel/asm/mpt/linked_list/initial_tries.asm b/evm_arithmetization/src/cpu/kernel/asm/mpt/linked_list/initial_tries.asm new file mode 100644 index 000000000..795453667 --- /dev/null +++ b/evm_arithmetization/src/cpu/kernel/asm/mpt/linked_list/initial_tries.asm @@ -0,0 +1,197 @@ +// Set the payload pointers of the leaves in the trie with root at `node_ptr` +// to mem[payload_ptr_ptr] + step*i, +// for i =0..n_leaves. This is used to constraint the +// initial state and account tries payload pointers such that they are exactly +// those of the initial accounts and linked lists. +// Pre stack: node_ptr, account_ptr_ptr, storage_ptr_ptr, retdest +// Post stack: account_ptr_ptr, storage_ptr_ptr +global mpt_set_payload: + // stack: node_ptr, account_ptr_ptr, storage_ptr_ptr, retdest + DUP1 %mload_trie_data + // stack: node_type, node_ptr, account_ptr_ptr, storage_ptr_ptr, retdest + // Increment node_ptr, so it points to the node payload instead of its type. + SWAP1 %increment SWAP1 + // stack: node_type, after_node_type, account_ptr_ptr, storage_ptr_ptr, retdest + + DUP1 %eq_const(@MPT_NODE_EMPTY) %jumpi(skip) + DUP1 %eq_const(@MPT_NODE_BRANCH) %jumpi(set_payload_branch) + DUP1 %eq_const(@MPT_NODE_EXTENSION) %jumpi(set_payload_extension) + DUP1 %eq_const(@MPT_NODE_LEAF) %jumpi(set_payload_leaf) + DUP1 %eq_const(@MPT_NODE_HASH) %jumpi(skip) + PANIC + +skip: + // stack: node_type, after_node_type, account_ptr_ptr, storage_ptr_ptr, retdest + %stack (node_type, after_node_type, account_ptr_ptr, storage_ptr_ptr, retdest) -> (retdest, account_ptr_ptr, storage_ptr_ptr) + JUMP + +%macro mpt_set_payload + %stack(node_ptr, account_ptr_ptr, storage_ptr_ptr) -> (node_ptr, account_ptr_ptr, storage_ptr_ptr, %%after) + %jump(mpt_set_payload) +%%after: +%endmacro + +%macro set_initial_tries + PUSH %%after + PUSH @SEGMENT_STORAGE_LINKED_LIST + %add_const(8) // The first node is the special node, of size 5, so the first value is at position 5 + 3. + PUSH @SEGMENT_ACCOUNTS_LINKED_LIST + %add_const(6) // The first node is the special node, of size 4, so the first payload is at position 4 + 2. + %mload_global_metadata(@GLOBAL_METADATA_STATE_TRIE_ROOT) + %jump(mpt_set_payload) +%%after: + // We store account_ptr_ptr - 2, i.e. a pointer to the first node not in the initial state. + %sub_const(2) + %mstore_global_metadata(@GLOBAL_METADATA_INITIAL_ACCOUNTS_LINKED_LIST_LEN) + // We store storage_ptr_ptr - 3, i.e. a pointer to the first node not in the initial state. + %sub_const(3) + %mstore_global_metadata(@GLOBAL_METADATA_INITIAL_STORAGE_LINKED_LIST_LEN) +%endmacro + +// Pre stack: node_ptr, storage_ptr_ptr, retdest +// Post stack: storage_ptr_ptr +global mpt_set_storage_payload: + // stack: node_ptr, storage_ptr_ptr, retdest + DUP1 %mload_trie_data + // stack: node_type, node_ptr, storage_ptr_ptr, retdest + // Increment node_ptr, so it points to the node payload instead of its type. + SWAP1 %increment SWAP1 + // stack: node_type, after_node_type, storage_ptr_ptr, retdest + + DUP1 %eq_const(@MPT_NODE_EMPTY) %jumpi(storage_skip) + DUP1 %eq_const(@MPT_NODE_BRANCH) %jumpi(set_payload_storage_branch) + DUP1 %eq_const(@MPT_NODE_EXTENSION) %jumpi(set_payload_storage_extension) + DUP1 %eq_const(@MPT_NODE_LEAF) %jumpi(set_payload_storage_leaf) + +storage_skip: + // stack: node_type, after_node_type, storage_ptr_ptr, retdest + %stack (node_type, after_node_type, storage_ptr_ptr, retdest) -> (retdest, storage_ptr_ptr) + JUMP + +%macro mpt_set_storage_payload + %stack(node_ptr, storage_ptr_ptr) -> (node_ptr, storage_ptr_ptr, %%after) + %jump(mpt_set_storage_payload) +%%after: +%endmacro + +set_payload_branch: + // stack: node_type, after_node_type, account_ptr_ptr, storage_ptr_ptr, retdest + POP + + // Call mpt_set_payload on each child + %rep 16 + %stack + (child_ptr_ptr, account_ptr_ptr, storage_ptr_ptr) -> + (child_ptr_ptr, account_ptr_ptr, storage_ptr_ptr, child_ptr_ptr) + // stack: child_ptr_ptr, account_ptr_ptr, storage_ptr_ptr, child_ptr_ptr, retdest + %mload_trie_data + // stack: child_ptr, account_ptr_ptr, storage_ptr_ptr, child_ptr_ptr, retdest + %mpt_set_payload + // stack: account_ptr_ptr', storage_ptr_ptr', child_ptr_ptr, retdest + SWAP1 + SWAP2 + %increment + %endrep + // stack: child_ptr_ptr', account_ptr_ptr', storage_ptr_ptr', retdest + %stack (child_ptr_ptr, account_ptr_ptr, storage_ptr_ptr, retdest) -> (retdest, account_ptr_ptr, storage_ptr_ptr) + JUMP + +set_payload_storage_branch: + // stack: node_type, child_ptr_ptr, storage_ptr_ptr, retdest + POP + + // Call mpt_set_storage_payload on each child + %rep 16 + %stack + (child_ptr_ptr, storage_ptr_ptr) -> + (child_ptr_ptr, storage_ptr_ptr, child_ptr_ptr) + // stack: child_ptr_ptr, storage_ptr_ptr, child_ptr_ptr, retdest + %mload_trie_data + // stack: child_ptr, storage_ptr_ptr, child_ptr_ptr, retdest + %mpt_set_storage_payload + // stack: storage_ptr_ptr', child_ptr_ptr, retdest + SWAP1 + %increment + %endrep + // stack: child_ptr_ptr', storage_ptr_ptr', retdest + %stack (child_ptr_ptr, storage_ptr_ptr, retdest) -> (retdest, storage_ptr_ptr) + JUMP + +set_payload_extension: + // stack: node_type, after_node_type, account_ptr_ptr, storage_ptr_ptr, retdest + POP + // stack: after_node_type, account_ptr_ptr, storage_ptr_ptr, retdest + %add_const(2) %mload_trie_data + // stack: child_ptr, after_node_type, account_ptr_ptr, storage_ptr_ptr, retdest + %jump(mpt_set_payload) + +set_payload_storage_extension: + // stack: node_type, after_node_type, storage_ptr_ptr, retdest + POP + // stack: after_node_type, storage_ptr_ptr, retdest + %add_const(2) %mload_trie_data + // stack: child_ptr, storage_ptr_ptr, retdest + %jump(mpt_set_storage_payload) + +set_payload_leaf: + // stack: node_type, after_node_type, account_ptr_ptr, storage_ptr_ptr, retdest + POP + %add_const(2) // The payload pointer starts at index 3, after num_nibbles and packed_nibbles. + DUP1 + // stack: payload_ptr_ptr, payload_ptr_ptr, account_ptr_ptr, storage_ptr_ptr, retdest + %mload_trie_data + // stack: account_ptr, payload_ptr_ptr, account_ptr_ptr, storage_ptr_ptr, retdest + %add_const(2) + %mload_trie_data // storage_root_ptr = account[2] + + // stack: storage_root_ptr, payload_ptr_ptr, account_ptr_ptr, storage_ptr_ptr, retdest + %stack + (storage_root_ptr, payload_ptr_ptr, account_ptr_ptr, storage_ptr_ptr) -> + (storage_root_ptr, storage_ptr_ptr, after_set_storage_payload, storage_root_ptr, payload_ptr_ptr, account_ptr_ptr) + %jump(mpt_set_storage_payload) +after_set_storage_payload: + // stack: storage_ptr_ptr', storage_root_ptr, payload_ptr_ptr, account_ptr_ptr, retdest + DUP4 + MLOAD_GENERAL // load the next payload pointer in the linked list + DUP1 %add_const(2) // new_storage_root_ptr_ptr = payload_ptr[2] + // stack: new_storage_root_ptr_ptr, new_payload_ptr, storage_root_ptr, storage_ptr_ptr', payload_ptr_ptr, account_ptr_ptr, retdest + // Load also the old "dynamic" payload for storing the storage_root_ptr + DUP6 %decrement + MLOAD_GENERAL + %add_const(2) // dyn_storage_root_ptr_ptr = dyn_paylod_ptr[2] + %stack + (dyn_storage_root_ptr_ptr, new_storage_root_ptr_ptr, new_payload_ptr, storage_ptr_ptr_p, storage_root_ptr, payload_ptr_ptr, account_ptr_ptr) -> + (new_storage_root_ptr_ptr, storage_root_ptr, dyn_storage_root_ptr_ptr, storage_root_ptr, payload_ptr_ptr, new_payload_ptr, account_ptr_ptr, storage_ptr_ptr_p) + %mstore_trie_data // The initial account pointer in the linked list has no storage root so we need to manually set it. + %mstore_trie_data // The dynamic account pointer in the linked list has no storage root so we need to manually set it. + %mstore_trie_data // Set the leaf payload pointing to next account in the linked list. + // stack: account_ptr_ptr, storage_ptr_ptr', retdest + %add_const(@ACCOUNTS_LINKED_LISTS_NODE_SIZE) // The next pointer is at distance `ACCOUNTS_LINKED_LISTS_NODE_SIZE` + // stack: payload_ptr_ptr', storage_ptr_ptr', retdest + SWAP1 + SWAP2 + JUMP + +set_payload_storage_leaf: + // stack: node_type, after_node_type, storage_ptr_ptr, retdest + POP + // stack: after_node_type, storage_ptr_ptr, retdest + %add_const(2) // The value pointer starts at index 3, after num_nibbles and packed_nibbles. + // stack: value_ptr_ptr, storage_ptr_ptr, retdest + DUP2 MLOAD_GENERAL + // stack: value, value_ptr_ptr, storage_ptr_ptr, retdest + // If value == 0, then value_ptr = 0, and we don't need to append the value to the `TrieData` segment. + DUP1 ISZERO %jumpi(set_payload_storage_leaf_end) + %get_trie_data_size + // stack: value_ptr, value, value_ptr_ptr, storage_ptr_ptr, retdest + SWAP1 + %append_to_trie_data +set_payload_storage_leaf_end: + // stack: value_ptr, value_ptr_ptr, storage_ptr_ptr, retdest + SWAP1 + %mstore_trie_data + // stack: storage_ptr_ptr, retdest + %add_const(@STORAGE_LINKED_LISTS_NODE_SIZE) // The next pointer is at distance `STORAGE_LINKED_LISTS_NODE_SIZE` + // stack: storage_ptr_ptr', retdest + SWAP1 + JUMP diff --git a/evm_arithmetization/src/cpu/kernel/asm/mpt/linked_list/linked_list.asm b/evm_arithmetization/src/cpu/kernel/asm/mpt/linked_list/linked_list.asm new file mode 100644 index 000000000..f48f33186 --- /dev/null +++ b/evm_arithmetization/src/cpu/kernel/asm/mpt/linked_list/linked_list.asm @@ -0,0 +1,916 @@ +/// Linked lists for accounts and storage slots. +/// The accounts linked list is stored in SEGMENT_ACCOUNTS_LINKED_LIST while the slots +/// are stored in SEGMENT_STORAGE_LINKED_LIST. The length of +/// the segments is stored in the associated global metadata. +/// Both arrays are stored in the kernel memory (context=0). +/// Searching and inserting is done by guessing the predecessor in the list. +/// If the address/storage key isn't found in the array, it is inserted +/// at the correct location. These linked lists are used to keep track of +/// inserted and deleted accounts/slots during the execution, so that the +/// initial and final MPT state tries can be reconstructed at the end of the execution. +/// An empty account linked list is written as +/// [@U256_MAX, _, _, @SEGMENT_ACCOUNTS_LINKED_LIST] in SEGMENT_ACCOUNTS_LINKED_LIST. +/// The linked list is preinitialized by appending accounts to the segment. Each account is encoded +/// using 4 values. +/// The values at the respective positions are: +/// - 0: The account key +/// - 1: A ptr to the payload (the account values) +/// - 2: A ptr to the initial payload. +/// - 3: A ptr (in segment @SEGMENT_ACCOUNTS_LINKED_LIST) to the next node in the list. +/// Similarly, an empty storage linked list is written as +/// [@U256_MAX, _, _, _, @SEGMENT_ACCOUNTS_LINKED_LIST] in SEGMENT_ACCOUNTS_LINKED_LIST. +/// The linked list is preinitialized by appending storage slots to the segment. +/// Each slot is encoded using 5 values. +/// The values at the respective positions are: +/// - 0: The account key +/// - 1: The slot key +/// - 2: The slot value. +/// - 3: The initial slot value. +/// - 4: A ptr (in segment @SEGMENT_ACCOUNTS_LINKED_LIST) to the next node in the list. + +%macro store_initial_accounts + PUSH %%after + %jump(store_initial_accounts) +%%after: +%endmacro + +/// Iterates over the initial account linked list and shallow copies +/// the accounts, storing a pointer to the copied account in the node. +/// Computes the length of `SEGMENT_ACCOUNTS_LINKED_LIST` and +/// stores it in `GLOBAL_METADATA_ACCOUNTS_LINKED_LIST_NEXT_AVAILABLE`. +global store_initial_accounts: + // stack: retdest + PUSH @ACCOUNTS_LINKED_LISTS_NODE_SIZE + PUSH @SEGMENT_ACCOUNTS_LINKED_LIST + ADD + // stack: cur_len, retdest + PUSH @SEGMENT_ACCOUNTS_LINKED_LIST + %next_account +loop_store_initial_accounts: + // stack: current_node_ptr, cur_len, retdest + %get_trie_data_size + DUP2 + MLOAD_GENERAL + // stack: current_addr_key, cpy_ptr, current_node_ptr, cur_len, retdest + %eq_const(@U256_MAX) + %jumpi(store_initial_accounts_end) + DUP2 + %increment + MLOAD_GENERAL + // stack: nonce_ptr, cpy_ptr, current_node_ptr, cur_len, retdest + DUP1 + %mload_trie_data // nonce + %append_to_trie_data + %increment + // stack: balance_ptr, cpy_ptr, current_node_ptr, cur_len, retdest + DUP1 + %mload_trie_data // balance + %append_to_trie_data + %increment // The storage_root_ptr is not really necessary + // stack: storage_root_ptr_ptr, cpy_ptr, current_node_ptr, cur_len, retdest + DUP1 + %mload_trie_data // storage_root_ptr + %append_to_trie_data + %increment + // stack: code_hash_ptr, cpy_ptr, current_node_ptr, cur_len, retdest + %mload_trie_data // code_hash + %append_to_trie_data + // stack: cpy_ptr, current_node_ptr, cur_len, retdest + DUP2 + %add_const(2) + SWAP1 + MSTORE_GENERAL // Store cpy_ptr + // stack: current_node_ptr, cur_len, retdest + SWAP1 PUSH @ACCOUNTS_LINKED_LISTS_NODE_SIZE + ADD + SWAP1 + // stack: current_node_ptr, cur_len', retdest + %next_account + %jump(loop_store_initial_accounts) + +store_initial_accounts_end: + %pop2 + // stack: cur_len, retdest + %mstore_global_metadata(@GLOBAL_METADATA_ACCOUNTS_LINKED_LIST_NEXT_AVAILABLE) + JUMP + +%macro insert_account_with_overwrite + %stack (addr_key, ptr) -> (addr_key, ptr, %%after) + %jump(insert_account_with_overwrite) +%%after: +%endmacro + +// Multiplies the value at the top of the stack, denoted by ptr/4, by 4 +// and aborts if ptr/4 <= mem[@GLOBAL_METADATA_ACCOUNTS_LINKED_LIST_NEXT_AVAILABLE]/4. +// Also checks that ptr >= @SEGMENT_ACCOUNTS_LINKED_LIST. +// This way, 4*ptr/4 must be pointing to the beginning of a node. +// TODO: Maybe we should check here if the node has been deleted. +%macro get_valid_account_ptr + // stack: ptr/4 + // Check that the pointer is greater than the segment. + PUSH @SEGMENT_ACCOUNTS_LINKED_LIST + DUP2 + %mul_const(4) + // stack: ptr, @SEGMENT_ACCOUNTS_LINKED_LIST, ptr/4 + %increment %assert_gt + // stack: ptr/4 + DUP1 + PUSH 4 + %mload_global_metadata(@GLOBAL_METADATA_ACCOUNTS_LINKED_LIST_NEXT_AVAILABLE) + // By construction, both @SEGMENT_ACCOUNTS_LINKED_LIST and the unscaled list len + // must be multiples of 4 + DIV + // stack: @SEGMENT_ACCOUNTS_LINKED_LIST/4 + accounts_linked_list_len/4, ptr/4, ptr/4 + %assert_gt + %mul_const(4) +%endmacro + +global insert_account_with_overwrite: + // stack: addr_key, payload_ptr, retdest + PROVER_INPUT(linked_list::insert_account) + // stack: pred_ptr/4, addr_key, payload_ptr, retdest + %get_valid_account_ptr + // stack: pred_ptr, addr_key, payload_ptr, retdest + DUP1 + MLOAD_GENERAL + DUP1 + // stack: pred_addr_key, pred_addr_key, pred_ptr, addr_key, payload_ptr, retdest + DUP4 GT + DUP3 %eq_const(@SEGMENT_ACCOUNTS_LINKED_LIST) + ADD // OR + // If the predesessor is strictly smaller or the predecessor is the special + // node with key @U256_MAX (and hence we're inserting a new minimum), then + // we need to insert a new node. + %jumpi(insert_new_account) + // stack: pred_addr_key, pred_ptr, addr_key, payload_ptr, retdest + // If we are here we know that addr <= pred_addr. But this is only possible if pred_addr == addr. + DUP3 + %assert_eq + + // stack: pred_ptr, addr_key, payload_ptr, retdest + // Check that this is not a deleted node + DUP1 + %add_const(@ACCOUNTS_NEXT_NODE_PTR) + MLOAD_GENERAL + %jump_neq_const(@U256_MAX, account_found_with_overwrite) + // The storage key is not in the list. + PANIC + +account_found_with_overwrite: + // The address was already in the list + // stack: pred_ptr, addr_key, payload_ptr, retdest + // Load the payload pointer + %increment + // stack: payload_ptr_ptr, addr_key, payload_ptr, retdest + DUP3 MSTORE_GENERAL + %pop2 + JUMP + +insert_new_account: + // stack: pred_addr_key, pred_ptr, addr_key, payload_ptr, retdest + POP + // get the value of the next address + %add_const(@ACCOUNTS_NEXT_NODE_PTR) + // stack: next_ptr_ptr, addr_key, payload_ptr, retdest + %mload_global_metadata(@GLOBAL_METADATA_ACCOUNTS_LINKED_LIST_NEXT_AVAILABLE) + DUP2 + MLOAD_GENERAL + // stack: next_ptr, new_ptr, next_ptr_ptr, addr_key, payload_ptr, retdest + // Check that this is not a deleted node + DUP1 + %eq_const(@U256_MAX) + %assert_zero + DUP1 + MLOAD_GENERAL + // stack: next_addr_key, next_ptr, new_ptr, next_ptr_ptr, addr_key, payload_ptr, retdest + DUP5 + // Here, (addr_key > pred_addr_key) || (pred_ptr == @SEGMENT_ACCOUNTS_LINKED_LIST). + // We should have (addr_key < next_addr_key), meaning the new value can be inserted between pred_ptr and next_ptr. + %assert_lt + // stack: next_ptr, new_ptr, next_ptr_ptr, addr_key, payload_ptr, retdest + SWAP2 + DUP2 + // stack: new_ptr, next_ptr_ptr, new_ptr, next_ptr, addr_key, payload_ptr, retdest + MSTORE_GENERAL + // stack: new_ptr, next_ptr, addr_key, payload_ptr, retdest + DUP1 + DUP4 + MSTORE_GENERAL + // stack: new_ptr, next_ptr, addr_key, payload_ptr, retdest + %increment + DUP1 + DUP5 + MSTORE_GENERAL + // stack: new_ptr + 1, next_ptr, addr_key, payload_ptr, retdest + %increment + DUP1 + DUP5 + %clone_account + MSTORE_GENERAL + %increment + DUP1 + // stack: new_next_ptr, new_next_ptr, next_ptr, addr_key, payload_ptr, retdest + SWAP2 + MSTORE_GENERAL + // stack: new_next_ptr, addr_key, payload_ptr, retdest + %increment + %mstore_global_metadata(@GLOBAL_METADATA_ACCOUNTS_LINKED_LIST_NEXT_AVAILABLE) + // stack: addr_key, payload_ptr, retdest + %pop2 + JUMP + + +/// Searches the account addr in the linked list. +/// Returns 0 if the account was not found or `original_ptr` if it was already present. +global search_account: + // stack: addr_key, retdest + PROVER_INPUT(linked_list::insert_account) + // stack: pred_ptr/4, addr_key, retdest + %get_valid_account_ptr + // stack: pred_ptr, addr_key, retdest + DUP1 + MLOAD_GENERAL + DUP1 + // stack: pred_addr_key, pred_addr_key, pred_ptr, addr_key, retdest + DUP4 GT + DUP3 %eq_const(@SEGMENT_ACCOUNTS_LINKED_LIST) + ADD // OR + // If the predesessor is strictly smaller or the predecessor is the special + // node with key @U256_MAX (and hence we're inserting a new minimum), then + // we need to insert a new node. + %jumpi(account_not_found) + // stack: pred_addr_key, pred_ptr, addr_key, retdest + // If we are here we know that addr_key <= pred_addr_key. But this is only possible if pred_addr == addr. + DUP3 + %assert_eq + + // stack: pred_ptr, addr_key, retdest + // Check that this is not a deleted node + DUP1 + %add_const(@ACCOUNTS_NEXT_NODE_PTR) + MLOAD_GENERAL + %jump_neq_const(@U256_MAX, account_found) + // The storage key is not in the list. + PANIC + +account_found: + // The address was already in the list + // stack: pred_ptr, addr_key, retdest + // Load the payload pointer + %increment + MLOAD_GENERAL + // stack: orig_payload_ptr, addr_key, retdest + %stack (orig_payload_ptr, addr_key, retdest) -> (retdest, orig_payload_ptr) + JUMP + +account_not_found: + // stack: pred_addr_key, pred_ptr, addr_key, retdest + %stack (pred_addr_key, pred_ptr, addr_key, retdest) -> (retdest, 0) + JUMP + +%macro remove_account_from_linked_list + PUSH %%after + SWAP1 + %jump(remove_account) +%%after: +%endmacro + +/// Removes the address and its value from the access list. +/// Panics if the key is not in the list. +global remove_account: + // stack: addr_key, retdest + PROVER_INPUT(linked_list::remove_account) + // stack: pred_ptr/4, addr_key, retdest + %get_valid_account_ptr + // stack: pred_ptr, addr_key, retdest + %add_const(@ACCOUNTS_NEXT_NODE_PTR) + // stack: next_ptr_ptr, addr_key, retdest + DUP1 + MLOAD_GENERAL + // stack: next_ptr, next_ptr_ptr, addr_key, retdest + DUP1 + MLOAD_GENERAL + // stack: next_addr_key, next_ptr, next_ptr_ptr, addr_key, retdest + DUP4 + %assert_eq + // stack: next_ptr, next_ptr_ptr, addr_key, retdest + %add_const(@ACCOUNTS_NEXT_NODE_PTR) + // stack: next_next_ptr_ptr, next_ptr_ptr, addr_key, key, retdest + DUP1 + MLOAD_GENERAL + // stack: next_next_ptr, next_next_ptr_ptr, next_ptr_ptr, addr_key, retdest + SWAP1 + %mstore_u256_max + // stack: next_next_ptr, next_ptr_ptr, addr_key, retdest + MSTORE_GENERAL + POP + JUMP + + +// +// +// STORAGE linked list +// +// + +%macro store_initial_slots + PUSH %%after + %jump(store_initial_slots) +%%after: +%endmacro + + +/// Iterates over the initial account linked list and shallow copies +/// the accounts, storing a pointer to the copied account in the node. +/// Computes the length of `SEGMENT_STORAGE_LINKED_LIST` and +/// checks against `GLOBAL_METADATA_STORAGE_LINKED_LIST_NEXT_AVAILABLE`. +global store_initial_slots: + // stack: retdest + PUSH @STORAGE_LINKED_LISTS_NODE_SIZE + PUSH @SEGMENT_STORAGE_LINKED_LIST + ADD + // stack: cur_len, retdest + PUSH @SEGMENT_STORAGE_LINKED_LIST + %next_slot + +loop_store_initial_slots: + // stack: current_node_ptr, cur_len, retdest + DUP1 + MLOAD_GENERAL + // stack: current_addr_key, current_node_ptr, cur_len, retdest + %eq_const(@U256_MAX) + %jumpi(store_initial_slots_end) + DUP1 + %add_const(2) + MLOAD_GENERAL + // stack: value, current_node_ptr, cur_len, retdest + DUP2 + %add_const(@STORAGE_COPY_PAYLOAD_PTR) + // stack: cpy_value_ptr, value, current_node_ptr, cur_len, retdest + SWAP1 + MSTORE_GENERAL // Store cpy_value + // stack: current_node_ptr, cur_len, retdest + SWAP1 PUSH @STORAGE_LINKED_LISTS_NODE_SIZE + ADD + SWAP1 + // stack: current_node_ptr, cur_len', retdest + %next_slot + %jump(loop_store_initial_slots) + +store_initial_slots_end: + POP + // stack: cur_len, retdest + %mstore_global_metadata(@GLOBAL_METADATA_STORAGE_LINKED_LIST_NEXT_AVAILABLE) + JUMP + + +%macro insert_slot + %stack (addr_key, key, ptr) -> (addr_key, key, ptr, %%after) + %jump(insert_slot) +%%after: + // stack: value_ptr +%endmacro + +%macro insert_slot_no_return + %insert_slot +%endmacro + +// Multiplies the value at the top of the stack, denoted by ptr/5, by 5 +// and aborts if ptr/5 >= (mem[@GLOBAL_METADATA_ACCOUNTS_LINKED_LIST_NEXT_AVAILABLE] - @SEGMENT_STORAGE_LINKED_LIST)/5. +// This way, @SEGMENT_STORAGE_LINKED_LIST + 5*ptr/5 must be pointing to the beginning of a node. +// TODO: Maybe we should check here if the node has been deleted. +%macro get_valid_slot_ptr + // stack: ptr/5 + DUP1 + PUSH 5 + PUSH @SEGMENT_STORAGE_LINKED_LIST + // stack: segment, 5, ptr/5, ptr/5 + %mload_global_metadata(@GLOBAL_METADATA_STORAGE_LINKED_LIST_NEXT_AVAILABLE) + SUB + // stack: accessed_strg_keys_len, 5, ptr/5, ptr/5 + // By construction, the unscaled list len must be multiple of 5 + DIV + // stack: accessed_strg_keys_len/5, ptr/5, ptr/5 + %assert_gt + %mul_const(5) + %add_const(@SEGMENT_STORAGE_LINKED_LIST) +%endmacro + +/// Inserts the pair (address_key, storage_key) and a new payload pointer into the linked list if it is not already present, +/// or modifies its payload if it was already present. +global insert_slot_with_value: + // stack: addr_key, key, value, retdest + PROVER_INPUT(linked_list::insert_slot) + // stack: pred_ptr/5, addr_key, key, value, retdest + %get_valid_slot_ptr + + // stack: pred_ptr, addr_key, key, value, retdest + DUP1 + MLOAD_GENERAL + DUP1 + // stack: pred_addr_key, pred_addr_key, pred_ptr, addr_key, key, value, retdest + DUP4 + GT + DUP3 %eq_const(@SEGMENT_STORAGE_LINKED_LIST) + ADD // OR + // If the predesessor is strictly smaller or the predecessor is the special + // node with key @U256_MAX (and hence we're inserting a new minimum), then + // we need to insert a new node. + %jumpi(insert_new_slot_with_value) + // stack: pred_addr_key, pred_ptr, addr_key, key, payload_ptr, retdest + // If we are here we know that addr <= pred_addr. But this is only possible if pred_addr == addr. + DUP3 + %assert_eq + // stack: pred_ptr, addr_key, key, value, retdest + DUP1 + %increment + MLOAD_GENERAL + // stack: pred_key, pred_ptr, addr_key, key, value, retdest + DUP1 DUP5 + GT + %jumpi(insert_new_slot_with_value) + // stack: pred_key, pred_ptr, addr_key, key, value, retdest + DUP4 + // We know that key <= pred_key. It must hold that pred_key == key. + %assert_eq + + // stack: pred_ptr, addr_key, key, value, retdest + // Check that this is not a deleted node + DUP1 + %add_const(@STORAGE_NEXT_NODE_PTR) + MLOAD_GENERAL + %jump_neq_const(@U256_MAX, slot_found_write_value) + // The storage key is not in the list. + PANIC + +insert_new_slot_with_value: + // stack: pred_addr or pred_key, pred_ptr, addr_key, key, value, retdest + POP + // get the value of the next address + %add_const(@STORAGE_NEXT_NODE_PTR) + // stack: next_ptr_ptr, addr_key, key, value, retdest + %mload_global_metadata(@GLOBAL_METADATA_STORAGE_LINKED_LIST_NEXT_AVAILABLE) + DUP2 + MLOAD_GENERAL + // stack: next_ptr, new_ptr, next_ptr_ptr, addr_key, key, value, retdest + // Check that this is not a deleted node + DUP1 + %eq_const(@U256_MAX) + %assert_zero + DUP1 + MLOAD_GENERAL + // stack: next_addr_key, next_ptr, new_ptr, next_ptr_ptr, addr_key, key, value, retdest + DUP1 + DUP6 + // Here, (addr_key > pred_addr_key) || (pred_ptr == @SEGMENT_ACCOUNTS_LINKED_LIST). + // We should have (addr_key < next_addr_key), meaning the new value can be inserted between pred_ptr and next_ptr. + LT + %jumpi(next_node_ok_with_value) + // If addr_key <= next_addr_key, then it addr must be equal to next_addr + // stack: next_addr_key, next_ptr, new_ptr, next_ptr_ptr, addr_key, key, value, retdest + DUP5 + %assert_eq + // stack: next_ptr, new_ptr, next_ptr_ptr, addr_key, key, value, retdest + DUP1 + %increment + MLOAD_GENERAL + // stack: next_key, next_ptr, new_ptr, next_ptr_ptr, addr_key, key, value, retdest + DUP1 // This is added just to have the correct stack in next_node_ok + DUP7 + // The next key must be strictly larger + %assert_lt + +next_node_ok_with_value: + // stack: next_addr or next_key, next_ptr, new_ptr, next_ptr_ptr, addr_key, key, value, retdest + POP + // stack: next_ptr, new_ptr, next_ptr_ptr, addr_key, key, value, retdest + SWAP2 + DUP2 + // stack: new_ptr, next_ptr_ptr, new_ptr, next_ptr, addr_key, key, value, retdest + MSTORE_GENERAL + // stack: new_ptr, next_ptr, addr_key, key, value, retdest + // Write the address in the new node + DUP1 + DUP4 + MSTORE_GENERAL + // stack: new_ptr, next_ptr, addr_key, key, value, retdest + // Write the key in the new node + %increment + DUP1 + DUP5 + MSTORE_GENERAL + // stack: new_ptr + 1, next_ptr, addr_key, key, value, retdest + // Write the value in the linked list. + %increment + DUP1 %increment + // stack: new_ptr+3, new_value_ptr, next_ptr, addr_key, key, value, retdest + %stack (new_cloned_value_ptr, new_value_ptr, next_ptr, addr_key, key, value, retdest) + -> (value, new_cloned_value_ptr, value, new_value_ptr, new_cloned_value_ptr, next_ptr, retdest) + MSTORE_GENERAL // Store copied value. + MSTORE_GENERAL // Store value. + + // stack: new_ptr + 3, next_ptr, retdest + %increment + DUP1 + // stack: new_next_ptr_ptr, new_next_ptr_ptr, next_ptr, retdest + SWAP2 + MSTORE_GENERAL + // stack: new_next_ptr_ptr, retdest + %increment + %mstore_global_metadata(@GLOBAL_METADATA_STORAGE_LINKED_LIST_NEXT_AVAILABLE) + // stack: retdest + JUMP + +slot_found_write_value: + // stack: pred_ptr, addr_key, key, value, retdest + %add_const(2) + %stack (payload_ptr, addr_key, key, value) -> (value, payload_ptr) + MSTORE_GENERAL + // stack: retdest + JUMP + +%macro insert_slot_with_value + // stack: addr, slot, value + %addr_to_state_key + SWAP1 + %slot_to_storage_key + %stack (slot_key, addr_key, value) -> (addr_key, slot_key, value, %%after) + %jump(insert_slot_with_value) +%%after: + // stack: (empty) +%endmacro + +%macro insert_slot_with_value_from_keys + // stack: addr_key, slot_key, value + %stack (addr_key, slot_key, value) -> (addr_key, slot_key, value, %%after) + %jump(insert_slot_with_value) +%%after: + // stack: (empty) +%endmacro + +/// Inserts the pair (address_key, storage_key) and payload pointer into the linked list if it is not already present, +/// or modifies its payload if it was already present. +/// Returns `payload_ptr` if the storage key was inserted, `original_ptr` if it was already present. +global insert_slot: + // stack: addr_key, key, payload_ptr, retdest + PROVER_INPUT(linked_list::insert_slot) + // stack: pred_ptr/5, addr_key, key, payload_ptr, retdest + %get_valid_slot_ptr + + // stack: pred_ptr, addr_key, key, payload_ptr, retdest + DUP1 + MLOAD_GENERAL + DUP1 + // stack: pred_addr_key, pred_addr_key, pred_ptr, addr_key, key, payload_ptr, retdest + DUP4 + GT + DUP3 %eq_const(@SEGMENT_STORAGE_LINKED_LIST) + ADD // OR + // If the predesessor is strictly smaller or the predecessor is the special + // node with key @U256_MAX (and hence we're inserting a new minimum), then + // we need to insert a new node. + %jumpi(insert_new_slot) + // stack: pred_addr_key, pred_ptr, addr_key, key, payload_ptr, retdest + // If we are here we know that addr <= pred_addr. But this is only possible if pred_addr == addr. + DUP3 + %assert_eq + // stack: pred_ptr, addr_key, key, payload_ptr, retdest + DUP1 + %increment + MLOAD_GENERAL + // stack: pred_key, pred_ptr, addr_key, key, payload_ptr, retdest + DUP1 DUP5 + GT + %jumpi(insert_new_slot) + // stack: pred_key, pred_ptr, addr_key, key, payload_ptr, retdest + DUP4 + // We know that key <= pred_key. It must hold that pred_key == key. + %assert_eq + // stack: pred_ptr, addr_key, key, payload_ptr, retdest + + // stack: pred_ptr, addr_key, key, payload_ptr, retdest + // Check that this is not a deleted node + DUP1 + %add_const(@STORAGE_NEXT_NODE_PTR) + MLOAD_GENERAL + %jump_neq_const(@U256_MAX, slot_found_write) + // The storage key is not in the list. + PANIC + +slot_found_write: + // The slot was already in the list + // stack: pred_ptr, addr_key, key, payload_ptr, retdest + // Load the the payload pointer and access counter + %add_const(2) + DUP1 + MLOAD_GENERAL + // stack: orig_payload_ptr, pred_ptr + 2, addr_key, key, payload_ptr, retdest + SWAP1 + DUP5 + MSTORE_GENERAL // Store the new payload + %stack (orig_payload_ptr, addr_key, key, payload_ptr, retdest) -> (retdest, orig_payload_ptr) + JUMP +insert_new_slot: + // stack: pred_addr or pred_key, pred_ptr, addr_key, key, payload_ptr, retdest + POP + // get the value of the next address + %add_const(@STORAGE_NEXT_NODE_PTR) + // stack: next_ptr_ptr, addr_key, key, payload_ptr, retdest + %mload_global_metadata(@GLOBAL_METADATA_STORAGE_LINKED_LIST_NEXT_AVAILABLE) + DUP2 + MLOAD_GENERAL + // stack: next_ptr, new_ptr, next_ptr_ptr, addr_key, key, payload_ptr, retdest + // Check that this is not a deleted node + DUP1 + %eq_const(@U256_MAX) + %assert_zero + DUP1 + MLOAD_GENERAL + // stack: next_addr_key, next_ptr, new_ptr, next_ptr_ptr, addr_key, key, payload_ptr, retdest + DUP1 + DUP6 + // Here, (addr_key > pred_addr_key) || (pred_ptr == @SEGMENT_ACCOUNTS_LINKED_LIST). + // We should have (addr_key < next_addr_key), meaning the new value can be inserted between pred_ptr and next_ptr. + LT + %jumpi(next_node_ok) + // If addr_key <= next_addr_key, then it addr must be equal to next_addr + // stack: next_addr_key, next_ptr, new_ptr, next_ptr_ptr, addr_key, key, payload_ptr, retdest + DUP5 + %assert_eq + // stack: next_ptr, new_ptr, next_ptr_ptr, addr_key, key, payload_ptr, retdest + DUP1 + %increment + MLOAD_GENERAL + // stack: next_key, next_ptr, new_ptr, next_ptr_ptr, addr_key, key, payload_ptr, retdest + DUP1 // This is added just to have the correct stack in next_node_ok + DUP7 + // The next key must be strictly larger + %assert_lt +next_node_ok: + // stack: next_addr or next_key, next_ptr, new_ptr, next_ptr_ptr, addr_key, key, payload_ptr, retdest + POP + // stack: next_ptr, new_ptr, next_ptr_ptr, addr_key, key, payload_ptr, retdest + SWAP2 + DUP2 + // stack: new_ptr, next_ptr_ptr, new_ptr, next_ptr, addr_key, key, payload_ptr, retdest + MSTORE_GENERAL + // stack: new_ptr, next_ptr, addr_key, key, payload_ptr, retdest + // Write the address in the new node + DUP1 + DUP4 + MSTORE_GENERAL + // stack: new_ptr, next_ptr, addr_key, key, payload_ptr, retdest + // Write the key in the new node + %increment + DUP1 + DUP5 + MSTORE_GENERAL + // stack: new_ptr + 1, next_ptr, addr_key, key, payload_ptr, retdest + // Store payload_ptr + %increment + DUP1 + DUP6 + MSTORE_GENERAL + + // stack: new_ptr + 2, next_ptr, addr_key, key, payload_ptr, retdest + // Store the copy of payload_ptr + %increment + DUP1 + DUP6 + %clone_slot + MSTORE_GENERAL + // stack: new_ptr + 3, next_ptr, addr_key, key, payload_ptr, retdest + %increment + DUP1 + // stack: new_next_ptr, new_next_ptr, next_ptr, addr_key, key, payload_ptr, retdest + SWAP2 + MSTORE_GENERAL + // stack: new_next_ptr, addr_key, key, payload_ptr, retdest + %increment + %mstore_global_metadata(@GLOBAL_METADATA_STORAGE_LINKED_LIST_NEXT_AVAILABLE) + // stack: addr_key, key, payload_ptr, retdest + %stack (addr_key, key, payload_ptr, retdest) -> (retdest, payload_ptr) + JUMP + +/// Searches the pair (address_key, storage_key) in the storage the linked list. +/// Returns `payload_ptr` if the storage key was inserted, `original_ptr` if it was already present. +global search_slot: + // stack: addr_key, key, payload_ptr, retdest + PROVER_INPUT(linked_list::insert_slot) + // stack: pred_ptr/5, addr_key, key, payload_ptr, retdest + %get_valid_slot_ptr + + // stack: pred_ptr, addr_key, key, payload_ptr, retdest + DUP1 + MLOAD_GENERAL + DUP1 + // stack: pred_addr_key, pred_addr_key, pred_ptr, addr_key, key, payload_ptr, retdest + DUP4 + GT + DUP3 %eq_const(@SEGMENT_STORAGE_LINKED_LIST) + ADD // OR + // If the predesessor is strictly smaller or the predecessor is the special + // node with key @U256_MAX (and hence we're inserting a new minimum), then + // the slot was not found + %jumpi(slot_not_found) + // stack: pred_addr_key, pred_ptr, addr_key, key, payload_ptr, retdest + // If we are here we know that addr <= pred_addr. But this is only possible if pred_addr == addr. + DUP3 + %assert_eq + // stack: pred_ptr, addr_key, key, payload_ptr, retdest + DUP1 + %increment + MLOAD_GENERAL + // stack: pred_key, pred_ptr, addr_key, key, payload_ptr, retdest + DUP1 DUP5 + GT + %jumpi(slot_not_found) + // stack: pred_key, pred_ptr, addr_key, key, payload_ptr, retdest + DUP4 + // We know that key <= pred_key. It must hold that pred_key == key. + %assert_eq + // stack: pred_ptr, addr_key, key, payload_ptr, retdest + + // stack: pred_ptr, addr_key, key, payload_ptr, retdest + // Check that this is not a deleted node + DUP1 + %add_const(@STORAGE_NEXT_NODE_PTR) + MLOAD_GENERAL + %jump_neq_const(@U256_MAX, slot_found_no_write) + // The storage key is not in the list. + PANIC +slot_not_found: + // stack: pred_addr_or_pred_key, pred_ptr, addr_key, key, payload_ptr, retdest + %stack (pred_addr_or_pred_key, pred_ptr, addr_key, key, payload_ptr, retdest) + -> (retdest, payload_ptr) + JUMP + +slot_found_no_write: + // The slot was already in the list + // stack: pred_ptr, addr_key, key, payload_ptr, retdest + // Load the the payload pointer and access counter + %add_const(2) + MLOAD_GENERAL + // stack: orig_value, addr_key, key, payload_ptr, retdest + %stack (orig_value, addr_key, key, payload_ptr, retdest) -> (retdest, orig_value) + JUMP + +%macro search_slot + // stack: state_key, storage_key, ptr + %stack (state_key, storage_key, ptr) -> (state_key, storage_key, ptr, %%after) + %jump(search_slot) +%%after: + // stack: ptr +%endmacro + +%macro remove_slot + %stack (key, addr_key) -> (addr_key, key, %%after) + %jump(remove_slot) +%%after: +%endmacro + +/// Removes the storage key and its value from the list. +/// Panics if the key is not in the list. +global remove_slot: + // stack: addr_key, key, retdest + PROVER_INPUT(linked_list::remove_slot) + // stack: pred_ptr/5, addr_key, key, retdest + %get_valid_slot_ptr + // stack: pred_ptr, addr_key, key, retdest + %add_const(@STORAGE_NEXT_NODE_PTR) + // stack: next_ptr_ptr, addr_key, key, retdest + DUP1 + MLOAD_GENERAL + // stack: next_ptr, next_ptr_ptr, addr_key, key, retdest + DUP1 + MLOAD_GENERAL + // stack: next_addr_key, next_ptr, next_ptr_ptr, addr_key, key, retdest + DUP4 + %assert_eq + // stack: next_ptr, next_ptr_ptr, addr_key, key, retdest + DUP1 + %increment + MLOAD_GENERAL + // stack: next_key, next_ptr, next_ptr_ptr, addr_key, key, retdest + DUP5 + %assert_eq + // stack: next_ptr, next_ptr_ptr, addr_key, key, retdest + %add_const(@STORAGE_NEXT_NODE_PTR) + // stack: next_next_ptr_ptr, next_ptr_ptr, addr_key, key, retdest + DUP1 + MLOAD_GENERAL + // stack: next_next_ptr, next_next_ptr_ptr, next_ptr_ptr, addr_key, key, retdest + // Mark the next node as deleted + SWAP1 + %mstore_u256_max + // stack: next_next_ptr, next_ptr_ptr, addr_key, key, retdest + MSTORE_GENERAL + %pop2 + JUMP + +/// Called when an account is deleted: it deletes all slots associated with the account. +global remove_all_account_slots: + // stack: addr_key, retdest + PROVER_INPUT(linked_list::remove_address_slots) + // pred_ptr/5, retdest + %get_valid_slot_ptr + // stack: pred_ptr, addr_key, retdest + // First, check that the previous address is not `addr` + DUP1 MLOAD_GENERAL + // stack: pred_addr_key, pred_ptr, addr_key, retdest + DUP3 EQ %jumpi(panic) + // stack: pred_ptr, addr_key, retdest + DUP1 + +// Now, while the next address is `addr`, remove the next slot. +remove_all_slots_loop: + // stack: pred_ptr, pred_ptr, addr_key, retdest + %add_const(@STORAGE_NEXT_NODE_PTR) DUP1 MLOAD_GENERAL + // stack: cur_ptr, cur_ptr_ptr, pred_ptr, addr_key, retdest + DUP1 %eq_const(@U256_MAX) %jumpi(remove_all_slots_end) + DUP1 %add_const(@STORAGE_NEXT_NODE_PTR) MLOAD_GENERAL + // stack: next_ptr, cur_ptr, cur_ptr_ptr, pred_ptr, addr_key, retdest + SWAP1 DUP1 + // stack: cur_ptr, cur_ptr, next_ptr, cur_ptr_ptr, pred_ptr, addr_key, retdest + MLOAD_GENERAL + DUP6 EQ ISZERO %jumpi(remove_all_slots_pop_and_end) + + // Remove slot: update the value in cur_ptr_ptr, and set cur_ptr+4 to @U256_MAX. + // stack: cur_ptr, next_ptr, cur_ptr_ptr, pred_ptr, addr_key, retdest + SWAP2 SWAP1 + // stack: next_ptr, cur_ptr_ptr, cur_ptr, pred_ptr, addr_key, retdest + MSTORE_GENERAL + // stack: cur_ptr, pred_ptr, addr_key, retdest + %add_const(@STORAGE_NEXT_NODE_PTR) + %mstore_u256_max + // stack: pred_ptr, addr_key, retdest + DUP1 + %jump(remove_all_slots_loop) + +remove_all_slots_pop_and_end: + POP +remove_all_slots_end: + // stack: next_ptr, cur_ptr_ptr, pred_ptr, addr_key, retdest + %pop4 JUMP + +%macro remove_all_account_slots + %stack (addr_key) -> (addr_key, %%after) + %jump(remove_all_account_slots) +%%after: +%endmacro + +%macro read_accounts_linked_list + %stack (addr) -> (addr, %%after) + %addr_to_state_key + %jump(search_account) +%%after: + // stack: account_ptr +%endmacro + +%macro read_storage_linked_list + // stack: slot + %slot_to_storage_key + %address + %addr_to_state_key + %stack (addr_key, key) -> (addr_key, key, 0, %%after) + %jump(search_slot) +%%after: + // stack: slot_ptr +%endmacro + +%macro read_storage_linked_list_w_addr + // stack: slot, address + %slot_to_storage_key + SWAP1 + %addr_to_state_key + %stack (addr_key, key) -> (addr_key, key, 0, %%after) + %jump(search_slot) +%%after: + // stack: slot_ptr +%endmacro + +%macro first_account + // stack: empty + PUSH @SEGMENT_ACCOUNTS_LINKED_LIST + %next_account +%endmacro + +%macro next_account + // stack: node_ptr + %add_const(@ACCOUNTS_NEXT_NODE_PTR) + MLOAD_GENERAL + // stack: next_node_ptr +%endmacro + +%macro first_slot + // stack: empty + PUSH @SEGMENT_STORAGE_LINKED_LIST + %next_slot +%endmacro + +%macro next_slot + // stack: node_ptr + %add_const(@STORAGE_NEXT_NODE_PTR) + MLOAD_GENERAL + // stack: next_node_ptr +%endmacro diff --git a/evm_arithmetization/src/cpu/kernel/asm/mpt/read.asm b/evm_arithmetization/src/cpu/kernel/asm/mpt/read.asm index 3741049fe..148a7897d 100644 --- a/evm_arithmetization/src/cpu/kernel/asm/mpt/read.asm +++ b/evm_arithmetization/src/cpu/kernel/asm/mpt/read.asm @@ -1,15 +1,13 @@ // Given an address, return a pointer to the associated account data, which // consists of four words (nonce, balance, storage_root, code_hash), in the -// state trie. Returns null if the address is not found. +// trie_data segment. Return null if the address is not found. global mpt_read_state_trie: // stack: addr, retdest - %addr_to_state_key -global mpt_read_state_trie_from_key: - // stack: key, retdest - PUSH 64 // num_nibbles - %mload_global_metadata(@GLOBAL_METADATA_STATE_TRIE_ROOT) // node_ptr - // stack: node_ptr, num_nibbles, key, retdest - %jump(mpt_read) + %read_accounts_linked_list + // stack: account_ptr, retdest + SWAP1 + // stack: retdest, account_ptr + JUMP // Convenience macro to call mpt_read_state_trie and return where we left off. %macro mpt_read_state_trie @@ -18,13 +16,6 @@ global mpt_read_state_trie_from_key: %%after: %endmacro -// Convenience macro to call mpt_read_state_trie_from_key and return where we left off. -%macro mpt_read_state_trie_from_key - %stack (key) -> (key, %%after) - %jump(mpt_read_state_trie_from_key) -%%after: -%endmacro - // Read a value from a MPT. // // Arguments: diff --git a/evm_arithmetization/src/cpu/kernel/asm/mpt/storage/storage_read.asm b/evm_arithmetization/src/cpu/kernel/asm/mpt/storage/storage_read.asm index db9fe4222..d4a7ca36a 100644 --- a/evm_arithmetization/src/cpu/kernel/asm/mpt/storage/storage_read.asm +++ b/evm_arithmetization/src/cpu/kernel/asm/mpt/storage/storage_read.asm @@ -5,15 +5,7 @@ %endmacro global sload_current: - %stack (slot) -> (slot, after_storage_read) - %slot_to_storage_key - // stack: storage_key, after_storage_read - PUSH 64 // storage_key has 64 nibbles - %current_storage_trie - // stack: storage_root_ptr, 64, storage_key, after_storage_read - %jump(mpt_read) - -global after_storage_read: + %read_storage_linked_list // stack: value_ptr, retdest DUP1 %jumpi(storage_key_exists) @@ -22,13 +14,6 @@ global after_storage_read: %stack (value_ptr, retdest) -> (retdest, 0) JUMP -global storage_key_exists: - // stack: value_ptr, retdest - %mload_trie_data - // stack: value, retdest - SWAP1 - JUMP - // Read a word from the current account's storage trie. // // Pre stack: kexit_info, slot diff --git a/evm_arithmetization/src/cpu/kernel/asm/mpt/storage/storage_write.asm b/evm_arithmetization/src/cpu/kernel/asm/mpt/storage/storage_write.asm index 22c5d29de..589c44094 100644 --- a/evm_arithmetization/src/cpu/kernel/asm/mpt/storage/storage_write.asm +++ b/evm_arithmetization/src/cpu/kernel/asm/mpt/storage/storage_write.asm @@ -111,33 +111,11 @@ sstore_after_refund: // stack: slot, value, kexit_info DUP2 ISZERO %jumpi(sstore_delete) - // First we write the value to MPT data, and get a pointer to it. - %get_trie_data_size - // stack: value_ptr, slot, value, kexit_info - SWAP2 - // stack: value, slot, value_ptr, kexit_info - %append_to_trie_data - // stack: slot, value_ptr, kexit_info - - // Next, call mpt_insert on the current account's storage root. - %stack (slot, value_ptr) -> (slot, value_ptr, after_storage_insert) - %slot_to_storage_key - // stack: storage_key, value_ptr, after_storage_insert, kexit_info - PUSH 64 // storage_key has 64 nibbles - %current_storage_trie - // stack: storage_root_ptr, 64, storage_key, value_ptr, after_storage_insert, kexit_info - %jump(mpt_insert) - -after_storage_insert: - // stack: new_storage_root_ptr, kexit_info - %current_account_data - // stack: account_ptr, new_storage_root_ptr, kexit_info - - // Update the copied account with our new storage root pointer. - %add_const(2) - // stack: account_storage_root_ptr_ptr, new_storage_root_ptr, kexit_info - %mstore_trie_data - // stack: kexit_info + + // stack: slot, value, kexit_info + %address + %insert_slot_with_value + EXIT_KERNEL sstore_noop: @@ -148,12 +126,11 @@ sstore_noop: // Delete the slot from the storage trie. sstore_delete: // stack: slot, value, kexit_info - SWAP1 POP - PUSH after_storage_insert SWAP1 - // stack: slot, after_storage_insert, kexit_info + %address + %addr_to_state_key + // stack: addr_key, slot, value, kexit_info + SWAP2 POP + // stack: slot, addr_key, kexit_info %slot_to_storage_key - // stack: storage_key, after_storage_insert, kexit_info - PUSH 64 // storage_key has 64 nibbles - %current_storage_trie - // stack: storage_root_ptr, 64, storage_key, after_storage_insert, kexit_info - %jump(mpt_delete) + %remove_slot + EXIT_KERNEL diff --git a/evm_arithmetization/src/cpu/kernel/asm/rlp/increment_bounded_rlp.asm b/evm_arithmetization/src/cpu/kernel/asm/rlp/increment_bounded_rlp.asm index 6958cff9f..2f5f07500 100644 --- a/evm_arithmetization/src/cpu/kernel/asm/rlp/increment_bounded_rlp.asm +++ b/evm_arithmetization/src/cpu/kernel/asm/rlp/increment_bounded_rlp.asm @@ -2,8 +2,8 @@ // its number of nibbles when required. Shouldn't be // called with rlp_index > 0x82 ff ff global increment_bounded_rlp: - // stack: num_nibbles, rlp_index, retdest - DUP2 + // stack: rlp_index, num_nibbles, retdest + DUP1 %eq_const(0x80) %jumpi(case_0x80) DUP1 @@ -14,19 +14,19 @@ global increment_bounded_rlp: %jumpi(case_0x81ff) // If rlp_index != 0x80 and rlp_index != 0x7f and rlp_index != 0x81ff // we only need to add one and keep the number of nibbles - DUP2 %increment DUP2 - %stack (next_num_nibbles, next_rlp_index, num_nibbles, rlp_index, retdest) -> (retdest, rlp_index, num_nibbles, next_rlp_index, next_num_nibbles) + %increment + %stack (next_rlp_index, next_num_nibbles, retdest) -> (retdest, next_rlp_index, next_num_nibbles) JUMP case_0x80: - %stack (num_nibbles, rlp_index, retdest) -> (retdest, 0x80, 2, 0x01, 2) + %stack (num_nibbles, rlp_index, retdest) -> (retdest, 0x01, 2) JUMP case_0x7f: - %stack (num_nibbles, rlp_index, retdest) -> (retdest, 0x7f, 2, 0x8180, 4) + %stack (num_nibbles, rlp_index, retdest) -> (retdest, 0x8180, 4) JUMP case_0x81ff: - %stack (num_nibbles, rlp_index, retdest) -> (retdest, 0x81ff, 4, 0x820100, 6) + %stack (num_nibbles, rlp_index, retdest) -> (retdest, 0x820100, 6) JUMP diff --git a/evm_arithmetization/src/cpu/kernel/asm/transactions/common_decoding.asm b/evm_arithmetization/src/cpu/kernel/asm/transactions/common_decoding.asm index 613131c3e..5109f29af 100644 --- a/evm_arithmetization/src/cpu/kernel/asm/transactions/common_decoding.asm +++ b/evm_arithmetization/src/cpu/kernel/asm/transactions/common_decoding.asm @@ -305,24 +305,17 @@ after_read: POP JUMP - sload_with_addr: - %stack (slot, addr) -> (slot, addr, after_storage_read) - %slot_to_storage_key - // stack: storage_key, addr, after_storage_read - PUSH 64 // storage_key has 64 nibbles - %stack (n64, storage_key, addr, after_storage_read) -> (addr, n64, storage_key, after_storage_read) - %mpt_read_state_trie - // stack: account_ptr, 64, storage_key, after_storage_read - DUP1 ISZERO %jumpi(ret_zero) // TODO: Fix this. This should never happen. - // stack: account_ptr, 64, storage_key, after_storage_read - %add_const(2) - // stack: storage_root_ptr_ptr - %mload_trie_data - // stack: storage_root_ptr, 64, storage_key, after_storage_read - %jump(mpt_read) - -ret_zero: - // stack: account_ptr, 64, storage_key, after_storage_read, retdest - %pop4 - PUSH 0 SWAP1 JUMP + %read_storage_linked_list_w_addr + // stack: value_ptr, retdest + DUP1 %jumpi(storage_key_exists) + // Storage key not found. Return default value_ptr = 0, + // which derefs to 0 since @SEGMENT_TRIE_DATA[0] = 0. + %stack (value, retdest) -> (retdest, 0) + + JUMP + +global storage_key_exists: + // stack: value, retdest + SWAP1 + JUMP diff --git a/evm_arithmetization/src/cpu/kernel/asm/transactions/type_3.asm b/evm_arithmetization/src/cpu/kernel/asm/transactions/type_3.asm index 286ff6b06..e63bb7ca1 100644 --- a/evm_arithmetization/src/cpu/kernel/asm/transactions/type_3.asm +++ b/evm_arithmetization/src/cpu/kernel/asm/transactions/type_3.asm @@ -184,3 +184,15 @@ store_origin: DIV // stack: len %endmacro + +%macro reset_blob_versioned_hashes + // stack: (empty) + // Reset the stored hashes + %mload_global_metadata(@GLOBAL_METADATA_BLOB_VERSIONED_HASHES_LEN) + PUSH @SEGMENT_TXN_BLOB_VERSIONED_HASHES // ctx 0 + %memset + // Reset the global metadata + PUSH 0 %mstore_global_metadata(@GLOBAL_METADATA_BLOB_VERSIONED_HASHES_RLP_LEN) + PUSH 0 %mstore_global_metadata(@GLOBAL_METADATA_BLOB_VERSIONED_HASHES_LEN) + // stack: (empty) +%endmacro \ No newline at end of file diff --git a/evm_arithmetization/src/cpu/kernel/constants/global_metadata.rs b/evm_arithmetization/src/cpu/kernel/constants/global_metadata.rs index 8f18ec60a..c81d4f4b2 100644 --- a/evm_arithmetization/src/cpu/kernel/constants/global_metadata.rs +++ b/evm_arithmetization/src/cpu/kernel/constants/global_metadata.rs @@ -102,6 +102,19 @@ pub(crate) enum GlobalMetadata { KernelHash, KernelLen, + /// The address of the next available address in + /// Segment::AccountsLinkedList + AccountsLinkedListNextAvailable, + /// The address of the next available address in + /// Segment::StorageLinkedList + StorageLinkedListNextAvailable, + /// Length of the `AccountsLinkedList` segment after insertion of the + /// initial accounts. + InitialAccountsLinkedListLen, + /// Length of the `StorageLinkedList` segment after insertion of the + /// initial storage slots. + InitialStorageLinkedListLen, + /// The length of the transient storage segment. TransientStorageLen, @@ -114,7 +127,7 @@ pub(crate) enum GlobalMetadata { } impl GlobalMetadata { - pub(crate) const COUNT: usize = 55; + pub(crate) const COUNT: usize = 59; /// Unscales this virtual offset by their respective `Segment` value. pub(crate) const fn unscale(&self) -> usize { @@ -174,6 +187,10 @@ impl GlobalMetadata { Self::CreatedContractsLen, Self::KernelHash, Self::KernelLen, + Self::AccountsLinkedListNextAvailable, + Self::StorageLinkedListNextAvailable, + Self::InitialAccountsLinkedListLen, + Self::InitialStorageLinkedListLen, Self::TransientStorageLen, Self::BlobVersionedHashesRlpStart, Self::BlobVersionedHashesRlpLen, @@ -235,6 +252,16 @@ impl GlobalMetadata { Self::CreatedContractsLen => "GLOBAL_METADATA_CREATED_CONTRACTS_LEN", Self::KernelHash => "GLOBAL_METADATA_KERNEL_HASH", Self::KernelLen => "GLOBAL_METADATA_KERNEL_LEN", + Self::AccountsLinkedListNextAvailable => { + "GLOBAL_METADATA_ACCOUNTS_LINKED_LIST_NEXT_AVAILABLE" + } + Self::StorageLinkedListNextAvailable => { + "GLOBAL_METADATA_STORAGE_LINKED_LIST_NEXT_AVAILABLE" + } + Self::InitialAccountsLinkedListLen => { + "GLOBAL_METADATA_INITIAL_ACCOUNTS_LINKED_LIST_LEN" + } + Self::InitialStorageLinkedListLen => "GLOBAL_METADATA_INITIAL_STORAGE_LINKED_LIST_LEN", Self::TransientStorageLen => "GLOBAL_METADATA_TRANSIENT_STORAGE_LEN", Self::BlobVersionedHashesRlpStart => "GLOBAL_METADATA_BLOB_VERSIONED_HASHES_RLP_START", Self::BlobVersionedHashesRlpLen => "GLOBAL_METADATA_BLOB_VERSIONED_HASHES_RLP_LEN", diff --git a/evm_arithmetization/src/cpu/kernel/constants/mod.rs b/evm_arithmetization/src/cpu/kernel/constants/mod.rs index 65d6418b3..dbbc1c0d7 100644 --- a/evm_arithmetization/src/cpu/kernel/constants/mod.rs +++ b/evm_arithmetization/src/cpu/kernel/constants/mod.rs @@ -55,6 +55,10 @@ pub(crate) fn evm_constants() -> HashMap { c.insert(name.into(), U256::from(value)); } + for (name, value) in LINKED_LISTS_CONSTANTS { + c.insert(name.into(), U256::from(value)); + } + c.insert(MAX_NONCE.0.into(), U256::from(MAX_NONCE.1)); c.insert(CALL_STACK_LIMIT.0.into(), U256::from(CALL_STACK_LIMIT.1)); c.insert( @@ -109,7 +113,7 @@ pub(crate) fn evm_constants() -> HashMap { c } -const MISC_CONSTANTS: [(&str, [u8; 32]); 5] = [ +const MISC_CONSTANTS: [(&str, [u8; 32]); 6] = [ // Base for limbs used in bignum arithmetic. ( "BIGNUM_LIMB_BASE", @@ -134,6 +138,13 @@ const MISC_CONSTANTS: [(&str, [u8; 32]); 5] = [ "INITIAL_TXN_RLP_ADDR", hex!("0000000000000000000000000000000000000000000000000000000b00000001"), ), + // Address where the final registers start. It is the offset 6 within the + // SEGMENT_REGISTERS_STATES. + // *Note*: Changing this will break some tests. + ( + "FINAL_REGISTERS_ADDR", + hex!("0000000000000000000000000000000000000000000000000000002100000006"), + ), // Scaled boolean value indicating that we are in kernel mode, to be used within `kexit_info`. // It is equal to 2^32. ( @@ -348,6 +359,14 @@ const CODE_SIZE_LIMIT: [(&str, u64); 3] = [ const MAX_NONCE: (&str, u64) = ("MAX_NONCE", 0xffffffffffffffff); const CALL_STACK_LIMIT: (&str, u64) = ("CALL_STACK_LIMIT", 1024); +const LINKED_LISTS_CONSTANTS: [(&str, u16); 5] = [ + ("ACCOUNTS_LINKED_LISTS_NODE_SIZE", 4), + ("STORAGE_LINKED_LISTS_NODE_SIZE", 5), + ("ACCOUNTS_NEXT_NODE_PTR", 3), + ("STORAGE_NEXT_NODE_PTR", 4), + ("STORAGE_COPY_PAYLOAD_PTR", 3), +]; + /// Cancun-related constants /// See and /// . diff --git a/evm_arithmetization/src/cpu/kernel/interpreter.rs b/evm_arithmetization/src/cpu/kernel/interpreter.rs index 38fbe2320..a42bc3a1e 100644 --- a/evm_arithmetization/src/cpu/kernel/interpreter.rs +++ b/evm_arithmetization/src/cpu/kernel/interpreter.rs @@ -12,16 +12,18 @@ use ethereum_types::{BigEndianHash, U256}; use log::Level; use mpt_trie::partial_trie::PartialTrie; use plonky2::field::types::Field; +use serde::{Deserialize, Serialize}; use crate::byte_packing::byte_packing_stark::BytePackingOp; use crate::cpu::columns::CpuColumnsView; use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; use crate::generation::debug_inputs; -use crate::generation::mpt::load_all_mpts; +use crate::generation::mpt::{load_linked_lists_and_txn_and_receipt_mpts, TrieRootPtrs}; use crate::generation::rlp::all_rlp_prover_inputs_reversed; use crate::generation::state::{ - all_withdrawals_prover_inputs_reversed, GenerationState, GenerationStateCheckpoint, + all_ger_prover_inputs_reversed, all_withdrawals_prover_inputs_reversed, GenerationState, + GenerationStateCheckpoint, }; use crate::generation::{state::State, GenerationInputs}; use crate::keccak_sponge::columns::KECCAK_WIDTH_BYTES; @@ -29,7 +31,9 @@ use crate::keccak_sponge::keccak_sponge_stark::KeccakSpongeOp; use crate::memory::segments::Segment; use crate::util::h2u; use crate::witness::errors::ProgramError; -use crate::witness::memory::{MemoryAddress, MemoryOp, MemoryOpKind, MemorySegmentState}; +use crate::witness::memory::{ + MemoryAddress, MemoryContextState, MemoryOp, MemoryOpKind, MemorySegmentState, MemoryState, +}; use crate::witness::operation::Operation; use crate::witness::state::RegistersState; use crate::witness::transition::{ @@ -58,6 +62,8 @@ pub(crate) struct Interpreter { /// Holds the value of the clock: the clock counts the number of operations /// in the execution. pub(crate) clock: usize, + /// Log of the maximal number of CPU cycles in one segment execution. + max_cpu_len_log: Option, } /// Simulates the CPU execution from `state` until the program counter reaches @@ -71,8 +77,12 @@ pub(crate) fn simulate_cpu_and_get_user_jumps( None => { let halt_pc = KERNEL.global_labels[final_label]; let initial_context = state.registers.context; - let mut interpreter = - Interpreter::new_with_state_and_halt_condition(state, halt_pc, initial_context); + let mut interpreter = Interpreter::new_with_state_and_halt_condition( + state, + halt_pc, + initial_context, + None, + ); log::debug!("Simulating CPU for jumpdest analysis."); @@ -80,43 +90,97 @@ pub(crate) fn simulate_cpu_and_get_user_jumps( log::trace!("jumpdest table = {:?}", interpreter.jumpdest_table); + let clock = interpreter.get_clock(); + interpreter .generation_state .set_jumpdest_analysis_inputs(interpreter.jumpdest_table); - log::debug!("Simulated CPU for jumpdest analysis halted."); + log::debug!( + "Simulated CPU for jumpdest analysis halted after {:?} cycles.", + clock + ); + interpreter.generation_state.jumpdest_table } } } +/// State data required to initialize the state passed to the prover. +#[derive(Default, Debug, Clone, Serialize, Deserialize)] +pub(crate) struct ExtraSegmentData { + pub(crate) bignum_modmul_result_limbs: Vec, + pub(crate) rlp_prover_inputs: Vec, + pub(crate) withdrawal_prover_inputs: Vec, + pub(crate) ger_prover_inputs: Vec, + pub(crate) trie_root_ptrs: TrieRootPtrs, + pub(crate) jumpdest_table: Option>>, + pub(crate) next_txn_index: usize, +} + +pub(crate) fn set_registers_and_run( + registers: RegistersState, + interpreter: &mut Interpreter, +) -> anyhow::Result<(RegistersState, Option)> { + interpreter.generation_state.registers = registers; + interpreter.generation_state.registers.program_counter = KERNEL.global_labels["init"]; + interpreter.generation_state.registers.is_kernel = true; + interpreter.clock = 0; + + // Write initial registers. + [ + registers.program_counter.into(), + (registers.is_kernel as usize).into(), + registers.stack_len.into(), + registers.stack_top, + registers.context.into(), + registers.gas_used.into(), + ] + .iter() + .enumerate() + .for_each(|(i, reg_content)| { + interpreter.generation_state.memory.set( + MemoryAddress::new(0, Segment::RegistersStates, i), + *reg_content, + ) + }); + + interpreter.run() +} + impl Interpreter { /// Returns an instance of `Interpreter` given `GenerationInputs`, and /// assuming we are initializing with the `KERNEL` code. pub(crate) fn new_with_generation_inputs( initial_offset: usize, initial_stack: Vec, - inputs: GenerationInputs, + inputs: &GenerationInputs, + max_cpu_len_log: Option, ) -> Self { - debug_inputs(&inputs); + debug_inputs(inputs); - let mut result = Self::new(initial_offset, initial_stack); + let mut result = Self::new(initial_offset, initial_stack, max_cpu_len_log); result.initialize_interpreter_state(inputs); result } - pub(crate) fn new(initial_offset: usize, initial_stack: Vec) -> Self { + pub(crate) fn new( + initial_offset: usize, + initial_stack: Vec, + max_cpu_len_log: Option, + ) -> Self { let mut interpreter = Self { - generation_state: GenerationState::new(GenerationInputs::default(), &KERNEL.code) + generation_state: GenerationState::new(&GenerationInputs::default(), &KERNEL.code) .expect("Default inputs are known-good"), // `DEFAULT_HALT_OFFSET` is used as a halting point for the interpreter, // while the label `halt` is the halting label in the kernel. - halt_offsets: vec![DEFAULT_HALT_OFFSET, KERNEL.global_labels["halt"]], + halt_offsets: vec![DEFAULT_HALT_OFFSET, KERNEL.global_labels["halt_final"]], halt_context: None, opcode_count: [0; 256], jumpdest_table: HashMap::new(), is_jumpdest_analysis: false, clock: 0, + max_cpu_len_log, }; interpreter.generation_state.registers.program_counter = initial_offset; let initial_stack_len = initial_stack.len(); @@ -137,6 +201,7 @@ impl Interpreter { state: &GenerationState, halt_offset: usize, halt_context: usize, + max_cpu_len_log: Option, ) -> Self { Self { generation_state: state.soft_clone(), @@ -146,36 +211,55 @@ impl Interpreter { jumpdest_table: HashMap::new(), is_jumpdest_analysis: true, clock: 0, + max_cpu_len_log, } } /// Initializes the interpreter state given `GenerationInputs`. - pub(crate) fn initialize_interpreter_state(&mut self, inputs: GenerationInputs) { - let kernel_hash = KERNEL.code_hash; - let kernel_code_len = KERNEL.code.len(); + pub(crate) fn initialize_interpreter_state(&mut self, inputs: &GenerationInputs) { + // Initialize registers. + let registers_before = RegistersState::new(); + self.generation_state.registers = RegistersState { + program_counter: self.generation_state.registers.program_counter, + is_kernel: self.generation_state.registers.is_kernel, + ..registers_before + }; + let tries = &inputs.tries; - // Set state's inputs. - self.generation_state.inputs = inputs.clone(); + // Set state's inputs. We trim unnecessary components. + self.generation_state.inputs = inputs.trim(); // Initialize the MPT's pointers. - let (trie_root_ptrs, trie_data) = - load_all_mpts(tries).expect("Invalid MPT data for preinitialization"); + let (trie_root_ptrs, state_leaves, storage_leaves, trie_data) = + load_linked_lists_and_txn_and_receipt_mpts(&inputs.tries) + .expect("Invalid MPT data for preinitialization"); + let trie_roots_after = &inputs.trie_roots_after; self.generation_state.trie_root_ptrs = trie_root_ptrs; // Initialize the `TrieData` segment. - let preinit_trie_data_segment = MemorySegmentState { - content: trie_data.iter().map(|&elt| Some(elt)).collect::>(), + let preinit_trie_data_segment = MemorySegmentState { content: trie_data }; + let preinit_accounts_ll_segment = MemorySegmentState { + content: state_leaves, + }; + let preinit_storage_ll_segment = MemorySegmentState { + content: storage_leaves, }; self.insert_preinitialized_segment(Segment::TrieData, preinit_trie_data_segment); + self.insert_preinitialized_segment( + Segment::AccountsLinkedList, + preinit_accounts_ll_segment, + ); + self.insert_preinitialized_segment(Segment::StorageLinkedList, preinit_storage_ll_segment); // Update the RLP and withdrawal prover inputs. - let rlp_prover_inputs = - all_rlp_prover_inputs_reversed(inputs.clone().signed_txn.as_ref().unwrap_or(&vec![])); + let rlp_prover_inputs = all_rlp_prover_inputs_reversed(&inputs.signed_txns); let withdrawal_prover_inputs = all_withdrawals_prover_inputs_reversed(&inputs.withdrawals); + let ger_prover_inputs = all_ger_prover_inputs_reversed(&inputs.global_exit_roots); self.generation_state.rlp_prover_inputs = rlp_prover_inputs; self.generation_state.withdrawal_prover_inputs = withdrawal_prover_inputs; + self.generation_state.ger_prover_inputs = ger_prover_inputs; // Set `GlobalMetadata` values. let metadata = &inputs.block_metadata; @@ -216,7 +300,7 @@ impl Interpreter { (GlobalMetadata::TxnNumberBefore, inputs.txn_number_before), ( GlobalMetadata::TxnNumberAfter, - inputs.txn_number_before + if inputs.signed_txn.is_some() { 1 } else { 0 }, + inputs.txn_number_before + inputs.signed_txns.len(), ), ( GlobalMetadata::StateTrieRootDigestBefore, @@ -242,8 +326,8 @@ impl Interpreter { GlobalMetadata::ReceiptTrieRootDigestAfter, h2u(trie_roots_after.receipts_root), ), - (GlobalMetadata::KernelHash, h2u(kernel_hash)), - (GlobalMetadata::KernelLen, kernel_code_len.into()), + (GlobalMetadata::KernelHash, h2u(KERNEL.code_hash)), + (GlobalMetadata::KernelLen, KERNEL.code.len().into()), ]; self.set_global_metadata_multi_fields(&global_metadata_to_set); @@ -252,12 +336,7 @@ impl Interpreter { let final_block_bloom_fields = (0..8) .map(|i| { ( - MemoryAddress::new_u256s( - U256::zero(), - (Segment::GlobalBlockBloom.unscale()).into(), - i.into(), - ) - .expect("This cannot panic as `virt` fits in a `u32`"), + MemoryAddress::new(0, Segment::GlobalBlockBloom, i), metadata.block_bloom[i], ) }) @@ -269,18 +348,33 @@ impl Interpreter { let block_hashes_fields = (0..256) .map(|i| { ( - MemoryAddress::new_u256s( - U256::zero(), - (Segment::BlockHashes.unscale()).into(), - i.into(), - ) - .expect("This cannot panic as `virt` fits in a `u32`"), + MemoryAddress::new(0, Segment::BlockHashes, i), h2u(inputs.block_hashes.prev_hashes[i]), ) }) .collect::>(); self.set_memory_multi_addresses(&block_hashes_fields); + + // Write initial registers. + let registers_before = [ + registers_before.program_counter.into(), + (registers_before.is_kernel as usize).into(), + registers_before.stack_len.into(), + registers_before.stack_top, + registers_before.context.into(), + registers_before.gas_used.into(), + ]; + let registers_before_fields = (0..registers_before.len()) + .map(|i| { + ( + MemoryAddress::new(0, Segment::RegistersStates, i), + registers_before[i], + ) + }) + .collect::>(); + + self.set_memory_multi_addresses(®isters_before_fields); } /// Applies all memory operations since the last checkpoint. The memory @@ -309,8 +403,8 @@ impl Interpreter { Ok(()) } - pub(crate) fn run(&mut self) -> Result<(), anyhow::Error> { - self.run_cpu()?; + pub(crate) fn run(&mut self) -> Result<(RegistersState, Option), anyhow::Error> { + let (final_registers, final_mem) = self.run_cpu(self.max_cpu_len_log)?; #[cfg(debug_assertions)] { @@ -322,7 +416,13 @@ impl Interpreter { } println!("Total: {}", self.opcode_count.into_iter().sum::()); } - Ok(()) + + Ok((final_registers, final_mem)) + } + + /// Returns the max number of CPU cycles. + pub(crate) fn get_max_cpu_len_log(&self) -> Option { + self.max_cpu_len_log } pub(crate) fn code(&self) -> &MemorySegmentState { @@ -406,7 +506,7 @@ impl Interpreter { } impl State for Interpreter { - //// Returns a `GenerationStateCheckpoint` to save the current registers and + /// Returns a `GenerationStateCheckpoint` to save the current registers and /// reset memory operations to the empty vector. fn checkpoint(&mut self) -> GenerationStateCheckpoint { self.generation_state.traces.memory_ops = vec![]; @@ -502,6 +602,56 @@ impl State for Interpreter { self.halt_offsets.clone() } + fn get_active_memory(&self) -> Option { + let mut memory_state = MemoryState { + contexts: vec![ + MemoryContextState::default(); + self.generation_state.memory.contexts.len() + ], + ..self.generation_state.memory.clone() + }; + + // Only copy memory from non-stale contexts + for (ctx_idx, ctx) in self.generation_state.memory.contexts.iter().enumerate() { + if !self + .get_generation_state() + .stale_contexts + .contains(&ctx_idx) + { + memory_state.contexts[ctx_idx] = ctx.clone(); + } + } + + memory_state.preinitialized_segments = + self.generation_state.memory.preinitialized_segments.clone(); + + Some(memory_state) + } + + fn update_interpreter_final_registers(&mut self, final_registers: RegistersState) { + { + let registers_after = [ + final_registers.program_counter.into(), + (final_registers.is_kernel as usize).into(), + final_registers.stack_len.into(), + final_registers.stack_top, + final_registers.context.into(), + final_registers.gas_used.into(), + ]; + + let length = registers_after.len(); + let registers_after_fields = (0..length) + .map(|i| { + ( + MemoryAddress::new(0, Segment::RegistersStates, length + i), + registers_after[i], + ) + }) + .collect::>(); + self.set_memory_multi_addresses(®isters_after_fields); + } + } + fn try_perform_instruction(&mut self) -> Result { let registers = self.generation_state.registers; let (mut row, opcode) = self.base_row(); @@ -817,7 +967,7 @@ mod tests { 0x60, 0xff, 0x60, 0x0, 0x52, 0x60, 0, 0x51, 0x60, 0x1, 0x51, 0x60, 0x42, 0x60, 0x27, 0x53, ]; - let mut interpreter: Interpreter = Interpreter::new(0, vec![]); + let mut interpreter: Interpreter = Interpreter::new(0, vec![], None); interpreter.set_code(1, code.to_vec()); diff --git a/evm_arithmetization/src/cpu/kernel/tests/account_code.rs b/evm_arithmetization/src/cpu/kernel/tests/account_code.rs index 125760ee5..da6a9a378 100644 --- a/evm_arithmetization/src/cpu/kernel/tests/account_code.rs +++ b/evm_arithmetization/src/cpu/kernel/tests/account_code.rs @@ -5,7 +5,7 @@ use ethereum_types::{Address, BigEndianHash, H256, U256}; use hex_literal::hex; use keccak_hash::keccak; use mpt_trie::nibbles::Nibbles; -use mpt_trie::partial_trie::{HashedPartialTrie, PartialTrie}; +use mpt_trie::partial_trie::{HashedPartialTrie, Node, PartialTrie}; use plonky2::field::goldilocks_field::GoldilocksField as F; use plonky2::field::types::Field; use rand::{thread_rng, Rng}; @@ -15,20 +15,92 @@ use crate::cpu::kernel::constants::context_metadata::ContextMetadata::{self, Gas use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; use crate::cpu::kernel::interpreter::Interpreter; use crate::cpu::kernel::tests::mpt::nibbles_64; -use crate::generation::mpt::{load_all_mpts, AccountRlp}; +use crate::generation::mpt::{ + load_linked_lists_and_txn_and_receipt_mpts, load_state_mpt, AccountRlp, +}; use crate::generation::TrieInputs; use crate::memory::segments::Segment; +use crate::util::h2u; use crate::witness::memory::MemoryAddress; use crate::witness::operation::CONTEXT_SCALING_FACTOR; -use crate::Node; pub(crate) fn initialize_mpts( interpreter: &mut Interpreter, trie_inputs: &TrieInputs, ) { // Load all MPTs. - let (trie_root_ptrs, trie_data) = - load_all_mpts(trie_inputs).expect("Invalid MPT data for preinitialization"); + let (mut trie_root_ptrs, state_leaves, storage_leaves, trie_data) = + load_linked_lists_and_txn_and_receipt_mpts(trie_inputs) + .expect("Invalid MPT data for preinitialization"); + + interpreter.generation_state.memory.contexts[0].segments + [Segment::AccountsLinkedList.unscale()] + .content = state_leaves; + interpreter.generation_state.memory.contexts[0].segments + [Segment::StorageLinkedList.unscale()] + .content = storage_leaves; + interpreter.generation_state.memory.contexts[0].segments[Segment::TrieData.unscale()].content = + trie_data.clone(); + interpreter.generation_state.trie_root_ptrs = trie_root_ptrs.clone(); + + if trie_root_ptrs.state_root_ptr.is_none() { + trie_root_ptrs.state_root_ptr = Some( + load_state_mpt( + &trie_inputs.trim(), + &mut interpreter.generation_state.memory.contexts[0].segments + [Segment::TrieData.unscale()] + .content, + ) + .expect("Invalid MPT data for preinitialization"), + ); + } + + let accounts_len = Segment::AccountsLinkedList as usize + + interpreter.generation_state.memory.contexts[0].segments + [Segment::AccountsLinkedList.unscale()] + .content + .len(); + let storage_len = Segment::StorageLinkedList as usize + + interpreter.generation_state.memory.contexts[0].segments + [Segment::StorageLinkedList.unscale()] + .content + .len(); + let accounts_len_addr = MemoryAddress { + context: 0, + segment: Segment::GlobalMetadata.unscale(), + virt: GlobalMetadata::AccountsLinkedListNextAvailable.unscale(), + }; + let storage_len_addr = MemoryAddress { + context: 0, + segment: Segment::GlobalMetadata.unscale(), + virt: GlobalMetadata::StorageLinkedListNextAvailable.unscale(), + }; + let initial_accounts_len_addr = MemoryAddress { + context: 0, + segment: Segment::GlobalMetadata.unscale(), + virt: GlobalMetadata::InitialAccountsLinkedListLen.unscale(), + }; + let initial_storage_len_addr = MemoryAddress { + context: 0, + segment: Segment::GlobalMetadata.unscale(), + virt: GlobalMetadata::InitialStorageLinkedListLen.unscale(), + }; + let trie_data_len_addr = MemoryAddress { + context: 0, + segment: Segment::GlobalMetadata.unscale(), + virt: GlobalMetadata::TrieDataSize.unscale(), + }; + let trie_data_len = interpreter.generation_state.memory.contexts[0].segments + [Segment::TrieData.unscale()] + .content + .len(); + interpreter.set_memory_multi_addresses(&[ + (accounts_len_addr, accounts_len.into()), + (storage_len_addr, storage_len.into()), + (trie_data_len_addr, trie_data_len.into()), + (initial_accounts_len_addr, accounts_len.into()), + (initial_storage_len_addr, storage_len.into()), + ]); let state_addr = MemoryAddress::new_bundle((GlobalMetadata::StateTrieRoot as usize).into()).unwrap(); @@ -36,15 +108,15 @@ pub(crate) fn initialize_mpts( MemoryAddress::new_bundle((GlobalMetadata::TransactionTrieRoot as usize).into()).unwrap(); let receipts_addr = MemoryAddress::new_bundle((GlobalMetadata::ReceiptTrieRoot as usize).into()).unwrap(); - let len_addr = - MemoryAddress::new_bundle((GlobalMetadata::TrieDataSize as usize).into()).unwrap(); - let to_set = [ - (state_addr, trie_root_ptrs.state_root_ptr.into()), + let mut to_set = vec![]; + if let Some(state_root_ptr) = trie_root_ptrs.state_root_ptr { + to_set.push((state_addr, state_root_ptr.into())); + } + to_set.extend([ (txn_addr, trie_root_ptrs.txn_root_ptr.into()), (receipts_addr, trie_root_ptrs.receipt_root_ptr.into()), - (len_addr, trie_data.len().into()), - ]; + ]); interpreter.set_memory_multi_addresses(&to_set); @@ -53,39 +125,29 @@ pub(crate) fn initialize_mpts( interpreter .generation_state .memory - .set(trie_addr, data.into()); - } -} - -// Test account with a given code hash. -fn test_account(code: &[u8]) -> AccountRlp { - AccountRlp { - nonce: U256::from(1111), - balance: U256::from(2222), - storage_root: HashedPartialTrie::from(Node::Empty).hash(), - code_hash: keccak(code), + .set(trie_addr, data.unwrap_or_default()); } } -fn random_code() -> Vec { - let mut rng = thread_rng(); - let num_bytes = rng.gen_range(0..1000); - (0..num_bytes).map(|_| rng.gen()).collect() -} - // Stolen from `tests/mpt/insert.rs` // Prepare the interpreter by inserting the account in the state trie. -fn prepare_interpreter( +pub(crate) fn prepare_interpreter( interpreter: &mut Interpreter, address: Address, account: &AccountRlp, ) -> Result<()> { let mpt_insert_state_trie = KERNEL.global_labels["mpt_insert_state_trie"]; - let mpt_hash_state_trie = KERNEL.global_labels["mpt_hash_state_trie"]; - let mut state_trie: HashedPartialTrie = Default::default(); - let trie_inputs = Default::default(); + let check_state_trie = KERNEL.global_labels["check_final_state_trie"]; + let mut state_trie: HashedPartialTrie = HashedPartialTrie::from(Node::Empty); + let trie_inputs = TrieInputs { + state_trie: HashedPartialTrie::from(Node::Empty), + transactions_trie: HashedPartialTrie::from(Node::Empty), + receipts_trie: HashedPartialTrie::from(Node::Empty), + storage_tries: vec![], + }; initialize_mpts(interpreter, &trie_inputs); + assert_eq!(interpreter.stack(), vec![]); let k = nibbles_64(U256::from_big_endian( keccak(address.to_fixed_bytes()).as_bytes(), @@ -119,6 +181,7 @@ fn prepare_interpreter( .expect("The stack should not overflow"); // key interpreter.run()?; + assert_eq!( interpreter.stack().len(), 0, @@ -126,13 +189,48 @@ fn prepare_interpreter( interpreter.stack() ); - // Now, execute mpt_hash_state_trie. - interpreter.generation_state.registers.program_counter = mpt_hash_state_trie; + // Set initial tries. interpreter .push(0xDEADBEEFu32.into()) .expect("The stack should not overflow"); interpreter - .push(1.into()) // Initial length of the trie data segment, unused. + .push((Segment::StorageLinkedList as usize + 8).into()) + .expect("The stack should not overflow"); + interpreter + .push((Segment::AccountsLinkedList as usize + 6).into()) + .expect("The stack should not overflow"); + interpreter + .push(interpreter.get_global_metadata_field(GlobalMetadata::StateTrieRoot)) + .expect("The stack should not overflow"); + + // Now, set the payload. + interpreter.generation_state.registers.program_counter = + KERNEL.global_labels["mpt_set_payload"]; + + interpreter.run()?; + + let acc_ptr = interpreter.pop().expect("The stack should not be empty") - 2; + let storage_ptr = interpreter.pop().expect("The stack should not be empty") - 3; + interpreter.set_global_metadata_field(GlobalMetadata::InitialAccountsLinkedListLen, acc_ptr); + interpreter.set_global_metadata_field(GlobalMetadata::InitialStorageLinkedListLen, storage_ptr); + + // Now, execute `mpt_hash_state_trie`. + state_trie.insert(k, rlp::encode(account).to_vec())?; + let expected_state_trie_hash = state_trie.hash(); + interpreter.set_global_metadata_field( + GlobalMetadata::StateTrieRootDigestAfter, + h2u(expected_state_trie_hash), + ); + + interpreter.generation_state.registers.program_counter = check_state_trie; + interpreter + .halt_offsets + .push(KERNEL.global_labels["check_txn_trie"]); + interpreter + .push(0xDEADBEEFu32.into()) + .expect("The stack should not overflow"); + interpreter + .push(interpreter.get_global_metadata_field(GlobalMetadata::TrieDataSize)) // Initial trie data segment size, unused. .expect("The stack should not overflow"); interpreter.run()?; @@ -142,21 +240,32 @@ fn prepare_interpreter( "Expected 2 items on stack after hashing, found {:?}", interpreter.stack() ); - let hash = H256::from_uint(&interpreter.stack()[1]); - - state_trie.insert(k, rlp::encode(account).to_vec())?; - let expected_state_trie_hash = state_trie.hash(); - assert_eq!(hash, expected_state_trie_hash); Ok(()) } +// Test account with a given code hash. +fn test_account(code: &[u8]) -> AccountRlp { + AccountRlp { + nonce: U256::from(1111), + balance: U256::from(2222), + storage_root: HashedPartialTrie::from(Node::Empty).hash(), + code_hash: keccak(code), + } +} + +fn random_code() -> Vec { + let mut rng = thread_rng(); + let num_bytes = rng.gen_range(0..1000); + (0..num_bytes).map(|_| rng.gen()).collect() +} + #[test] fn test_extcodesize() -> Result<()> { let code = random_code(); let account = test_account(&code); - let mut interpreter: Interpreter = Interpreter::new(0, vec![]); + let mut interpreter: Interpreter = Interpreter::new(0, vec![], None); let address: Address = thread_rng().gen(); // Prepare the interpreter by inserting the account in the state trie. prepare_interpreter(&mut interpreter, address, &account)?; @@ -178,7 +287,10 @@ fn test_extcodesize() -> Result<()> { HashMap::from([(keccak(&code), code.clone())]); interpreter.run()?; - assert_eq!(interpreter.stack(), vec![code.len().into()]); + assert_eq!( + interpreter.stack(), + vec![U256::one() << CONTEXT_SCALING_FACTOR, code.len().into()] + ); Ok(()) } @@ -188,7 +300,7 @@ fn test_extcodecopy() -> Result<()> { let code = random_code(); let account = test_account(&code); - let mut interpreter: Interpreter = Interpreter::new(0, vec![]); + let mut interpreter: Interpreter = Interpreter::new(0, vec![], None); let address: Address = thread_rng().gen(); // Prepare the interpreter by inserting the account in the state trie. prepare_interpreter(&mut interpreter, address, &account)?; @@ -277,6 +389,52 @@ fn prepare_interpreter_all_accounts( initialize_mpts(interpreter, &trie_inputs); assert_eq!(interpreter.stack(), vec![]); + // Copy the initial account and storage pointers + interpreter + .push(0xDEADBEEFu32.into()) + .expect("The stack should not overflow"); + interpreter.generation_state.registers.program_counter = + KERNEL.global_labels["store_initial_accounts"]; + interpreter.run()?; + interpreter + .push(0xDEADBEEFu32.into()) + .expect("The stack should not overflow"); + interpreter.generation_state.registers.program_counter = + KERNEL.global_labels["store_initial_slots"]; + interpreter.run()?; + + // Set the pointers to the initial payloads. + interpreter + .push(0xDEADBEEFu32.into()) + .expect("The stack should not overflow"); + interpreter + .push((Segment::StorageLinkedList as usize + 8).into()) + .expect("The stack should not overflow"); + interpreter + .push((Segment::AccountsLinkedList as usize + 6).into()) + .expect("The stack should not overflow"); + interpreter + .push(interpreter.get_global_metadata_field(GlobalMetadata::StateTrieRoot)) + .expect("The stack should not overflow"); + + // Now, set the payloads in the state trie leaves. + interpreter.generation_state.registers.program_counter = + KERNEL.global_labels["mpt_set_payload"]; + + interpreter.run()?; + + assert_eq!( + interpreter.stack().len(), + 2, + "Expected 2 items on stack after setting the initial trie payloads, found {:?}", + interpreter.stack() + ); + + let acc_ptr = interpreter.pop().expect("The stack should not be empty") - 2; + let storage_ptr = interpreter.pop().expect("The stack should not be empty") - 3; + interpreter.set_global_metadata_field(GlobalMetadata::InitialAccountsLinkedListLen, acc_ptr); + interpreter.set_global_metadata_field(GlobalMetadata::InitialStorageLinkedListLen, storage_ptr); + // Switch context and initialize memory with the data we need for the tests. interpreter.generation_state.registers.program_counter = 0; interpreter.set_code(1, code.to_vec()); @@ -333,7 +491,7 @@ fn sstore() -> Result<()> { }; let initial_stack = vec![]; - let mut interpreter: Interpreter = Interpreter::new(0, initial_stack); + let mut interpreter: Interpreter = Interpreter::new(0, initial_stack, None); // Pre-initialize the accessed addresses list. let init_accessed_addresses = KERNEL.global_labels["init_access_lists"]; @@ -365,8 +523,19 @@ fn sstore() -> Result<()> { .hash(), ..AccountRlp::default() }; - // Now, execute mpt_hash_state_trie. - let mpt_hash_state_trie = KERNEL.global_labels["mpt_hash_state_trie"]; + + let mut expected_state_trie_after = HashedPartialTrie::from(Node::Empty); + expected_state_trie_after.insert(addr_nibbles, rlp::encode(&account_after).to_vec())?; + + let expected_state_trie_hash = expected_state_trie_after.hash(); + + interpreter.set_global_metadata_field( + GlobalMetadata::StateTrieRootDigestAfter, + h2u(expected_state_trie_hash), + ); + + // Now, execute `mpt_hash_state_trie` and check that the hash is correct. + let mpt_hash_state_trie = KERNEL.global_labels["check_final_state_trie"]; interpreter.generation_state.registers.program_counter = mpt_hash_state_trie; interpreter.set_is_kernel(true); interpreter.set_context(0); @@ -378,27 +547,12 @@ fn sstore() -> Result<()> { .expect("The stack should not overflow"); interpreter.run()?; - assert_eq!( - interpreter.stack().len(), - 2, - "Expected 2 items on stack after hashing, found {:?}", - interpreter.stack() - ); - - let hash = H256::from_uint(&interpreter.stack()[1]); - - let mut expected_state_trie_after = HashedPartialTrie::from(Node::Empty); - expected_state_trie_after.insert(addr_nibbles, rlp::encode(&account_after).to_vec())?; - - let expected_state_trie_hash = expected_state_trie_after.hash(); - assert_eq!(hash, expected_state_trie_hash); Ok(()) } /// Tests an SLOAD within a code similar to the contract code in add11_yml. #[test] fn sload() -> Result<()> { - // We take the same `to` account as in add11_yml. let addr = hex!("095e7baea6a6c7c4c2dfeb977efac326af552d87"); let addr_hashed = keccak(addr); @@ -431,7 +585,7 @@ fn sload() -> Result<()> { }; let initial_stack = vec![]; - let mut interpreter: Interpreter = Interpreter::new(0, initial_stack); + let mut interpreter: Interpreter = Interpreter::new(0, initial_stack, None); // Pre-initialize the accessed addresses list. let init_accessed_addresses = KERNEL.global_labels["init_access_lists"]; @@ -464,7 +618,7 @@ fn sload() -> Result<()> { interpreter .pop() .expect("The stack length should not be empty."); - // Now, execute mpt_hash_state_trie. We check that the state trie has not + // Now, execute `mpt_hash_state_trie`. We check that the state trie has not // changed. let mpt_hash_state_trie = KERNEL.global_labels["mpt_hash_state_trie"]; interpreter.generation_state.registers.program_counter = mpt_hash_state_trie; @@ -474,7 +628,7 @@ fn sload() -> Result<()> { .push(0xDEADBEEFu32.into()) .expect("The stack should not overflow."); interpreter - .push(1.into()) // Initial length of the trie data segment, unused. + .push(interpreter.get_global_metadata_field(GlobalMetadata::TrieDataSize)) // Initial length of the trie data segment, unused. .expect("The stack should not overflow."); interpreter.run()?; @@ -485,15 +639,6 @@ fn sload() -> Result<()> { interpreter.stack() ); - let trie_data_segment_len = interpreter.stack()[0]; - assert_eq!( - trie_data_segment_len, - interpreter - .get_memory_segment(Segment::TrieData) - .len() - .into() - ); - let hash = H256::from_uint(&interpreter.stack()[1]); let expected_state_trie_hash = state_trie_before.hash(); diff --git a/evm_arithmetization/src/cpu/kernel/tests/add11.rs b/evm_arithmetization/src/cpu/kernel/tests/add11.rs index ae5ac3871..89fbdec80 100644 --- a/evm_arithmetization/src/cpu/kernel/tests/add11.rs +++ b/evm_arithmetization/src/cpu/kernel/tests/add11.rs @@ -180,7 +180,7 @@ fn test_add11_yml() { }; let inputs = GenerationInputs { - signed_txn: Some(txn.to_vec()), + signed_txns: vec![txn.to_vec()], withdrawals: vec![], global_exit_roots: vec![], tries: tries_before, @@ -198,9 +198,9 @@ fn test_add11_yml() { }; let initial_stack = vec![]; - let initial_offset = KERNEL.global_labels["main"]; + let initial_offset = KERNEL.global_labels["init"]; let mut interpreter: Interpreter = - Interpreter::new_with_generation_inputs(initial_offset, initial_stack, inputs); + Interpreter::new_with_generation_inputs(initial_offset, initial_stack, &inputs, None); interpreter.set_is_kernel(true); interpreter.run().expect("Proving add11 failed."); @@ -361,7 +361,7 @@ fn test_add11_yml_with_exception() { }; let inputs = GenerationInputs { - signed_txn: Some(txn.to_vec()), + signed_txns: vec![txn.to_vec()], withdrawals: vec![], global_exit_roots: vec![], tries: tries_before, @@ -379,9 +379,9 @@ fn test_add11_yml_with_exception() { }; let initial_stack = vec![]; - let initial_offset = KERNEL.global_labels["main"]; + let initial_offset = KERNEL.global_labels["init"]; let mut interpreter: Interpreter = - Interpreter::new_with_generation_inputs(initial_offset, initial_stack, inputs); + Interpreter::new_with_generation_inputs(initial_offset, initial_stack, &inputs, None); interpreter.set_is_kernel(true); interpreter diff --git a/evm_arithmetization/src/cpu/kernel/tests/balance.rs b/evm_arithmetization/src/cpu/kernel/tests/balance.rs index 2b8f8c241..fc4d63347 100644 --- a/evm_arithmetization/src/cpu/kernel/tests/balance.rs +++ b/evm_arithmetization/src/cpu/kernel/tests/balance.rs @@ -1,16 +1,12 @@ use anyhow::Result; use ethereum_types::{Address, BigEndianHash, H256, U256}; -use keccak_hash::keccak; use mpt_trie::partial_trie::{HashedPartialTrie, PartialTrie}; use plonky2::field::goldilocks_field::GoldilocksField as F; -use plonky2::field::types::Field; use rand::{thread_rng, Rng}; use crate::cpu::kernel::aggregator::KERNEL; -use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; use crate::cpu::kernel::interpreter::Interpreter; -use crate::cpu::kernel::tests::account_code::initialize_mpts; -use crate::cpu::kernel::tests::mpt::nibbles_64; +use crate::cpu::kernel::tests::account_code::prepare_interpreter; use crate::generation::mpt::AccountRlp; use crate::Node; @@ -24,92 +20,13 @@ fn test_account(balance: U256) -> AccountRlp { } } -// Stolen from `tests/mpt/insert.rs` -// Prepare the interpreter by inserting the account in the state trie. -fn prepare_interpreter( - interpreter: &mut Interpreter, - address: Address, - account: &AccountRlp, -) -> Result<()> { - let mpt_insert_state_trie = KERNEL.global_labels["mpt_insert_state_trie"]; - let mpt_hash_state_trie = KERNEL.global_labels["mpt_hash_state_trie"]; - let mut state_trie: HashedPartialTrie = Default::default(); - let trie_inputs = Default::default(); - - initialize_mpts(interpreter, &trie_inputs); - assert_eq!(interpreter.stack(), vec![]); - - let k = nibbles_64(U256::from_big_endian( - keccak(address.to_fixed_bytes()).as_bytes(), - )); - // Next, execute mpt_insert_state_trie. - interpreter.generation_state.registers.program_counter = mpt_insert_state_trie; - let trie_data = interpreter.get_trie_data_mut(); - if trie_data.is_empty() { - // In the assembly we skip over 0, knowing trie_data[0] = 0 by default. - // Since we don't explicitly set it to 0, we need to do so here. - trie_data.push(Some(0.into())); - } - let value_ptr = trie_data.len(); - trie_data.push(Some(account.nonce)); - trie_data.push(Some(account.balance)); - // In memory, storage_root gets interpreted as a pointer to a storage trie, - // so we have to ensure the pointer is valid. It's easiest to set it to 0, - // which works as an empty node, since trie_data[0] = 0 = MPT_TYPE_EMPTY. - trie_data.push(Some(H256::zero().into_uint())); - trie_data.push(Some(account.code_hash.into_uint())); - let trie_data_len = trie_data.len().into(); - interpreter.set_global_metadata_field(GlobalMetadata::TrieDataSize, trie_data_len); - interpreter - .push(0xDEADBEEFu32.into()) - .expect("The stack should not overflow"); - interpreter - .push(value_ptr.into()) - .expect("The stack should not overflow"); // value_ptr - interpreter - .push(k.try_into().unwrap()) - .expect("The stack should not overflow"); // key - - interpreter.run()?; - assert_eq!( - interpreter.stack().len(), - 0, - "Expected empty stack after insert, found {:?}", - interpreter.stack() - ); - - // Now, execute mpt_hash_state_trie. - interpreter.generation_state.registers.program_counter = mpt_hash_state_trie; - interpreter - .push(0xDEADBEEFu32.into()) - .expect("The stack should not overflow"); - interpreter - .push(1.into()) // Initial trie data segment size, unused. - .expect("The stack should not overflow"); - interpreter.run()?; - - assert_eq!( - interpreter.stack().len(), - 2, - "Expected 2 items on stack after hashing, found {:?}", - interpreter.stack() - ); - let hash = H256::from_uint(&interpreter.stack()[1]); - - state_trie.insert(k, rlp::encode(account).to_vec())?; - let expected_state_trie_hash = state_trie.hash(); - assert_eq!(hash, expected_state_trie_hash); - - Ok(()) -} - #[test] fn test_balance() -> Result<()> { let mut rng = thread_rng(); let balance = U256(rng.gen()); let account = test_account(balance); - let mut interpreter: Interpreter = Interpreter::new(0, vec![]); + let mut interpreter: Interpreter = Interpreter::new(0, vec![], None); let address: Address = rng.gen(); // Prepare the interpreter by inserting the account in the state trie. prepare_interpreter(&mut interpreter, address, &account)?; diff --git a/evm_arithmetization/src/cpu/kernel/tests/bignum/mod.rs b/evm_arithmetization/src/cpu/kernel/tests/bignum/mod.rs index c18ad5f76..efb37bcd3 100644 --- a/evm_arithmetization/src/cpu/kernel/tests/bignum/mod.rs +++ b/evm_arithmetization/src/cpu/kernel/tests/bignum/mod.rs @@ -9,7 +9,6 @@ use itertools::Itertools; use num::{BigUint, One, Zero}; use num_bigint::RandBigInt; use plonky2::field::goldilocks_field::GoldilocksField as F; -use plonky2_util::ceil_div_usize; use rand::Rng; use crate::cpu::kernel::aggregator::KERNEL; @@ -90,7 +89,7 @@ fn max_bignum(bit_size: usize) -> BigUint { } fn bignum_len(a: &BigUint) -> usize { - ceil_div_usize(a.bits() as usize, BIGNUM_LIMB_BITS) + (a.bits() as usize).div_ceil(BIGNUM_LIMB_BITS) } fn run_test(fn_label: &str, memory: Vec, stack: Vec) -> Result<(Vec, Vec)> { @@ -101,7 +100,7 @@ fn run_test(fn_label: &str, memory: Vec, stack: Vec) -> Result<(Vec< initial_stack.push(retdest); initial_stack.reverse(); - let mut interpreter: Interpreter = Interpreter::new(fn_label, initial_stack); + let mut interpreter: Interpreter = Interpreter::new(fn_label, initial_stack, None); interpreter.set_current_general_memory(memory); interpreter.run()?; diff --git a/evm_arithmetization/src/cpu/kernel/tests/blobhash.rs b/evm_arithmetization/src/cpu/kernel/tests/blobhash.rs index 429bd729f..73bb7ee54 100644 --- a/evm_arithmetization/src/cpu/kernel/tests/blobhash.rs +++ b/evm_arithmetization/src/cpu/kernel/tests/blobhash.rs @@ -19,7 +19,7 @@ fn test_valid_blobhash() -> Result<()> { let index = 3; let target_hash = versioned_hashes[index]; - let mut interpreter: Interpreter = Interpreter::new(blobhash_label, vec![]); + let mut interpreter: Interpreter = Interpreter::new(blobhash_label, vec![], None); interpreter .generation_state .memory @@ -60,7 +60,7 @@ fn test_invalid_blobhash() -> Result<()> { let index = 7; let target_hash = U256::zero(); // out of bound indexing yields 0. - let mut interpreter: Interpreter = Interpreter::new(blobhash_label, vec![]); + let mut interpreter: Interpreter = Interpreter::new(blobhash_label, vec![], None); interpreter .generation_state .memory diff --git a/evm_arithmetization/src/cpu/kernel/tests/block_hash.rs b/evm_arithmetization/src/cpu/kernel/tests/block_hash.rs index e7d142da0..9aac4d247 100644 --- a/evm_arithmetization/src/cpu/kernel/tests/block_hash.rs +++ b/evm_arithmetization/src/cpu/kernel/tests/block_hash.rs @@ -20,7 +20,7 @@ fn test_correct_block_hash() -> Result<()> { let hashes: Vec = vec![U256::from_big_endian(&thread_rng().gen::().0); 257]; - let mut interpreter: Interpreter = Interpreter::new(blockhash_label, initial_stack); + let mut interpreter: Interpreter = Interpreter::new(blockhash_label, initial_stack, None); interpreter.set_memory_segment(Segment::BlockHashes, hashes[0..256].to_vec()); interpreter.set_global_metadata_field(GlobalMetadata::BlockCurrentHash, hashes[256]); interpreter.set_global_metadata_field(GlobalMetadata::BlockNumber, 256.into()); @@ -49,7 +49,7 @@ fn test_big_index_block_hash() -> Result<()> { let hashes: Vec = vec![U256::from_big_endian(&thread_rng().gen::().0); 257]; - let mut interpreter: Interpreter = Interpreter::new(blockhash_label, initial_stack); + let mut interpreter: Interpreter = Interpreter::new(blockhash_label, initial_stack, None); interpreter.set_memory_segment(Segment::BlockHashes, hashes[0..256].to_vec()); interpreter.set_global_metadata_field(GlobalMetadata::BlockCurrentHash, hashes[256]); interpreter.set_global_metadata_field(GlobalMetadata::BlockNumber, cur_block_number.into()); @@ -79,7 +79,7 @@ fn test_small_index_block_hash() -> Result<()> { let hashes: Vec = vec![U256::from_big_endian(&thread_rng().gen::().0); 257]; - let mut interpreter: Interpreter = Interpreter::new(blockhash_label, initial_stack); + let mut interpreter: Interpreter = Interpreter::new(blockhash_label, initial_stack, None); interpreter.set_memory_segment(Segment::BlockHashes, hashes[0..256].to_vec()); interpreter.set_global_metadata_field(GlobalMetadata::BlockCurrentHash, hashes[256]); interpreter.set_global_metadata_field(GlobalMetadata::BlockNumber, cur_block_number.into()); @@ -107,7 +107,7 @@ fn test_block_hash_with_overflow() -> Result<()> { let hashes: Vec = vec![U256::from_big_endian(&thread_rng().gen::().0); 257]; - let mut interpreter: Interpreter = Interpreter::new(blockhash_label, initial_stack); + let mut interpreter: Interpreter = Interpreter::new(blockhash_label, initial_stack, None); interpreter.set_memory_segment(Segment::BlockHashes, hashes[0..256].to_vec()); interpreter.set_global_metadata_field(GlobalMetadata::BlockCurrentHash, hashes[256]); interpreter.set_global_metadata_field(GlobalMetadata::BlockNumber, cur_block_number.into()); diff --git a/evm_arithmetization/src/cpu/kernel/tests/bls381.rs b/evm_arithmetization/src/cpu/kernel/tests/bls381.rs index 0910ec75c..40a28ac5b 100644 --- a/evm_arithmetization/src/cpu/kernel/tests/bls381.rs +++ b/evm_arithmetization/src/cpu/kernel/tests/bls381.rs @@ -144,7 +144,7 @@ fn test_kzg_peval_precompile() -> Result<()> { stack.reverse(); let verify_kzg_proof = KERNEL.global_labels["verify_kzg_proof"]; - let mut interpreter: Interpreter = Interpreter::new(verify_kzg_proof, stack); + let mut interpreter: Interpreter = Interpreter::new(verify_kzg_proof, stack, None); interpreter.halt_offsets = vec![ KERNEL.global_labels["store_kzg_verification"], KERNEL.global_labels["fault_exception"], diff --git a/evm_arithmetization/src/cpu/kernel/tests/bn254.rs b/evm_arithmetization/src/cpu/kernel/tests/bn254.rs index a253fa815..b71b37c9a 100644 --- a/evm_arithmetization/src/cpu/kernel/tests/bn254.rs +++ b/evm_arithmetization/src/cpu/kernel/tests/bn254.rs @@ -325,7 +325,7 @@ fn test_ecpairing_precompile_invalid_input() -> Result<()> { stack.reverse(); for bytes in ECPAIRING_PRECOMPILE_INVALID_INPUTS.iter() { - let mut interpreter: Interpreter = Interpreter::new(pairing_label, stack.clone()); + let mut interpreter: Interpreter = Interpreter::new(pairing_label, stack.clone(), None); let preloaded_memory = vec![ U256::from_big_endian(&bytes[0..32]), // Px U256::from_big_endian(&bytes[32..64]), // Py diff --git a/evm_arithmetization/src/cpu/kernel/tests/core/access_lists.rs b/evm_arithmetization/src/cpu/kernel/tests/core/access_lists.rs index d1f94c5cb..9a52301e3 100644 --- a/evm_arithmetization/src/cpu/kernel/tests/core/access_lists.rs +++ b/evm_arithmetization/src/cpu/kernel/tests/core/access_lists.rs @@ -19,7 +19,7 @@ fn test_init_access_lists() -> Result<()> { // Check the initial state of the access list in the kernel. let initial_stack = vec![0xdeadbeefu32.into()]; - let mut interpreter = Interpreter::::new(init_label, initial_stack); + let mut interpreter = Interpreter::::new(init_label, initial_stack, None); interpreter.run()?; assert!(interpreter.stack().is_empty()); @@ -64,7 +64,7 @@ fn test_list_iterator() -> Result<()> { let init_label = KERNEL.global_labels["init_access_lists"]; let initial_stack = vec![0xdeadbeefu32.into()]; - let mut interpreter = Interpreter::::new(init_label, initial_stack); + let mut interpreter = Interpreter::::new(init_label, initial_stack, None); interpreter.run()?; // test the list iterator @@ -73,15 +73,14 @@ fn test_list_iterator() -> Result<()> { .get_addresses_access_list() .expect("Since we called init_access_lists there must be a list"); - let Some((pos_0, next_val_0, _)) = list.next() else { + let Some([next_val_0, _]) = list.next() else { return Err(anyhow::Error::msg("Couldn't get value")); }; - assert_eq!(pos_0, 0); assert_eq!(next_val_0, U256::MAX); - let Some((pos_0, _, _)) = list.next() else { + let Some([_, pos_0]) = list.next() else { return Err(anyhow::Error::msg("Couldn't get value")); }; - assert_eq!(pos_0, 0); + assert_eq!(pos_0, U256::from(Segment::AccessedAddresses as usize)); Ok(()) } @@ -91,7 +90,7 @@ fn test_insert_address() -> Result<()> { // Test for address already in list. let initial_stack = vec![0xdeadbeefu32.into()]; - let mut interpreter = Interpreter::::new(init_label, initial_stack); + let mut interpreter = Interpreter::::new(init_label, initial_stack, None); interpreter.run()?; let insert_accessed_addresses = KERNEL.global_labels["insert_accessed_addresses"]; @@ -128,7 +127,7 @@ fn test_insert_accessed_addresses() -> Result<()> { // Test for address already in list. let initial_stack = vec![0xdeadbeefu32.into()]; - let mut interpreter = Interpreter::::new(init_access_lists, initial_stack); + let mut interpreter = Interpreter::::new(init_access_lists, initial_stack, None); interpreter.run()?; let insert_accessed_addresses = KERNEL.global_labels["insert_accessed_addresses"]; @@ -215,7 +214,7 @@ fn test_insert_accessed_storage_keys() -> Result<()> { // Test for address already in list. let initial_stack = vec![0xdeadbeefu32.into()]; - let mut interpreter = Interpreter::::new(init_access_lists, initial_stack); + let mut interpreter = Interpreter::::new(init_access_lists, initial_stack, None); interpreter.run()?; let insert_accessed_storage_keys = KERNEL.global_labels["insert_accessed_storage_keys"]; diff --git a/evm_arithmetization/src/cpu/kernel/tests/core/create_addresses.rs b/evm_arithmetization/src/cpu/kernel/tests/core/create_addresses.rs index 79433c823..951884527 100644 --- a/evm_arithmetization/src/cpu/kernel/tests/core/create_addresses.rs +++ b/evm_arithmetization/src/cpu/kernel/tests/core/create_addresses.rs @@ -20,7 +20,7 @@ fn test_get_create_address() -> Result<()> { let expected_addr = U256::from_big_endian(&hex!("3f09c73a5ed19289fb9bdc72f1742566df146f56")); let initial_stack = vec![retaddr, nonce, sender]; - let mut interpreter: Interpreter = Interpreter::new(get_create_address, initial_stack); + let mut interpreter: Interpreter = Interpreter::new(get_create_address, initial_stack, None); interpreter.run()?; assert_eq!(interpreter.stack(), &[expected_addr]); @@ -106,7 +106,8 @@ fn test_get_create2_address() -> Result<()> { } in create2_test_cases() { let initial_stack = vec![retaddr, salt, U256::from(code_hash.0), sender]; - let mut interpreter: Interpreter = Interpreter::new(get_create2_address, initial_stack); + let mut interpreter: Interpreter = + Interpreter::new(get_create2_address, initial_stack, None); interpreter.run()?; assert_eq!(interpreter.stack(), &[expected_addr]); diff --git a/evm_arithmetization/src/cpu/kernel/tests/core/intrinsic_gas.rs b/evm_arithmetization/src/cpu/kernel/tests/core/intrinsic_gas.rs index 42f9ef958..0298affee 100644 --- a/evm_arithmetization/src/cpu/kernel/tests/core/intrinsic_gas.rs +++ b/evm_arithmetization/src/cpu/kernel/tests/core/intrinsic_gas.rs @@ -16,13 +16,14 @@ fn test_intrinsic_gas() -> Result<()> { // Contract creation transaction. let initial_stack = vec![0xdeadbeefu32.into()]; - let mut interpreter: Interpreter = Interpreter::new(intrinsic_gas, initial_stack.clone()); + let mut interpreter: Interpreter = + Interpreter::new(intrinsic_gas, initial_stack.clone(), None); interpreter.set_global_metadata_field(GlobalMetadata::ContractCreation, U256::one()); interpreter.run()?; assert_eq!(interpreter.stack(), vec![(GAS_TX + GAS_TXCREATE).into()]); // Message transaction. - let mut interpreter: Interpreter = Interpreter::new(intrinsic_gas, initial_stack); + let mut interpreter: Interpreter = Interpreter::new(intrinsic_gas, initial_stack, None); interpreter.set_txn_field(NormalizedTxnField::To, 123.into()); interpreter.run()?; assert_eq!(interpreter.stack(), vec![GAS_TX.into()]); diff --git a/evm_arithmetization/src/cpu/kernel/tests/core/jumpdest_analysis.rs b/evm_arithmetization/src/cpu/kernel/tests/core/jumpdest_analysis.rs index 61a580de7..b0ef17033 100644 --- a/evm_arithmetization/src/cpu/kernel/tests/core/jumpdest_analysis.rs +++ b/evm_arithmetization/src/cpu/kernel/tests/core/jumpdest_analysis.rs @@ -47,7 +47,7 @@ fn test_jumpdest_analysis() -> Result<()> { .chain(std::iter::once(true)), ); - let mut interpreter: Interpreter = Interpreter::new(jumpdest_analysis, vec![]); + let mut interpreter: Interpreter = Interpreter::new(jumpdest_analysis, vec![], None); let code_len = code.len(); interpreter.set_code(CONTEXT, code); @@ -134,7 +134,7 @@ fn test_packed_verification() -> Result<()> { U256::one(), ]; let mut interpreter: Interpreter = - Interpreter::new(write_table_if_jumpdest, initial_stack.clone()); + Interpreter::new(write_table_if_jumpdest, initial_stack.clone(), None); interpreter.set_code(CONTEXT, code.clone()); interpreter.generation_state.jumpdest_table = Some(HashMap::from([(3, vec![1, 33])])); @@ -147,7 +147,7 @@ fn test_packed_verification() -> Result<()> { for i in 1..=32 { code[i] += 1; let mut interpreter: Interpreter = - Interpreter::new(write_table_if_jumpdest, initial_stack.clone()); + Interpreter::new(write_table_if_jumpdest, initial_stack.clone(), None); interpreter.set_code(CONTEXT, code.clone()); interpreter.generation_state.jumpdest_table = Some(HashMap::from([(3, vec![1, 33])])); @@ -196,7 +196,7 @@ fn test_verify_non_jumpdest() -> Result<()> { // jumpdest for i in 8..code_len - 1 { code[i] += 1; - let mut interpreter: Interpreter = Interpreter::new(verify_non_jumpdest, vec![]); + let mut interpreter: Interpreter = Interpreter::new(verify_non_jumpdest, vec![], None); interpreter.generation_state.registers.context = CONTEXT; interpreter.set_code(CONTEXT, code.clone()); diff --git a/evm_arithmetization/src/cpu/kernel/tests/ecc/curve_ops.rs b/evm_arithmetization/src/cpu/kernel/tests/ecc/curve_ops.rs index 047c2b8f5..56f2c48bc 100644 --- a/evm_arithmetization/src/cpu/kernel/tests/ecc/curve_ops.rs +++ b/evm_arithmetization/src/cpu/kernel/tests/ecc/curve_ops.rs @@ -1,6 +1,5 @@ #[cfg(test)] mod bn { - use anyhow::Result; use ethereum_types::U256; use plonky2::field::goldilocks_field::GoldilocksField as F; @@ -186,7 +185,7 @@ mod bn { let mut initial_stack = u256ify(["0xdeadbeef"])?; initial_stack.push(k); - let mut int: Interpreter = Interpreter::new(glv, initial_stack); + let mut int: Interpreter = Interpreter::new(glv, initial_stack, None); int.run()?; assert_eq!(line, int.stack()); @@ -204,7 +203,7 @@ mod bn { "0x10d7cf0621b6e42c1dbb421f5ef5e1936ca6a87b38198d1935be31e28821d171", "0x11b7d55f16aaac07de9a0ed8ac2e8023570dbaa78571fc95e553c4b3ba627689", ])?; - let mut int: Interpreter = Interpreter::new(precompute, initial_stack); + let mut int: Interpreter = Interpreter::new(precompute, initial_stack, None); int.run()?; let mut computed_table = Vec::new(); @@ -357,7 +356,7 @@ mod secp { let mut initial_stack = u256ify(["0xdeadbeef"])?; initial_stack.push(k); - let mut int: Interpreter = Interpreter::new(glv, initial_stack); + let mut int: Interpreter = Interpreter::new(glv, initial_stack, None); int.run()?; assert_eq!(line, int.stack()); diff --git a/evm_arithmetization/src/cpu/kernel/tests/exp.rs b/evm_arithmetization/src/cpu/kernel/tests/exp.rs index 660ec538a..746ba98b3 100644 --- a/evm_arithmetization/src/cpu/kernel/tests/exp.rs +++ b/evm_arithmetization/src/cpu/kernel/tests/exp.rs @@ -17,7 +17,7 @@ fn test_exp() -> Result<()> { // Random input let initial_stack = vec![0xDEADBEEFu32.into(), b, a]; - let mut interpreter: Interpreter = Interpreter::new(0, initial_stack.clone()); + let mut interpreter: Interpreter = Interpreter::new(0, initial_stack.clone(), None); let stack_with_kernel = run_interpreter::(exp, initial_stack)?.stack(); diff --git a/evm_arithmetization/src/cpu/kernel/tests/init_exc_stop.rs b/evm_arithmetization/src/cpu/kernel/tests/init_exc_stop.rs new file mode 100644 index 000000000..e2d5fb41d --- /dev/null +++ b/evm_arithmetization/src/cpu/kernel/tests/init_exc_stop.rs @@ -0,0 +1,186 @@ +use std::collections::HashMap; + +use ethereum_types::U256; +use keccak_hash::keccak; +use keccak_hash::H256; +use mpt_trie::partial_trie::HashedPartialTrie; +use mpt_trie::partial_trie::PartialTrie; +use plonky2::field::goldilocks_field::GoldilocksField as F; + +use crate::cpu::kernel::aggregator::KERNEL; +use crate::cpu::kernel::interpreter::Interpreter; +use crate::generation::state::State; +use crate::generation::TrieInputs; +use crate::generation::NUM_EXTRA_CYCLES_AFTER; +use crate::generation::NUM_EXTRA_CYCLES_BEFORE; +use crate::memory::segments::Segment; +use crate::proof::BlockMetadata; +use crate::proof::TrieRoots; +use crate::testing_utils::beacon_roots_account_nibbles; +use crate::testing_utils::beacon_roots_contract_from_storage; +use crate::testing_utils::ger_account_nibbles; +use crate::testing_utils::init_logger; +use crate::testing_utils::preinitialized_state_and_storage_tries; +use crate::testing_utils::update_beacon_roots_account_storage; +use crate::testing_utils::GLOBAL_EXIT_ROOT_ACCOUNT; +use crate::witness::memory::MemoryAddress; +use crate::witness::state::RegistersState; +use crate::{proof::BlockHashes, GenerationInputs, Node}; + +enum RegistersIdx { + ProgramCounter = 0, + IsKernel = 1, + _StackLen = 2, + _StackTop = 3, + _Context = 4, + _GasUsed = 5, +} + +const REGISTERS_LEN: usize = 6; + +// Test to check NUM_EXTRA_CYCLES_BEFORE and NUM_EXTRA_CYCLES_AFTER +#[test] +fn test_init_exc_stop() { + init_logger(); + + let block_metadata = BlockMetadata { + block_number: 1.into(), + block_timestamp: 0x1234.into(), + ..Default::default() + }; + + let (state_trie_before, storage_tries) = preinitialized_state_and_storage_tries().unwrap(); + let mut beacon_roots_account_storage = storage_tries[0].1.clone(); + let transactions_trie = HashedPartialTrie::from(Node::Empty); + let receipts_trie = HashedPartialTrie::from(Node::Empty); + + let expected_state_trie_after = { + update_beacon_roots_account_storage( + &mut beacon_roots_account_storage, + block_metadata.block_timestamp, + block_metadata.parent_beacon_block_root, + ) + .unwrap(); + let beacon_roots_account = + beacon_roots_contract_from_storage(&beacon_roots_account_storage); + + let mut expected_state_trie_after = HashedPartialTrie::from(Node::Empty); + expected_state_trie_after + .insert( + beacon_roots_account_nibbles(), + rlp::encode(&beacon_roots_account).to_vec(), + ) + .unwrap(); + expected_state_trie_after + .insert( + ger_account_nibbles(), + rlp::encode(&GLOBAL_EXIT_ROOT_ACCOUNT).to_vec(), + ) + .unwrap(); + expected_state_trie_after + }; + + let mut contract_code = HashMap::new(); + contract_code.insert(keccak(vec![]), vec![]); + + let trie_roots_after = TrieRoots { + state_root: expected_state_trie_after.hash(), + transactions_root: transactions_trie.hash(), + receipts_root: receipts_trie.hash(), + }; + + let inputs = GenerationInputs { + signed_txns: vec![], + withdrawals: vec![], + tries: TrieInputs { + state_trie: state_trie_before, + transactions_trie, + receipts_trie, + storage_tries, + }, + trie_roots_after, + contract_code, + checkpoint_state_trie_root: HashedPartialTrie::from(Node::Empty).hash(), + block_metadata, + txn_number_before: 0.into(), + gas_used_before: 0.into(), + gas_used_after: 0.into(), + block_hashes: BlockHashes { + prev_hashes: vec![H256::default(); 256], + cur_hash: H256::default(), + }, + global_exit_roots: vec![], + }; + let initial_stack = vec![]; + let initial_offset = KERNEL.global_labels["init"]; + let mut interpreter: Interpreter = + Interpreter::new_with_generation_inputs(initial_offset, initial_stack, &inputs, None); + interpreter.halt_offsets = vec![KERNEL.global_labels["main"]]; + interpreter.set_is_kernel(true); + interpreter.run().expect("Running dummy init failed."); + + assert_eq!( + interpreter.get_clock(), + NUM_EXTRA_CYCLES_BEFORE, + "NUM_EXTRA_CYCLES_BEFORE is set incorrectly." + ); + + // The registers should not have changed, besides the stack top. + let expected_registers = RegistersState { + stack_top: interpreter.get_registers().stack_top, + check_overflow: interpreter.get_registers().check_overflow, + ..RegistersState::new() + }; + + assert_eq!( + interpreter.get_registers(), + expected_registers, + "Incorrect registers for dummy run." + ); + + let exc_stop_offset = KERNEL.global_labels["exc_stop"]; + + let pc_u256 = U256::from(interpreter.get_registers().program_counter); + let exit_info = pc_u256 + (U256::one() << 32); + interpreter.push(exit_info).unwrap(); + interpreter.get_mut_registers().program_counter = exc_stop_offset; + interpreter.halt_offsets = vec![KERNEL.global_labels["halt_final"]]; + interpreter.set_is_kernel(true); + interpreter.clock = 0; + + // Set the program counter and `is_kernel` at the end of the execution. The + // `registers_before` and `registers_after` are stored contiguously in the + // `RegistersState` segment. We need to update `registers_after` here, hence the + // offset by `RegistersData::SIZE`. + let regs_to_set = [ + ( + MemoryAddress { + context: 0, + segment: Segment::RegistersStates.unscale(), + virt: REGISTERS_LEN + RegistersIdx::ProgramCounter as usize, + }, + pc_u256, + ), + ( + MemoryAddress { + context: 0, + segment: Segment::RegistersStates.unscale(), + virt: REGISTERS_LEN + RegistersIdx::IsKernel as usize, + }, + U256::one(), + ), + ]; + interpreter.set_memory_multi_addresses(®s_to_set); + + interpreter.run().expect("Running dummy exc_stop failed."); + + // The "-2" comes from the fact that: + // - we stop 1 cycle before the max, to allow for one padding row, which is + // needed for CPU STARK. + // - we need one additional cycle to enter `exc_stop`. + assert_eq!( + interpreter.get_clock(), + NUM_EXTRA_CYCLES_AFTER - 2, + "NUM_EXTRA_CYCLES_AFTER is set incorrectly." + ); +} diff --git a/evm_arithmetization/src/cpu/kernel/tests/log.rs b/evm_arithmetization/src/cpu/kernel/tests/log.rs index aa65f1b3e..9a4cad634 100644 --- a/evm_arithmetization/src/cpu/kernel/tests/log.rs +++ b/evm_arithmetization/src/cpu/kernel/tests/log.rs @@ -26,7 +26,7 @@ fn test_log_0() -> Result<()> { U256::from_big_endian(&address.to_fixed_bytes()), ]; - let mut interpreter: Interpreter = Interpreter::new(logs_entry, initial_stack); + let mut interpreter: Interpreter = Interpreter::new(logs_entry, initial_stack, None); interpreter.set_global_metadata_field(GlobalMetadata::LogsLen, 0.into()); interpreter.set_global_metadata_field(GlobalMetadata::LogsDataLen, 0.into()); @@ -70,7 +70,7 @@ fn test_log_2() -> Result<()> { U256::from_big_endian(&address.to_fixed_bytes()), ]; - let mut interpreter: Interpreter = Interpreter::new(logs_entry, initial_stack); + let mut interpreter: Interpreter = Interpreter::new(logs_entry, initial_stack, None); interpreter.set_global_metadata_field(GlobalMetadata::LogsLen, 2.into()); interpreter.set_global_metadata_field(GlobalMetadata::LogsDataLen, 5.into()); @@ -134,7 +134,7 @@ fn test_log_4() -> Result<()> { U256::from_big_endian(&address.to_fixed_bytes()), ]; - let mut interpreter: Interpreter = Interpreter::new(logs_entry, initial_stack); + let mut interpreter: Interpreter = Interpreter::new(logs_entry, initial_stack, None); interpreter.set_global_metadata_field(GlobalMetadata::LogsLen, 2.into()); interpreter.set_global_metadata_field(GlobalMetadata::LogsDataLen, 5.into()); @@ -197,7 +197,7 @@ fn test_log_5() -> Result<()> { U256::from_big_endian(&address.to_fixed_bytes()), ]; - let mut interpreter: Interpreter = Interpreter::new(logs_entry, initial_stack); + let mut interpreter: Interpreter = Interpreter::new(logs_entry, initial_stack, None); interpreter.set_global_metadata_field(GlobalMetadata::LogsLen, 0.into()); interpreter.set_global_metadata_field(GlobalMetadata::LogsDataLen, 0.into()); diff --git a/evm_arithmetization/src/cpu/kernel/tests/mcopy.rs b/evm_arithmetization/src/cpu/kernel/tests/mcopy.rs index 07ca1bda8..87f00bada 100644 --- a/evm_arithmetization/src/cpu/kernel/tests/mcopy.rs +++ b/evm_arithmetization/src/cpu/kernel/tests/mcopy.rs @@ -22,7 +22,7 @@ fn test_mcopy( let kexit_info = U256::from(0xdeadbeefu32) + (U256::from(u64::from(true)) << 32); let initial_stack = vec![size.into(), offset.into(), dest_offset.into(), kexit_info]; - let mut interpreter: Interpreter = Interpreter::new(sys_mcopy, initial_stack); + let mut interpreter: Interpreter = Interpreter::new(sys_mcopy, initial_stack, None); interpreter.set_context_metadata_field( 0, ContextMetadata::GasLimit, diff --git a/evm_arithmetization/src/cpu/kernel/tests/mod.rs b/evm_arithmetization/src/cpu/kernel/tests/mod.rs index c44eb8454..6219773d0 100644 --- a/evm_arithmetization/src/cpu/kernel/tests/mod.rs +++ b/evm_arithmetization/src/cpu/kernel/tests/mod.rs @@ -11,6 +11,7 @@ mod core; mod ecc; mod exp; mod hash; +mod init_exc_stop; mod kernel_consistency; mod log; mod mcopy; @@ -59,7 +60,7 @@ pub(crate) fn run_interpreter( initial_offset: usize, initial_stack: Vec, ) -> anyhow::Result> { - let mut interpreter = Interpreter::new(initial_offset, initial_stack); + let mut interpreter = Interpreter::new(initial_offset, initial_stack, None); interpreter.run()?; Ok(interpreter) } @@ -78,7 +79,7 @@ pub(crate) fn run_interpreter_with_memory( let label = KERNEL.global_labels[&memory_init.label]; let mut stack = memory_init.stack; stack.reverse(); - let mut interpreter = Interpreter::new(label, stack); + let mut interpreter = Interpreter::new(label, stack, None); for (pointer, data) in memory_init.memory { for (i, term) in data.iter().enumerate() { interpreter.generation_state.memory.set( diff --git a/evm_arithmetization/src/cpu/kernel/tests/mpt/delete.rs b/evm_arithmetization/src/cpu/kernel/tests/mpt/delete.rs index 15a3a36cd..b954c1f1f 100644 --- a/evm_arithmetization/src/cpu/kernel/tests/mpt/delete.rs +++ b/evm_arithmetization/src/cpu/kernel/tests/mpt/delete.rs @@ -12,6 +12,8 @@ use crate::cpu::kernel::tests::account_code::initialize_mpts; use crate::cpu::kernel::tests::mpt::{nibbles_64, test_account_1_rlp, test_account_2}; use crate::generation::mpt::AccountRlp; use crate::generation::TrieInputs; +use crate::memory::segments::Segment; +use crate::util::h2u; use crate::Node; #[test] @@ -100,11 +102,46 @@ fn test_state_trie( let mpt_hash_state_trie = KERNEL.global_labels["mpt_hash_state_trie"]; let initial_stack = vec![]; - let mut interpreter: Interpreter = Interpreter::new(0, initial_stack); + let mut interpreter: Interpreter = Interpreter::new(0, initial_stack, None); initialize_mpts(&mut interpreter, &trie_inputs); assert_eq!(interpreter.stack(), vec![]); + // Store initial accounts and storage. + interpreter + .halt_offsets + .push(KERNEL.global_labels["after_store_initial"]); + interpreter.generation_state.registers.program_counter = KERNEL.global_labels["store_initial"]; + interpreter.run().unwrap(); + + assert_eq!(interpreter.stack(), vec![]); + // Set initial tries. + interpreter + .push(0xDEADBEEFu32.into()) + .expect("The stack should not overflow"); + interpreter + .push((Segment::StorageLinkedList as usize + 8).into()) + .expect("The stack should not overflow"); + interpreter + .push((Segment::AccountsLinkedList as usize + 6).into()) + .expect("The stack should not overflow"); + interpreter + .push(interpreter.get_global_metadata_field(GlobalMetadata::StateTrieRoot)) + .unwrap(); + + // Now, set the payload. + interpreter.generation_state.registers.program_counter = + KERNEL.global_labels["mpt_set_payload"]; + + interpreter.run()?; + + assert_eq!(interpreter.stack_len(), 2); + + let acc_ptr = interpreter.pop().expect("The stack should not be empty") - 2; + let storage_ptr = interpreter.pop().expect("The stack should not be empty") - 3; + interpreter.set_global_metadata_field(GlobalMetadata::InitialAccountsLinkedListLen, acc_ptr); + interpreter.set_global_metadata_field(GlobalMetadata::InitialStorageLinkedListLen, storage_ptr); + // Next, execute mpt_insert_state_trie. interpreter.generation_state.registers.program_counter = mpt_insert_state_trie; let trie_data = interpreter.get_trie_data_mut(); @@ -140,6 +177,14 @@ fn test_state_trie( interpreter.stack() ); + // Now, run `set_final_tries` so that the trie roots are correct. + interpreter + .push(0xDEADBEEFu32.into()) + .expect("The stack should not overflow"); + interpreter.generation_state.registers.program_counter = + KERNEL.global_labels["set_final_tries"]; + interpreter.run()?; + // Next, execute mpt_delete, deleting the account we just inserted. let state_trie_ptr = interpreter.get_global_metadata_field(GlobalMetadata::StateTrieRoot); interpreter.generation_state.registers.program_counter = mpt_delete; @@ -159,20 +204,32 @@ fn test_state_trie( let state_trie_ptr = interpreter.pop().expect("The stack should not be empty"); interpreter.set_global_metadata_field(GlobalMetadata::StateTrieRoot, state_trie_ptr); - // Now, execute mpt_hash_state_trie. + // Now, execute `mpt_hash_state_trie`. + let expected_state_trie_hash = state_trie.hash(); + interpreter.set_global_metadata_field( + GlobalMetadata::StateTrieRootDigestAfter, + h2u(expected_state_trie_hash), + ); + interpreter.generation_state.registers.program_counter = mpt_hash_state_trie; + interpreter .push(0xDEADBEEFu32.into()) .expect("The stack should not overflow"); interpreter - .push(1.into()) // Initial length of the trie data segment, unused. + .push(interpreter.get_global_metadata_field(GlobalMetadata::TrieDataSize)) // Initial trie data segment size, unused. .expect("The stack should not overflow"); interpreter.run()?; - let state_trie_hash = - H256::from_uint(&interpreter.pop().expect("The stack should not be empty")); - let expected_state_trie_hash = state_trie.hash(); - assert_eq!(state_trie_hash, expected_state_trie_hash); + assert_eq!( + interpreter.stack().len(), + 2, + "Expected 2 items on stack after hashing, found {:?}", + interpreter.stack() + ); + + let hash = interpreter.pop().expect("The stack should not be empty"); + assert_eq!(hash, h2u(expected_state_trie_hash)); Ok(()) } diff --git a/evm_arithmetization/src/cpu/kernel/tests/mpt/hash.rs b/evm_arithmetization/src/cpu/kernel/tests/mpt/hash.rs index 18e3ae1fe..59c5fb384 100644 --- a/evm_arithmetization/src/cpu/kernel/tests/mpt/hash.rs +++ b/evm_arithmetization/src/cpu/kernel/tests/mpt/hash.rs @@ -112,12 +112,12 @@ fn test_state_trie(trie_inputs: TrieInputs) -> Result<()> { let mpt_hash_state_trie = KERNEL.global_labels["mpt_hash_state_trie"]; let initial_stack = vec![]; - let mut interpreter: Interpreter = Interpreter::new(0, initial_stack); + let mut interpreter: Interpreter = Interpreter::new(0, initial_stack, None); initialize_mpts(&mut interpreter, &trie_inputs); assert_eq!(interpreter.stack(), vec![]); - // Now, execute mpt_hash_state_trie. + // Now, execute `mpt_hash_state_trie`. interpreter.generation_state.registers.program_counter = mpt_hash_state_trie; interpreter .push(0xDEADBEEFu32.into()) diff --git a/evm_arithmetization/src/cpu/kernel/tests/mpt/hex_prefix.rs b/evm_arithmetization/src/cpu/kernel/tests/mpt/hex_prefix.rs index ac93153f4..5f9c4252d 100644 --- a/evm_arithmetization/src/cpu/kernel/tests/mpt/hex_prefix.rs +++ b/evm_arithmetization/src/cpu/kernel/tests/mpt/hex_prefix.rs @@ -16,7 +16,7 @@ fn hex_prefix_even_nonterminated() -> Result<()> { let num_nibbles = 6.into(); let rlp_pos = U256::from(Segment::RlpRaw as usize); let initial_stack = vec![retdest, terminated, packed_nibbles, num_nibbles, rlp_pos]; - let mut interpreter: Interpreter = Interpreter::new(hex_prefix, initial_stack); + let mut interpreter: Interpreter = Interpreter::new(hex_prefix, initial_stack, None); interpreter.run()?; assert_eq!(interpreter.stack(), vec![rlp_pos + U256::from(5)]); @@ -44,7 +44,7 @@ fn hex_prefix_odd_terminated() -> Result<()> { let num_nibbles = 5.into(); let rlp_pos = U256::from(Segment::RlpRaw as usize); let initial_stack = vec![retdest, terminated, packed_nibbles, num_nibbles, rlp_pos]; - let mut interpreter: Interpreter = Interpreter::new(hex_prefix, initial_stack); + let mut interpreter: Interpreter = Interpreter::new(hex_prefix, initial_stack, None); interpreter.run()?; assert_eq!(interpreter.stack(), vec![rlp_pos + U256::from(4)]); @@ -71,7 +71,7 @@ fn hex_prefix_odd_terminated_tiny() -> Result<()> { let num_nibbles = 1.into(); let rlp_pos = U256::from(Segment::RlpRaw as usize + 2); let initial_stack = vec![retdest, terminated, packed_nibbles, num_nibbles, rlp_pos]; - let mut interpreter: Interpreter = Interpreter::new(hex_prefix, initial_stack); + let mut interpreter: Interpreter = Interpreter::new(hex_prefix, initial_stack, None); interpreter.run()?; assert_eq!( interpreter.stack(), diff --git a/evm_arithmetization/src/cpu/kernel/tests/mpt/insert.rs b/evm_arithmetization/src/cpu/kernel/tests/mpt/insert.rs index d25138631..e5735624f 100644 --- a/evm_arithmetization/src/cpu/kernel/tests/mpt/insert.rs +++ b/evm_arithmetization/src/cpu/kernel/tests/mpt/insert.rs @@ -4,6 +4,7 @@ use mpt_trie::nibbles::Nibbles; use mpt_trie::partial_trie::{HashedPartialTrie, PartialTrie}; use plonky2::field::goldilocks_field::GoldilocksField as F; +use super::test_account_1_empty_storage_rlp; use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; use crate::cpu::kernel::interpreter::Interpreter; @@ -13,6 +14,8 @@ use crate::cpu::kernel::tests::mpt::{ }; use crate::generation::mpt::AccountRlp; use crate::generation::TrieInputs; +use crate::memory::segments::Segment; +use crate::util::h2u; use crate::Node; #[test] @@ -25,7 +28,7 @@ fn mpt_insert_leaf_identical_keys() -> Result<()> { let key = nibbles_64(0xABC); let state_trie = Node::Leaf { nibbles: key, - value: test_account_1_rlp(), + value: test_account_1_empty_storage_rlp(), } .into(); test_state_trie(state_trie, key, test_account_2()) @@ -172,14 +175,46 @@ fn test_state_trie( storage_tries: vec![], }; let mpt_insert_state_trie = KERNEL.global_labels["mpt_insert_state_trie"]; - let mpt_hash_state_trie = KERNEL.global_labels["mpt_hash_state_trie"]; + let check_state_trie = KERNEL.global_labels["check_final_state_trie"]; let initial_stack = vec![]; - let mut interpreter: Interpreter = Interpreter::new(0, initial_stack); + let mut interpreter: Interpreter = Interpreter::new(0, initial_stack, None); initialize_mpts(&mut interpreter, &trie_inputs); assert_eq!(interpreter.stack(), vec![]); + // Store initial accounts and storage. + interpreter + .halt_offsets + .push(KERNEL.global_labels["after_store_initial"]); + interpreter.generation_state.registers.program_counter = KERNEL.global_labels["store_initial"]; + interpreter.run()?; + + // Set initial tries. + interpreter + .push(0xDEADBEEFu32.into()) + .expect("The stack should not overflow"); + interpreter + .push((Segment::StorageLinkedList as usize + 8).into()) + .expect("The stack should not overflow"); + interpreter + .push((Segment::AccountsLinkedList as usize + 6).into()) + .expect("The stack should not overflow"); + interpreter + .push(interpreter.get_global_metadata_field(GlobalMetadata::StateTrieRoot)) + .unwrap(); + + // Now, set the payload. + interpreter.generation_state.registers.program_counter = + KERNEL.global_labels["mpt_set_payload"]; + + interpreter.run()?; + + let acc_ptr = interpreter.pop().expect("The stack should not be empty") - 2; + let storage_ptr = interpreter.pop().expect("The stack should not be empty") - 3; + interpreter.set_global_metadata_field(GlobalMetadata::InitialAccountsLinkedListLen, acc_ptr); + interpreter.set_global_metadata_field(GlobalMetadata::InitialStorageLinkedListLen, storage_ptr); + // Next, execute mpt_insert_state_trie. interpreter.generation_state.registers.program_counter = mpt_insert_state_trie; let trie_data = interpreter.get_trie_data_mut(); @@ -216,27 +251,31 @@ fn test_state_trie( interpreter.stack() ); - // Now, execute mpt_hash_state_trie. - interpreter.generation_state.registers.program_counter = mpt_hash_state_trie; + // Now, execute `mpt_hash_state_trie` and check the hash value (both are done + // under `check_state_trie`). + state_trie.insert(k, rlp::encode(&account).to_vec())?; + let expected_state_trie_hash = state_trie.hash(); + interpreter.set_global_metadata_field( + GlobalMetadata::StateTrieRootDigestAfter, + h2u(expected_state_trie_hash), + ); + + interpreter.generation_state.registers.program_counter = check_state_trie; interpreter - .push(0xDEADBEEFu32.into()) - .expect("The stack should not overflow"); + .halt_offsets + .push(KERNEL.global_labels["check_txn_trie"]); + interpreter - .push(1.into()) // Initial length of the trie data segment, unused. + .push(interpreter.get_global_metadata_field(GlobalMetadata::TrieDataSize)) // Initial trie data segment size, unused. .expect("The stack should not overflow"); interpreter.run()?; assert_eq!( interpreter.stack().len(), - 2, + 1, "Expected 2 items on stack after hashing, found {:?}", interpreter.stack() ); - let hash = H256::from_uint(&interpreter.stack()[1]); - - state_trie.insert(k, rlp::encode(&account).to_vec())?; - let expected_state_trie_hash = state_trie.hash(); - assert_eq!(hash, expected_state_trie_hash); Ok(()) } diff --git a/evm_arithmetization/src/cpu/kernel/tests/mpt/linked_list.rs b/evm_arithmetization/src/cpu/kernel/tests/mpt/linked_list.rs new file mode 100644 index 000000000..69f7061bc --- /dev/null +++ b/evm_arithmetization/src/cpu/kernel/tests/mpt/linked_list.rs @@ -0,0 +1,660 @@ +use std::collections::HashSet; + +use anyhow::Result; +use env_logger::try_init_from_env; +use env_logger::Env; +use env_logger::DEFAULT_FILTER_ENV; +use ethereum_types::{Address, H160, U256}; +use itertools::Itertools; +use plonky2::field::goldilocks_field::GoldilocksField as F; +use rand::{thread_rng, Rng}; + +use crate::cpu::kernel::aggregator::KERNEL; +use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; +use crate::cpu::kernel::interpreter::Interpreter; +use crate::generation::linked_list::LinkedList; +use crate::memory::segments::Segment; +use crate::witness::memory::MemoryAddress; + +fn init_logger() { + let _ = try_init_from_env(Env::default().filter_or(DEFAULT_FILTER_ENV, "debug")); +} + +#[test] +fn test_init_linked_lists() -> Result<()> { + init_logger(); + + let interpreter = Interpreter::::new(0, vec![], None); + + // Check the initial accounts linked list + let acc_addr_list: Vec = (0..4) + .map(|i| { + interpreter + .generation_state + .memory + .get_with_init(MemoryAddress::new(0, Segment::AccountsLinkedList, i)) + }) + .collect(); + assert_eq!( + vec![ + U256::MAX, + U256::zero(), + U256::zero(), + (Segment::AccountsLinkedList as usize).into(), + ], + acc_addr_list + ); + + // Check the initial storage linked list + let acc_addr_list: Vec = (0..5) + .map(|i| { + interpreter + .generation_state + .memory + .get_with_init(MemoryAddress::new(0, Segment::StorageLinkedList, i)) + }) + .collect(); + assert_eq!( + vec![ + U256::MAX, + U256::zero(), + U256::zero(), + U256::zero(), + (Segment::StorageLinkedList as usize).into(), + ], + acc_addr_list + ); + + Ok(()) +} + +#[test] +fn test_list_iterator() -> Result<()> { + init_logger(); + + let interpreter = Interpreter::::new(0, vec![], None); + + // test the list iterator + let accounts_mem = interpreter + .generation_state + .memory + .get_preinit_memory(Segment::AccountsLinkedList); + let mut accounts_list = + LinkedList::from_mem_and_segment(&accounts_mem, Segment::AccountsLinkedList).unwrap(); + + let Some([addr, ptr, ptr_cpy, scaled_pos_1]) = accounts_list.next() else { + return Err(anyhow::Error::msg("Couldn't get value")); + }; + assert_eq!(addr, U256::MAX); + assert_eq!(ptr, U256::zero()); + assert_eq!(ptr_cpy, U256::zero()); + assert_eq!(scaled_pos_1, (Segment::AccountsLinkedList as usize).into()); + let Some([addr, ptr, ptr_cpy, scaled_pos_1]) = accounts_list.next() else { + return Err(anyhow::Error::msg("Couldn't get value")); + }; + assert_eq!(addr, U256::MAX); + assert_eq!(ptr, U256::zero()); + assert_eq!(ptr_cpy, U256::zero()); + assert_eq!(scaled_pos_1, (Segment::AccountsLinkedList as usize).into()); + + let accounts_mem = interpreter + .generation_state + .memory + .get_preinit_memory(Segment::StorageLinkedList); + let mut storage_list = + LinkedList::from_mem_and_segment(&accounts_mem, Segment::StorageLinkedList).unwrap(); + let Some([addr, key, ptr, ptr_cpy, scaled_pos_1]) = storage_list.next() else { + return Err(anyhow::Error::msg("Couldn't get value")); + }; + assert_eq!(addr, U256::MAX); + assert_eq!(key, U256::zero()); + assert_eq!(ptr, U256::zero()); + assert_eq!(ptr_cpy, U256::zero()); + assert_eq!(scaled_pos_1, (Segment::StorageLinkedList as usize).into()); + let Some([addr, _key, ptr, ptr_cpy, scaled_pos_1]) = storage_list.next() else { + return Err(anyhow::Error::msg("Couldn't get value")); + }; + assert_eq!(addr, U256::MAX); + assert_eq!(ptr, U256::zero()); + assert_eq!(ptr_cpy, U256::zero()); + assert_eq!(scaled_pos_1, (Segment::StorageLinkedList as usize).into()); + + Ok(()) +} + +#[test] +fn test_insert_account() -> Result<()> { + init_logger(); + + let mut interpreter = Interpreter::::new(0, vec![], None); + + // Initialize the accounts linked list. + let init_accounts_ll = vec![ + Some(U256::MAX), + Some(0.into()), + Some(0.into()), + Some((Segment::AccountsLinkedList as usize).into()), + ]; + let init_len = init_accounts_ll.len(); + interpreter.generation_state.memory.contexts[0].segments + [Segment::AccountsLinkedList.unscale()] + .content = init_accounts_ll; + interpreter.set_global_metadata_field( + GlobalMetadata::AccountsLinkedListNextAvailable, + (Segment::AccountsLinkedList as usize + init_len).into(), + ); + + let insert_account_label = KERNEL.global_labels["insert_account_with_overwrite"]; + + let retaddr = 0xdeadbeefu32.into(); + let mut rng = thread_rng(); + let address: H160 = rng.gen(); + let payload_ptr = U256::from(5); + + assert!(address != H160::zero(), "Cosmic luck or bad RNG?"); + + interpreter + .push(retaddr) + .expect("The stack should not overflow"); + interpreter + .push(payload_ptr) + .expect("The stack should not overflow"); + interpreter + .push(U256::from(address.0.as_slice())) + .expect("The stack should not overflow"); + interpreter.generation_state.registers.program_counter = insert_account_label; + + interpreter.run()?; + + let accounts_mem = interpreter + .generation_state + .memory + .get_preinit_memory(Segment::AccountsLinkedList); + let mut list = + LinkedList::from_mem_and_segment(&accounts_mem, Segment::AccountsLinkedList).unwrap(); + + let Some([addr, ptr, ptr_cpy, scaled_next_pos]) = list.next() else { + return Err(anyhow::Error::msg("Couldn't get value")); + }; + assert_eq!(addr, U256::from(address.0.as_slice())); + assert_eq!(ptr, payload_ptr); + assert_eq!(ptr_cpy, U256::zero()); // ptr_cpy is zero because the trie_data segment is empty + assert_eq!( + scaled_next_pos, + (Segment::AccountsLinkedList as usize).into() + ); + let Some([addr, ptr, ptr_cpy, scaled_new_pos]) = list.next() else { + return Err(anyhow::Error::msg("Couldn't get value")); + }; + assert_eq!(addr, U256::MAX); + assert_eq!(ptr, U256::zero()); + assert_eq!(ptr_cpy, U256::zero()); + assert_eq!( + scaled_new_pos, + (Segment::AccountsLinkedList as usize + 4).into() + ); + Ok(()) +} + +#[test] +fn test_insert_storage() -> Result<()> { + init_logger(); + + let mut interpreter = Interpreter::::new(0, vec![], None); + + // Initialize the storage linked list. + let init_storage_ll = vec![ + Some(U256::MAX), + Some(0.into()), + Some(0.into()), + Some(0.into()), + Some((Segment::StorageLinkedList as usize).into()), + ]; + let init_len = init_storage_ll.len(); + interpreter.generation_state.memory.contexts[0].segments + [Segment::StorageLinkedList.unscale()] + .content = init_storage_ll; + interpreter.set_global_metadata_field( + GlobalMetadata::StorageLinkedListNextAvailable, + (Segment::StorageLinkedList as usize + init_len).into(), + ); + + let insert_account_label = KERNEL.global_labels["insert_slot"]; + + let retaddr = 0xdeadbeefu32.into(); + let mut rng = thread_rng(); + let address: H160 = rng.gen(); + let key: H160 = rng.gen(); + let payload_ptr = U256::from(5); + + assert!(address != H160::zero(), "Cosmic luck or bad RNG?"); + + interpreter + .push(retaddr) + .expect("The stack should not overflow"); + interpreter + .push(payload_ptr) + .expect("The stack should not overflow"); + interpreter + .push(U256::from(key.0.as_slice())) + .expect("The stack should not overflow"); + interpreter + .push(U256::from(address.0.as_slice())) + .expect("The stack should not overflow"); + interpreter.generation_state.registers.program_counter = insert_account_label; + + interpreter.run()?; + assert_eq!(interpreter.stack(), &[payload_ptr]); + + let accounts_mem = interpreter + .generation_state + .memory + .get_preinit_memory(Segment::StorageLinkedList); + let mut list = + LinkedList::from_mem_and_segment(&accounts_mem, Segment::StorageLinkedList).unwrap(); + + let Some([inserted_addr, inserted_key, ptr, ptr_cpy, scaled_next_pos]) = list.next() else { + return Err(anyhow::Error::msg("Couldn't get value")); + }; + assert_eq!(inserted_addr, U256::from(address.0.as_slice())); + assert_eq!(inserted_key, U256::from(key.0.as_slice())); + assert_eq!(ptr, payload_ptr); + assert_eq!(ptr_cpy, U256::zero()); // ptr_cpy is zero because the trie data segment is empty + assert_eq!( + scaled_next_pos, + (Segment::StorageLinkedList as usize).into() + ); + let Some([inserted_addr, inserted_key, ptr, ptr_cpy, scaled_new_pos]) = list.next() else { + return Err(anyhow::Error::msg("Couldn't get value")); + }; + assert_eq!(inserted_addr, U256::MAX); + assert_eq!(inserted_key, U256::zero()); + assert_eq!(ptr, U256::zero()); + assert_eq!(ptr_cpy, U256::zero()); + assert_eq!( + scaled_new_pos, + (Segment::StorageLinkedList as usize + 5).into() + ); + Ok(()) +} + +#[test] +fn test_insert_and_delete_accounts() -> Result<()> { + init_logger(); + + let mut interpreter = Interpreter::::new(0, vec![], None); + + // Initialize the accounts linked list. + let init_accounts_ll = vec![ + Some(U256::MAX), + Some(0.into()), + Some(0.into()), + Some((Segment::AccountsLinkedList as usize).into()), + ]; + let init_len = init_accounts_ll.len(); + interpreter.generation_state.memory.contexts[0].segments + [Segment::AccountsLinkedList.unscale()] + .content = init_accounts_ll; + interpreter.set_global_metadata_field( + GlobalMetadata::AccountsLinkedListNextAvailable, + (Segment::AccountsLinkedList as usize + init_len).into(), + ); + + let insert_account_label = KERNEL.global_labels["insert_account_with_overwrite"]; + + let retaddr = 0xdeadbeefu32.into(); + let n = 10; + let mut addresses = (0..n) + .map(|i| Address::from_low_u64_be(i as u64 + 5)) + .collect::>() + .into_iter() + .collect::>(); + let delta_ptr = 100; + let addr_not_in_list = Address::from_low_u64_be(4); + assert!( + !addresses.contains(&addr_not_in_list), + "Cosmic luck or bad RNG?" + ); + + let offset = Segment::AccountsLinkedList as usize; + // Insert all addresses + for i in 0..n { + let addr = U256::from(addresses[i].0.as_slice()); + interpreter + .push(0xdeadbeefu32.into()) + .expect("The stack should not overflow"); + interpreter + .push(addr + delta_ptr) + .expect("The stack should not overflow"); // ptr = addr + delta_ptr for the sake of the test + interpreter + .push(addr) + .expect("The stack should not overflow"); + interpreter.generation_state.registers.program_counter = insert_account_label; + interpreter.run()?; + + // The copied ptr is at distance 4, the size of an account, from the previous + // copied ptr. + assert_eq!( + interpreter.generation_state.memory.get_with_init( + MemoryAddress::new_bundle(U256::from(offset + 4 * (i + 1) + 2)).unwrap(), + ), + (4 * i).into() + ); + } + + // The next free address in Segment::AccounLinkedList must be offset + (n + + // 1)*4. + assert_eq!( + interpreter.generation_state.memory.get_with_init( + MemoryAddress::new_bundle(U256::from( + GlobalMetadata::AccountsLinkedListNextAvailable as usize + )) + .unwrap(), + ), + U256::from(offset + (n + 1) * 4) + ); + + let search_account_label = KERNEL.global_labels["search_account"]; + // Test for address already in list. + for i in 0..n { + let addr_in_list = U256::from(addresses[i].0.as_slice()); + interpreter + .push(retaddr) + .expect("The stack should not overflow"); + interpreter + .push(addr_in_list) + .expect("The stack should not overflow"); + interpreter.generation_state.registers.program_counter = search_account_label; + interpreter.run()?; + + assert_eq!( + interpreter.pop().expect("The stack can't be empty"), + addr_in_list + delta_ptr + ); + } + + // Test for address not in the list. + interpreter + .push(retaddr) + .expect("The stack should not overflow"); + interpreter + .push(U256::from(addr_not_in_list.0.as_slice()) + delta_ptr) + .expect("The stack should not overflow"); + interpreter + .push(U256::from(addr_not_in_list.0.as_slice())) + .expect("The stack should not overflow"); + interpreter.generation_state.registers.program_counter = insert_account_label; + + interpreter.run()?; + + // Now the list of accounts have address 4 + addresses.push(addr_not_in_list); + + // The next free address in Segment::AccounLinkedList must be offset + (n + + // 2)*4. + assert_eq!( + interpreter.generation_state.memory.get_with_init( + MemoryAddress::new_bundle(U256::from( + GlobalMetadata::AccountsLinkedListNextAvailable as usize + )) + .unwrap(), + ), + U256::from(offset + (n + 2) * 4) + ); + + // Remove all even nodes. + let delete_account_label = KERNEL.global_labels["remove_account"]; + + let mut new_addresses = vec![]; + + for (i, j) in (0..n).tuples() { + // Remove addressese already in list. + let addr_in_list = U256::from(addresses[i].0.as_slice()); + interpreter + .push(retaddr) + .expect("The stack should not overflow"); + interpreter + .push(addr_in_list) + .expect("The stack should not overflow"); + interpreter.generation_state.registers.program_counter = delete_account_label; + interpreter.run()?; + assert!(interpreter.stack().is_empty()); + // we add the non deleted addres to new_addresses + new_addresses.push(addresses[j]); + } + // The last address is not removed. + new_addresses.push(*addresses.last().unwrap()); + + // We need to sort the list in order to properly compare + // the linked list with the interpreter's memory. + new_addresses.sort(); + + let accounts_mem = interpreter + .generation_state + .memory + .get_preinit_memory(Segment::AccountsLinkedList); + let list = + LinkedList::from_mem_and_segment(&accounts_mem, Segment::AccountsLinkedList).unwrap(); + + for (i, [addr, ptr, ptr_cpy, _]) in list.enumerate() { + if addr == U256::MAX { + assert_eq!(addr, U256::MAX); + assert_eq!(ptr, U256::zero()); + assert_eq!(ptr_cpy, U256::zero()); + break; + } + let addr_in_list = U256::from(new_addresses[i].0.as_slice()); + assert_eq!(addr, addr_in_list); + assert_eq!(ptr, addr + delta_ptr); + } + + Ok(()) +} + +#[test] +fn test_insert_and_delete_storage() -> Result<()> { + init_logger(); + + let mut interpreter = Interpreter::::new(0, vec![], None); + + // Initialize the storage linked list. + let init_storage_ll = vec![ + Some(U256::MAX), + Some(0.into()), + Some(0.into()), + Some(0.into()), + Some((Segment::StorageLinkedList as usize).into()), + ]; + let init_len = init_storage_ll.len(); + interpreter.generation_state.memory.contexts[0].segments + [Segment::StorageLinkedList.unscale()] + .content = init_storage_ll; + interpreter.set_global_metadata_field( + GlobalMetadata::StorageLinkedListNextAvailable, + (Segment::StorageLinkedList as usize + init_len).into(), + ); + + let insert_slot_label = KERNEL.global_labels["insert_slot"]; + + let retaddr = 0xdeadbeefu32.into(); + let n = 10; + let mut addresses_and_keys = (0..n) + .map(|i| { + [ + Address::from_low_u64_be(i as u64 + 5), + H160::from_low_u64_be(i as u64 + 6), + ] + }) + .collect::>() + .into_iter() + .collect::>(); + let delta_ptr = 100; + let addr_not_in_list = Address::from_low_u64_be(4); + let key_not_in_list = H160::from_low_u64_be(5); + assert!( + !addresses_and_keys.contains(&[addr_not_in_list, key_not_in_list]), + "Cosmic luck or bad RNG?" + ); + + let offset = Segment::StorageLinkedList as usize; + // Insert all addresses, key pairs + for i in 0..n { + let [addr, key] = addresses_and_keys[i].map(|x| U256::from(x.0.as_slice())); + interpreter + .push(0xdeadbeefu32.into()) + .expect("The stack should not overflow"); + interpreter + .push(addr + delta_ptr) + .expect("The stack should not overflow"); // ptr = addr + delta_ptr for the sake of the test + interpreter + .push(key) + .expect("The stack should not overflow"); + interpreter + .push(addr) + .expect("The stack should not overflow"); + interpreter.generation_state.registers.program_counter = insert_slot_label; + interpreter.run()?; + assert_eq!( + interpreter.pop().expect("The stack can't be empty"), + addr + delta_ptr + ); + // The ptr_cpy must be 0 + assert_eq!( + interpreter.generation_state.memory.get_with_init( + MemoryAddress::new_bundle(U256::from(offset + 5 * (i + 1) + 3)).unwrap(), + ), + i.into() + ); + } + + // The next free node in Segment::StorageLinkedList must be at offset + (n + + // 1)*5. + assert_eq!( + interpreter.generation_state.memory.get_with_init( + MemoryAddress::new_bundle(U256::from( + GlobalMetadata::StorageLinkedListNextAvailable as usize + )) + .unwrap(), + ), + U256::from(offset + (n + 1) * 5) + ); + + // Test for address already in list. + for i in 0..n { + let [addr_in_list, key_in_list] = addresses_and_keys[i].map(|x| U256::from(x.0.as_slice())); + interpreter + .push(retaddr) + .expect("The stack should not overflow"); + interpreter + .push(addr_in_list + delta_ptr) + .expect("The stack should not overflow"); + interpreter + .push(key_in_list) + .expect("The stack should not overflow"); + interpreter + .push(addr_in_list) + .expect("The stack should not overflow"); + interpreter.generation_state.registers.program_counter = insert_slot_label; + interpreter.run()?; + + assert_eq!( + interpreter.pop().expect("The stack can't be empty"), + addr_in_list + delta_ptr + ); + assert_eq!( + interpreter.generation_state.memory.get_with_init( + MemoryAddress::new_bundle(U256::from(offset + 5 * (i + 1) + 3)).unwrap(), + ), + i.into() + ); + } + + // Test for address not in the list. + interpreter + .push(retaddr) + .expect("The stack should not overflow"); + interpreter + .push(U256::from(addr_not_in_list.0.as_slice()) + delta_ptr) + .expect("The stack should not overflow"); + interpreter + .push(U256::from(key_not_in_list.0.as_slice())) + .expect("The stack should not overflow"); + interpreter + .push(U256::from(addr_not_in_list.0.as_slice())) + .expect("The stack should not overflow"); + interpreter.generation_state.registers.program_counter = insert_slot_label; + + interpreter.run()?; + + assert_eq!( + interpreter.pop().expect("The stack can't be empty"), + U256::from(addr_not_in_list.0.as_slice()) + delta_ptr + ); + + // Now the list of accounts have [4, 5] + addresses_and_keys.push([addr_not_in_list, key_not_in_list]); + + // The next free node in Segment::AccounLinkedList must be at offset + (n + + // 2)*5. + assert_eq!( + interpreter.generation_state.memory.get_with_init( + MemoryAddress::new_bundle(U256::from( + GlobalMetadata::StorageLinkedListNextAvailable as usize + )) + .unwrap(), + ), + U256::from(offset + (n + 2) * 5) + ); + + // Remove all even nodes. + let remove_slot_label = KERNEL.global_labels["remove_slot"]; + + let mut new_addresses = vec![]; + + for (i, j) in (0..n).tuples() { + // Test for [address, key] already in list. + let [addr_in_list, key_in_list] = addresses_and_keys[i].map(|x| U256::from(x.0.as_slice())); + interpreter + .push(retaddr) + .expect("The stack should not overflow"); + interpreter + .push(key_in_list) + .expect("The stack should not overflow"); + interpreter + .push(addr_in_list) + .expect("The stack should not overflow"); + interpreter.generation_state.registers.program_counter = remove_slot_label; + interpreter.run()?; + assert!(interpreter.stack().is_empty()); + // we add the non deleted addres to new_addresses + new_addresses.push(addresses_and_keys[j]); + } + // The last address is not removed. + new_addresses.push(*addresses_and_keys.last().unwrap()); + + // We need to sort the list in order to properly compare + // the linked list with the interpreter's memory. + new_addresses.sort(); + + let accounts_mem = interpreter + .generation_state + .memory + .get_preinit_memory(Segment::StorageLinkedList); + let list = LinkedList::from_mem_and_segment(&accounts_mem, Segment::StorageLinkedList).unwrap(); + + for (i, [addr, key, ptr, ptr_cpy, _]) in list.enumerate() { + if addr == U256::MAX { + assert_eq!(addr, U256::MAX); + assert_eq!(key, U256::zero()); + assert_eq!(ptr, U256::zero()); + assert_eq!(ptr_cpy, U256::zero()); + break; + } + let [addr_in_list, key_in_list] = new_addresses[i].map(|x| U256::from(x.0.as_slice())); + assert_eq!(addr, addr_in_list); + assert_eq!(key, key_in_list); + assert_eq!(ptr, addr + delta_ptr); + } + + Ok(()) +} diff --git a/evm_arithmetization/src/cpu/kernel/tests/mpt/load.rs b/evm_arithmetization/src/cpu/kernel/tests/mpt/load.rs index 9aa8a1f0b..9d04700bf 100644 --- a/evm_arithmetization/src/cpu/kernel/tests/mpt/load.rs +++ b/evm_arithmetization/src/cpu/kernel/tests/mpt/load.rs @@ -25,7 +25,7 @@ fn load_all_mpts_empty() -> Result<()> { }; let initial_stack = vec![]; - let mut interpreter: Interpreter = Interpreter::new(0, initial_stack); + let mut interpreter: Interpreter = Interpreter::new(0, initial_stack, None); initialize_mpts(&mut interpreter, &trie_inputs); assert_eq!(interpreter.stack(), vec![]); @@ -62,7 +62,7 @@ fn load_all_mpts_leaf() -> Result<()> { }; let initial_stack = vec![]; - let mut interpreter: Interpreter = Interpreter::new(0, initial_stack); + let mut interpreter: Interpreter = Interpreter::new(0, initial_stack, None); initialize_mpts(&mut interpreter, &trie_inputs); assert_eq!(interpreter.stack(), vec![]); @@ -70,14 +70,20 @@ fn load_all_mpts_leaf() -> Result<()> { assert_eq!( interpreter.get_trie_data(), vec![ - 0.into(), + 0.into(), // First address is unused, so that 0 can be treated as a null pointer. + // The next four elements correspond to the account stored in the linked list. + test_account_1().nonce, + test_account_1().balance, + 0.into(), // pointer to storage trie root before insertion + test_account_1().code_hash.into_uint(), + // Values used for hashing. type_leaf, 3.into(), 0xABC.into(), - 5.into(), // value ptr + 9.into(), // value ptr test_account_1().nonce, test_account_1().balance, - 9.into(), // pointer to storage trie root + 13.into(), // pointer to storage trie root test_account_1().code_hash.into_uint(), // These last two elements encode the storage trie, which is a hash node. (PartialTrieType::Hash as u32).into(), @@ -108,7 +114,7 @@ fn load_all_mpts_hash() -> Result<()> { }; let initial_stack = vec![]; - let mut interpreter: Interpreter = Interpreter::new(0, initial_stack); + let mut interpreter: Interpreter = Interpreter::new(0, initial_stack, None); initialize_mpts(&mut interpreter, &trie_inputs); assert_eq!(interpreter.stack(), vec![]); @@ -146,7 +152,7 @@ fn load_all_mpts_empty_branch() -> Result<()> { }; let initial_stack = vec![]; - let mut interpreter: Interpreter = Interpreter::new(0, initial_stack); + let mut interpreter: Interpreter = Interpreter::new(0, initial_stack, None); initialize_mpts(&mut interpreter, &trie_inputs); assert_eq!(interpreter.stack(), vec![]); @@ -198,7 +204,7 @@ fn load_all_mpts_ext_to_leaf() -> Result<()> { }; let initial_stack = vec![]; - let mut interpreter: Interpreter = Interpreter::new(0, initial_stack); + let mut interpreter: Interpreter = Interpreter::new(0, initial_stack, None); initialize_mpts(&mut interpreter, &trie_inputs); assert_eq!(interpreter.stack(), vec![]); @@ -208,17 +214,23 @@ fn load_all_mpts_ext_to_leaf() -> Result<()> { interpreter.get_trie_data(), vec![ 0.into(), // First address is unused, so that 0 can be treated as a null pointer. + // The next four elements correspond to the account stored in the linked list. + test_account_1().nonce, + test_account_1().balance, + 0.into(), // pointer to storage trie root before insertion + test_account_1().code_hash.into_uint(), + // Values used for hashing. type_extension, 3.into(), // 3 nibbles 0xABC.into(), // key part - 5.into(), // Pointer to the leaf node immediately below. + 9.into(), // Pointer to the leaf node immediately below. type_leaf, 3.into(), // 3 nibbles 0xDEF.into(), // key part - 9.into(), // value pointer + 13.into(), // value pointer test_account_1().nonce, test_account_1().balance, - 13.into(), // pointer to storage trie root + 17.into(), // pointer to storage trie root test_account_1().code_hash.into_uint(), // These last two elements encode the storage trie, which is a hash node. (PartialTrieType::Hash as u32).into(), @@ -244,7 +256,7 @@ fn load_mpt_txn_trie() -> Result<()> { }; let initial_stack = vec![]; - let mut interpreter: Interpreter = Interpreter::new(0, initial_stack); + let mut interpreter: Interpreter = Interpreter::new(0, initial_stack, None); initialize_mpts(&mut interpreter, &trie_inputs); assert_eq!(interpreter.stack(), vec![]); diff --git a/evm_arithmetization/src/cpu/kernel/tests/mpt/mod.rs b/evm_arithmetization/src/cpu/kernel/tests/mpt/mod.rs index 84f64bb7b..17ff18a9b 100644 --- a/evm_arithmetization/src/cpu/kernel/tests/mpt/mod.rs +++ b/evm_arithmetization/src/cpu/kernel/tests/mpt/mod.rs @@ -1,6 +1,7 @@ use ethereum_types::{BigEndianHash, H256, U256}; use mpt_trie::nibbles::Nibbles; use mpt_trie::partial_trie::HashedPartialTrie; +use mpt_trie::partial_trie::PartialTrie; use crate::generation::mpt::AccountRlp; use crate::Node; @@ -9,6 +10,7 @@ mod delete; mod hash; mod hex_prefix; mod insert; +mod linked_list; mod load; mod read; @@ -37,10 +39,23 @@ pub(crate) fn test_account_1() -> AccountRlp { } } +pub(crate) fn test_account_1_empty_storage() -> AccountRlp { + AccountRlp { + nonce: U256::from(1111), + balance: U256::from(2222), + storage_root: HashedPartialTrie::from(Node::Empty).hash(), + code_hash: H256::from_uint(&U256::from(4444)), + } +} + pub(crate) fn test_account_1_rlp() -> Vec { rlp::encode(&test_account_1()).to_vec() } +pub(crate) fn test_account_1_empty_storage_rlp() -> Vec { + rlp::encode(&test_account_1_empty_storage()).to_vec() +} + pub(crate) fn test_account_2() -> AccountRlp { AccountRlp { nonce: U256::from(5555), diff --git a/evm_arithmetization/src/cpu/kernel/tests/mpt/read.rs b/evm_arithmetization/src/cpu/kernel/tests/mpt/read.rs index 9b669a21c..571b45c38 100644 --- a/evm_arithmetization/src/cpu/kernel/tests/mpt/read.rs +++ b/evm_arithmetization/src/cpu/kernel/tests/mpt/read.rs @@ -21,11 +21,11 @@ fn mpt_read() -> Result<()> { let mpt_read = KERNEL.global_labels["mpt_read"]; let initial_stack = vec![]; - let mut interpreter: Interpreter = Interpreter::new(0, initial_stack); + let mut interpreter: Interpreter = Interpreter::new(0, initial_stack, None); initialize_mpts(&mut interpreter, &trie_inputs); assert_eq!(interpreter.stack(), vec![]); - // Now, execute mpt_read on the state trie. + // Now, execute `mpt_read` on the state trie. interpreter.generation_state.registers.program_counter = mpt_read; interpreter .push(0xdeadbeefu32.into()) diff --git a/evm_arithmetization/src/cpu/kernel/tests/packing.rs b/evm_arithmetization/src/cpu/kernel/tests/packing.rs index d487fd66a..79cffaf4d 100644 --- a/evm_arithmetization/src/cpu/kernel/tests/packing.rs +++ b/evm_arithmetization/src/cpu/kernel/tests/packing.rs @@ -16,7 +16,7 @@ fn test_mstore_unpacking() -> Result<()> { let addr = (Segment::TxnData as u64).into(); let initial_stack = vec![retdest, len, value, addr]; - let mut interpreter: Interpreter = Interpreter::new(mstore_unpacking, initial_stack); + let mut interpreter: Interpreter = Interpreter::new(mstore_unpacking, initial_stack, None); interpreter.run()?; assert_eq!(interpreter.stack(), vec![addr + U256::from(4)]); diff --git a/evm_arithmetization/src/cpu/kernel/tests/receipt.rs b/evm_arithmetization/src/cpu/kernel/tests/receipt.rs index 0dfdefed0..4c688d3ef 100644 --- a/evm_arithmetization/src/cpu/kernel/tests/receipt.rs +++ b/evm_arithmetization/src/cpu/kernel/tests/receipt.rs @@ -48,7 +48,7 @@ fn test_process_receipt() -> Result<()> { leftover_gas, success, ]; - let mut interpreter: Interpreter = Interpreter::new(process_receipt, initial_stack); + let mut interpreter: Interpreter = Interpreter::new(process_receipt, initial_stack, None); interpreter.set_memory_segment( Segment::LogsData, vec![ @@ -59,6 +59,8 @@ fn test_process_receipt() -> Result<()> { 0.into(), // data_len ], ); + interpreter.set_memory_segment(Segment::TrieData, vec![0.into()]); + interpreter.set_global_metadata_field(GlobalMetadata::TrieDataSize, 1.into()); interpreter.set_txn_field(NormalizedTxnField::GasLimit, U256::from(5000)); interpreter.set_memory_segment(Segment::TxnBloom, vec![0.into(); 256]); interpreter.set_memory_segment(Segment::Logs, vec![0.into()]); @@ -69,9 +71,11 @@ fn test_process_receipt() -> Result<()> { let segment_read = interpreter.get_memory_segment(Segment::TrieData); - // The expected TrieData has the form [payload_len, status, cum_gas_used, - // bloom_filter, logs_payload_len, num_logs, [logs]] - let mut expected_trie_data: Vec = vec![323.into(), success, 2000.into()]; + // The expected TrieData has the form [0, payload_len, status, cum_gas_used, + // bloom_filter, logs_payload_len, num_logs, [logs]]. + // The 0 is always the first element of `TrieSegmentData`, as it corresponds to + // the null pointer. + let mut expected_trie_data: Vec = vec![0.into(), 323.into(), success, 2000.into()]; expected_trie_data.extend( expected_bloom .into_iter() @@ -132,7 +136,7 @@ fn test_receipt_encoding() -> Result<()> { // Address at which the encoding is written. let rlp_addr = U256::from(Segment::RlpRaw as usize); let initial_stack: Vec = vec![retdest, 0.into(), 0.into(), rlp_addr]; - let mut interpreter: Interpreter = Interpreter::new(encode_receipt, initial_stack); + let mut interpreter: Interpreter = Interpreter::new(encode_receipt, initial_stack, None); // Write data to memory. let expected_bloom_bytes = vec![ @@ -252,7 +256,7 @@ fn test_receipt_bloom_filter() -> Result<()> { // Set logs memory and initialize TxnBloom and BlockBloom segments. let initial_stack: Vec = vec![retdest]; - let mut interpreter: Interpreter = Interpreter::new(logs_bloom, initial_stack); + let mut interpreter: Interpreter = Interpreter::new(logs_bloom, initial_stack, None); let mut logs = vec![ 0.into(), // unused addr, @@ -414,7 +418,7 @@ fn test_mpt_insert_receipt() -> Result<()> { receipt.push(num_logs.into()); // num_logs receipt.extend(logs_0.clone()); - let mut interpreter: Interpreter = Interpreter::new(0, vec![]); + let mut interpreter: Interpreter = Interpreter::new(0, vec![], None); initialize_mpts(&mut interpreter, &trie_inputs); // If TrieData is empty, we need to push 0 because the first value is always 0. @@ -570,7 +574,7 @@ fn test_bloom_two_logs() -> Result<()> { ] .into(), ]; - let mut interpreter: Interpreter = Interpreter::new(logs_bloom, initial_stack); + let mut interpreter: Interpreter = Interpreter::new(logs_bloom, initial_stack, None); interpreter.set_memory_segment(Segment::TxnBloom, vec![0.into(); 256]); // Initialize transaction Bloom filter. interpreter.set_memory_segment(Segment::LogsData, logs); interpreter.set_memory_segment(Segment::Logs, vec![0.into(), 4.into()]); diff --git a/evm_arithmetization/src/cpu/kernel/tests/rlp/decode.rs b/evm_arithmetization/src/cpu/kernel/tests/rlp/decode.rs index 13ef498aa..9fa533c71 100644 --- a/evm_arithmetization/src/cpu/kernel/tests/rlp/decode.rs +++ b/evm_arithmetization/src/cpu/kernel/tests/rlp/decode.rs @@ -14,7 +14,8 @@ fn test_decode_rlp_string_len_short() -> Result<()> { 0xDEADBEEFu32.into(), U256::from(Segment::RlpRaw as usize + 2), ]; - let mut interpreter: Interpreter = Interpreter::new(decode_rlp_string_len, initial_stack); + let mut interpreter: Interpreter = + Interpreter::new(decode_rlp_string_len, initial_stack, None); // A couple dummy bytes, followed by "0x70" which is its own encoding. interpreter.set_rlp_memory(vec![123, 234, 0x70]); @@ -34,7 +35,8 @@ fn test_decode_rlp_string_len_medium() -> Result<()> { 0xDEADBEEFu32.into(), U256::from(Segment::RlpRaw as usize + 2), ]; - let mut interpreter: Interpreter = Interpreter::new(decode_rlp_string_len, initial_stack); + let mut interpreter: Interpreter = + Interpreter::new(decode_rlp_string_len, initial_stack, None); // A couple dummy bytes, followed by the RLP encoding of "1 2 3 4 5". interpreter.set_rlp_memory(vec![123, 234, 0x85, 1, 2, 3, 4, 5]); @@ -54,7 +56,8 @@ fn test_decode_rlp_string_len_long() -> Result<()> { 0xDEADBEEFu32.into(), U256::from(Segment::RlpRaw as usize + 2), ]; - let mut interpreter: Interpreter = Interpreter::new(decode_rlp_string_len, initial_stack); + let mut interpreter: Interpreter = + Interpreter::new(decode_rlp_string_len, initial_stack, None); // The RLP encoding of the string "1 2 3 ... 56". interpreter.set_rlp_memory(vec![ @@ -75,7 +78,8 @@ fn test_decode_rlp_list_len_short() -> Result<()> { let decode_rlp_list_len = KERNEL.global_labels["decode_rlp_list_len"]; let initial_stack = vec![0xDEADBEEFu32.into(), U256::from(Segment::RlpRaw as usize)]; - let mut interpreter: Interpreter = Interpreter::new(decode_rlp_list_len, initial_stack); + let mut interpreter: Interpreter = + Interpreter::new(decode_rlp_list_len, initial_stack, None); // The RLP encoding of [1, 2, [3, 4]]. interpreter.set_rlp_memory(vec![0xc5, 1, 2, 0xc2, 3, 4]); @@ -92,7 +96,8 @@ fn test_decode_rlp_list_len_long() -> Result<()> { let decode_rlp_list_len = KERNEL.global_labels["decode_rlp_list_len"]; let initial_stack = vec![0xDEADBEEFu32.into(), U256::from(Segment::RlpRaw as usize)]; - let mut interpreter: Interpreter = Interpreter::new(decode_rlp_list_len, initial_stack); + let mut interpreter: Interpreter = + Interpreter::new(decode_rlp_list_len, initial_stack, None); // The RLP encoding of [1, ..., 56]. interpreter.set_rlp_memory(vec![ @@ -113,7 +118,7 @@ fn test_decode_rlp_scalar() -> Result<()> { let decode_rlp_scalar = KERNEL.global_labels["decode_rlp_scalar"]; let initial_stack = vec![0xDEADBEEFu32.into(), U256::from(Segment::RlpRaw as usize)]; - let mut interpreter: Interpreter = Interpreter::new(decode_rlp_scalar, initial_stack); + let mut interpreter: Interpreter = Interpreter::new(decode_rlp_scalar, initial_stack, None); // The RLP encoding of "12 34 56". interpreter.set_rlp_memory(vec![0x83, 0x12, 0x34, 0x56]); diff --git a/evm_arithmetization/src/cpu/kernel/tests/rlp/encode.rs b/evm_arithmetization/src/cpu/kernel/tests/rlp/encode.rs index a7591a933..a3cf8699b 100644 --- a/evm_arithmetization/src/cpu/kernel/tests/rlp/encode.rs +++ b/evm_arithmetization/src/cpu/kernel/tests/rlp/encode.rs @@ -14,7 +14,7 @@ fn test_encode_rlp_scalar_small() -> Result<()> { let scalar = 42.into(); let pos = U256::from(Segment::RlpRaw as usize + 2); let initial_stack = vec![retdest, scalar, pos]; - let mut interpreter: Interpreter = Interpreter::new(encode_rlp_scalar, initial_stack); + let mut interpreter: Interpreter = Interpreter::new(encode_rlp_scalar, initial_stack, None); interpreter.run()?; let expected_stack = vec![pos + U256::from(1)]; // pos' = pos + rlp_len = 2 + 1 @@ -37,7 +37,7 @@ fn test_encode_rlp_scalar_medium() -> Result<()> { let scalar = 0x12345.into(); let pos = U256::from(Segment::RlpRaw as usize + 2); let initial_stack = vec![retdest, scalar, pos]; - let mut interpreter: Interpreter = Interpreter::new(encode_rlp_scalar, initial_stack); + let mut interpreter: Interpreter = Interpreter::new(encode_rlp_scalar, initial_stack, None); interpreter.run()?; let expected_stack = vec![pos + U256::from(4)]; // pos' = pos + rlp_len = 2 + 4 @@ -60,7 +60,7 @@ fn test_encode_rlp_160() -> Result<()> { let string = 0x12345.into(); let pos = U256::from(Segment::RlpRaw as usize); let initial_stack = vec![retdest, string, pos, U256::from(20)]; - let mut interpreter: Interpreter = Interpreter::new(encode_rlp_fixed, initial_stack); + let mut interpreter: Interpreter = Interpreter::new(encode_rlp_fixed, initial_stack, None); interpreter.run()?; let expected_stack = vec![pos + U256::from(1 + 20)]; // pos' @@ -80,7 +80,7 @@ fn test_encode_rlp_256() -> Result<()> { let string = 0x12345.into(); let pos = U256::from(Segment::RlpRaw as usize); let initial_stack = vec![retdest, string, pos, U256::from(32)]; - let mut interpreter: Interpreter = Interpreter::new(encode_rlp_fixed, initial_stack); + let mut interpreter: Interpreter = Interpreter::new(encode_rlp_fixed, initial_stack, None); interpreter.run()?; let expected_stack = vec![pos + U256::from(1 + 32)]; // pos' @@ -100,7 +100,8 @@ fn test_prepend_rlp_list_prefix_small() -> Result<()> { let start_pos = U256::from(Segment::RlpRaw as usize + 9); let end_pos = U256::from(Segment::RlpRaw as usize + 9 + 5); let initial_stack = vec![retdest, start_pos, end_pos]; - let mut interpreter: Interpreter = Interpreter::new(prepend_rlp_list_prefix, initial_stack); + let mut interpreter: Interpreter = + Interpreter::new(prepend_rlp_list_prefix, initial_stack, None); interpreter.set_rlp_memory(vec![ // Nine 0s to leave room for the longest possible RLP list prefix. 0, 0, 0, 0, 0, 0, 0, 0, 0, @@ -129,7 +130,8 @@ fn test_prepend_rlp_list_prefix_large() -> Result<()> { let start_pos = U256::from(Segment::RlpRaw as usize + 9); let end_pos = U256::from(Segment::RlpRaw as usize + 9 + 60); let initial_stack = vec![retdest, start_pos, end_pos]; - let mut interpreter: Interpreter = Interpreter::new(prepend_rlp_list_prefix, initial_stack); + let mut interpreter: Interpreter = + Interpreter::new(prepend_rlp_list_prefix, initial_stack, None); #[rustfmt::skip] interpreter.set_rlp_memory(vec![ diff --git a/evm_arithmetization/src/cpu/kernel/tests/rlp/num_bytes.rs b/evm_arithmetization/src/cpu/kernel/tests/rlp/num_bytes.rs index 2cc9c1bf5..9bcf2cf2b 100644 --- a/evm_arithmetization/src/cpu/kernel/tests/rlp/num_bytes.rs +++ b/evm_arithmetization/src/cpu/kernel/tests/rlp/num_bytes.rs @@ -11,7 +11,7 @@ fn test_num_bytes_0() -> Result<()> { let retdest = 0xDEADBEEFu32.into(); let x = 0.into(); let initial_stack = vec![retdest, x]; - let mut interpreter: Interpreter = Interpreter::new(num_bytes, initial_stack); + let mut interpreter: Interpreter = Interpreter::new(num_bytes, initial_stack, None); interpreter.run()?; assert_eq!(interpreter.stack(), vec![1.into()]); @@ -25,7 +25,7 @@ fn test_num_bytes_small() -> Result<()> { let retdest = 0xDEADBEEFu32.into(); let x = 42.into(); let initial_stack = vec![retdest, x]; - let mut interpreter: Interpreter = Interpreter::new(num_bytes, initial_stack); + let mut interpreter: Interpreter = Interpreter::new(num_bytes, initial_stack, None); interpreter.run()?; assert_eq!(interpreter.stack(), vec![1.into()]); @@ -39,7 +39,7 @@ fn test_num_bytes_medium() -> Result<()> { let retdest = 0xDEADBEEFu32.into(); let x = 0xAABBCCDDu32.into(); let initial_stack = vec![retdest, x]; - let mut interpreter: Interpreter = Interpreter::new(num_bytes, initial_stack); + let mut interpreter: Interpreter = Interpreter::new(num_bytes, initial_stack, None); interpreter.run()?; assert_eq!(interpreter.stack(), vec![4.into()]); diff --git a/evm_arithmetization/src/cpu/kernel/tests/signed_syscalls.rs b/evm_arithmetization/src/cpu/kernel/tests/signed_syscalls.rs index 4f3a16bec..ca705d171 100644 --- a/evm_arithmetization/src/cpu/kernel/tests/signed_syscalls.rs +++ b/evm_arithmetization/src/cpu/kernel/tests/signed_syscalls.rs @@ -118,7 +118,7 @@ fn run_test(fn_label: &str, expected_fn: fn(U256, U256) -> U256, opname: &str) { for &x in &inputs { for &y in &inputs { let stack = vec![retdest, y, x]; - let mut interpreter: Interpreter = Interpreter::new(fn_label, stack); + let mut interpreter: Interpreter = Interpreter::new(fn_label, stack, None); interpreter.run().unwrap(); assert_eq!(interpreter.stack_len(), 1usize, "unexpected stack size"); let output = interpreter diff --git a/evm_arithmetization/src/cpu/kernel/tests/transaction_parsing/parse_type_0_txn.rs b/evm_arithmetization/src/cpu/kernel/tests/transaction_parsing/parse_type_0_txn.rs index 745c81624..db0cb20d1 100644 --- a/evm_arithmetization/src/cpu/kernel/tests/transaction_parsing/parse_type_0_txn.rs +++ b/evm_arithmetization/src/cpu/kernel/tests/transaction_parsing/parse_type_0_txn.rs @@ -19,6 +19,7 @@ fn process_type_0_txn() -> Result<()> { let mut interpreter: Interpreter = Interpreter::new( process_type_0_txn, vec![retaddr, INITIAL_TXN_RLP_ADDR.into()], + None, ); // When we reach process_normalized_txn, we're done with parsing and @@ -82,6 +83,7 @@ fn process_type_0_txn_invalid_sig() -> Result<()> { let mut interpreter: Interpreter = Interpreter::new( process_type_0_txn, vec![retaddr, INITIAL_TXN_RLP_ADDR.into()], + None, ); // Same transaction as `process_type_0_txn()`, with the exception that the `s` diff --git a/evm_arithmetization/src/cpu/kernel/tests/transient_storage.rs b/evm_arithmetization/src/cpu/kernel/tests/transient_storage.rs index 774ac51d7..f584b1322 100644 --- a/evm_arithmetization/src/cpu/kernel/tests/transient_storage.rs +++ b/evm_arithmetization/src/cpu/kernel/tests/transient_storage.rs @@ -58,7 +58,7 @@ fn test_tstore() -> Result<()> { kexit_info, ]; - let mut interpreter: Interpreter = Interpreter::new(sys_tstore, initial_stack); + let mut interpreter: Interpreter = Interpreter::new(sys_tstore, initial_stack, None); initialize_interpreter(&mut interpreter, 100.into()); interpreter.run()?; @@ -104,7 +104,7 @@ fn test_tstore_tload() -> Result<()> { kexit_info, ]; - let mut interpreter: Interpreter = Interpreter::new(sys_tstore, initial_stack); + let mut interpreter: Interpreter = Interpreter::new(sys_tstore, initial_stack, None); initialize_interpreter(&mut interpreter, 200.into()); interpreter.run()?; @@ -162,7 +162,7 @@ fn test_many_tstore_many_tload() -> Result<()> { kexit_info, ]; - let mut interpreter: Interpreter = Interpreter::new(0, initial_stack); + let mut interpreter: Interpreter = Interpreter::new(0, initial_stack, None); initialize_interpreter(&mut interpreter, (10 * 200).into()); for i in 0..10 { @@ -235,9 +235,9 @@ fn test_revert() -> Result<()> { }); let sys_tstore = KERNEL.global_labels["sys_tstore"]; - let mut interpreter = Interpreter::::new(sys_tstore, vec![]); + let mut interpreter = Interpreter::::new(sys_tstore, vec![], None); interpreter.generation_state = - GenerationState::::new(GenerationInputs::default(), &KERNEL.code).unwrap(); + GenerationState::::new(&GenerationInputs::default(), &KERNEL.code).unwrap(); initialize_interpreter(&mut interpreter, (20 * 100).into()); // Store different values at slot 1 diff --git a/evm_arithmetization/src/cpu/kernel/utils.rs b/evm_arithmetization/src/cpu/kernel/utils.rs index adda086e8..082086d17 100644 --- a/evm_arithmetization/src/cpu/kernel/utils.rs +++ b/evm_arithmetization/src/cpu/kernel/utils.rs @@ -1,7 +1,6 @@ use core::fmt::Debug; use ethereum_types::U256; -use plonky2_util::ceil_div_usize; /// Enumerate the length `W` windows of `vec`, and run `maybe_replace` on each /// one. @@ -28,7 +27,7 @@ where } pub(crate) fn u256_to_trimmed_be_bytes(u256: &U256) -> Vec { - let num_bytes = ceil_div_usize(u256.bits(), 8); + let num_bytes = u256.bits().div_ceil(8); // `byte` is little-endian, so we manually reverse it. (0..num_bytes).rev().map(|i| u256.byte(i)).collect() } diff --git a/evm_arithmetization/src/cpu/syscalls_exceptions.rs b/evm_arithmetization/src/cpu/syscalls_exceptions.rs index a97d8eb52..cea7d704a 100644 --- a/evm_arithmetization/src/cpu/syscalls_exceptions.rs +++ b/evm_arithmetization/src/cpu/syscalls_exceptions.rs @@ -14,6 +14,7 @@ use crate::cpu::columns::CpuColumnsView; use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::membus::NUM_GP_CHANNELS; use crate::memory::segments::Segment; +use crate::witness::transition::EXC_STOP_CODE; // Copy the constant but make it `usize`. const BYTES_PER_OFFSET: usize = crate::cpu::kernel::assembler::BYTES_PER_OFFSET as usize; @@ -34,9 +35,6 @@ pub(crate) fn eval_packed( yield_constr.constraint(filter_syscall * (filter_syscall - P::ONES)); yield_constr.constraint(filter_exception * (filter_exception - P::ONES)); - // If exception, ensure we are not in kernel mode - yield_constr.constraint(filter_exception * lv.is_kernel_mode); - // Get the exception code as an value in {0, ..., 7}. let exc_code_bits = lv.general.exception().exc_code_bits; let exc_code: P = exc_code_bits @@ -44,6 +42,12 @@ pub(crate) fn eval_packed( .enumerate() .map(|(i, bit)| bit * P::Scalar::from_canonical_u64(1 << i)) .sum(); + + // All exceptions -- except `exc_stop`, which carries out the final checks of a + // segment execution -- have to be in user mode. + let exc_stop_code = P::Scalar::from_canonical_u8(EXC_STOP_CODE); + yield_constr.constraint(filter_exception * (exc_code - exc_stop_code) * lv.is_kernel_mode); + // Ensure that all bits are either 0 or 1. for bit in exc_code_bits { yield_constr.constraint(filter_exception * bit * (bit - P::ONES)); @@ -116,8 +120,9 @@ pub(crate) fn eval_packed( yield_constr.constraint(total_filter * output[7]); // High limb of gas is zero. // Zero the rest of that register - // output[1] is 0 for exceptions, but not for syscalls - yield_constr.constraint(filter_exception * output[1]); + // output[1] is 0 for exceptions (except for the final halting step), but not + // for syscalls. + yield_constr.constraint(filter_exception * (exc_code - exc_stop_code) * output[1]); for &limb in &output[2..6] { yield_constr.constraint(total_filter * limb); } @@ -143,10 +148,6 @@ pub(crate) fn eval_ext_circuit, const D: usize>( let constr = builder.mul_sub_extension(filter_exception, filter_exception, filter_exception); yield_constr.constraint(builder, constr); - // Ensure that, if exception, we are not in kernel mode - let constr = builder.mul_extension(filter_exception, lv.is_kernel_mode); - yield_constr.constraint(builder, constr); - let exc_code_bits = lv.general.exception().exc_code_bits; let exc_code = exc_code_bits @@ -156,6 +157,14 @@ pub(crate) fn eval_ext_circuit, const D: usize>( builder.mul_const_add_extension(F::from_canonical_u64(1 << i), bit, cumul) }); + // All exceptions -- except `exc_stop`, which carries out the final checks of a + // segment execution -- have to be in user mode. + let opcode_is_exc_stop = + builder.add_const_extension(exc_code, F::NEG_ONE * F::from_canonical_u8(EXC_STOP_CODE)); + let constr = + builder.mul_many_extension([filter_exception, opcode_is_exc_stop, lv.is_kernel_mode]); + + yield_constr.constraint(builder, constr); // Ensure that all bits are either 0 or 1. for bit in exc_code_bits { let constr = builder.mul_sub_extension(bit, bit, bit); @@ -303,7 +312,9 @@ pub(crate) fn eval_ext_circuit, const D: usize>( } // Zero the rest of that register - let constr = builder.mul_extension(filter_exception, output[1]); + // output[1] is 0 for exceptions (except for the final halting step), but not + // for syscalls. + let constr = builder.mul_many_extension([filter_exception, opcode_is_exc_stop, output[1]]); yield_constr.constraint(builder, constr); for &limb in &output[2..6] { let constr = builder.mul_extension(total_filter, limb); diff --git a/evm_arithmetization/src/fixed_recursive_verifier.rs b/evm_arithmetization/src/fixed_recursive_verifier.rs index 3fa6e208f..5a318b0fb 100644 --- a/evm_arithmetization/src/fixed_recursive_verifier.rs +++ b/evm_arithmetization/src/fixed_recursive_verifier.rs @@ -12,7 +12,7 @@ use plonky2::field::extension::Extendable; use plonky2::fri::FriParams; use plonky2::gates::constant::ConstantGate; use plonky2::gates::noop::NoopGate; -use plonky2::hash::hash_types::{RichField, NUM_HASH_OUT_ELTS}; +use plonky2::hash::hash_types::{MerkleCapTarget, RichField, NUM_HASH_OUT_ELTS}; use plonky2::iop::challenger::RecursiveChallenger; use plonky2::iop::target::{BoolTarget, Target}; use plonky2::iop::witness::{PartialWitness, WitnessWrite}; @@ -36,24 +36,39 @@ use starky::proof::StarkProofWithMetadata; use starky::stark::Stark; use crate::all_stark::{all_cross_table_lookups, AllStark, Table, NUM_TABLES}; -use crate::generation::GenerationInputs; +use crate::cpu::kernel::aggregator::KERNEL; +use crate::generation::{GenerationInputs, TrimmedGenerationInputs}; use crate::get_challenges::observe_public_values_target; use crate::proof::{ AllProof, BlockHashesTarget, BlockMetadataTarget, ExtraBlockData, ExtraBlockDataTarget, - PublicValues, PublicValuesTarget, TrieRoots, TrieRootsTarget, + FinalPublicValues, MemCapTarget, PublicValues, PublicValuesTarget, RegistersDataTarget, + TrieRoots, TrieRootsTarget, DEFAULT_CAP_LEN, TARGET_HASH_SIZE, }; -use crate::prover::{check_abort_signal, prove}; +use crate::prover::{check_abort_signal, prove, GenerationSegmentData, SegmentDataIterator}; use crate::recursive_verifier::{ add_common_recursion_gates, add_virtual_public_values, get_memory_extra_looking_sum_circuit, recursive_stark_circuit, set_public_value_targets, PlonkWrapperCircuit, PublicInputs, StarkWrapperCircuit, }; use crate::util::h256_limbs; +use crate::verifier::initial_memory_merkle_cap; /// The recursion threshold. We end a chain of recursive proofs once we reach /// this size. const THRESHOLD_DEGREE_BITS: usize = 13; +#[derive(Clone)] +pub struct ProverOutputData +where + F: RichField + Extendable, + C: GenericConfig, + C::Hasher: AlgebraicHasher, +{ + pub is_dummy: bool, + pub proof_with_pis: ProofWithPublicInputs, + pub public_values: PublicValues, +} + /// Contains all recursive circuits used in the system. For each STARK and each /// initial `degree_bits`, this contains a chain of recursive circuits for /// shrinking that STARK from `degree_bits` to a constant @@ -69,10 +84,14 @@ where /// The EVM root circuit, which aggregates the (shrunk) per-table recursive /// proofs. pub root: RootCircuitData, - /// The aggregation circuit, which verifies two proofs that can either be - /// root or aggregation proofs. - pub aggregation: AggregationCircuitData, - /// The block circuit, which verifies an aggregation root proof and an + /// The segment aggregation circuit, which verifies that two segment proofs + /// that can either be root or aggregation proofs. + pub segment_aggregation: SegmentAggregationCircuitData, + /// The transaction aggregation circuit, which verifies the aggregation of + /// two proofs that can either be a segment aggregation representing a + /// transaction or an aggregation of transactions. + pub txn_aggregation: TxnAggregationCircuitData, + /// The block circuit, which verifies a transaction aggregation proof and an /// optional previous block proof. pub block: BlockCircuitData, /// The two-to-one block aggregation circuit, which verifies two unrelated @@ -155,11 +174,109 @@ where } } -/// Data for the aggregation circuit, which is used to compress two proofs into -/// one. Each inner proof can be either an EVM root proof or another aggregation -/// proof. +/// Data for the segment aggregation circuit, which is used to compress two +/// segment proofs into one. Each inner proof can be either an EVM root proof or +/// another segment aggregation proof. #[derive(Eq, PartialEq, Debug)] -pub struct AggregationCircuitData +pub struct SegmentAggregationCircuitData +where + F: RichField + Extendable, + C: GenericConfig, +{ + pub circuit: CircuitData, + lhs: AggregationChildTarget, + rhs: AggregationChildWithDummyTarget, + public_values: PublicValuesTarget, + cyclic_vk: VerifierCircuitTarget, +} + +impl SegmentAggregationCircuitData +where + F: RichField + Extendable, + C: GenericConfig, +{ + fn to_buffer( + &self, + buffer: &mut Vec, + gate_serializer: &dyn GateSerializer, + generator_serializer: &dyn WitnessGeneratorSerializer, + ) -> IoResult<()> { + buffer.write_circuit_data(&self.circuit, gate_serializer, generator_serializer)?; + buffer.write_target_verifier_circuit(&self.cyclic_vk)?; + self.public_values.to_buffer(buffer)?; + self.lhs.to_buffer(buffer)?; + self.rhs.to_buffer(buffer)?; + Ok(()) + } + + fn from_buffer( + buffer: &mut Buffer, + gate_serializer: &dyn GateSerializer, + generator_serializer: &dyn WitnessGeneratorSerializer, + ) -> IoResult { + let circuit = buffer.read_circuit_data(gate_serializer, generator_serializer)?; + let cyclic_vk = buffer.read_target_verifier_circuit()?; + let public_values = PublicValuesTarget::from_buffer(buffer)?; + let lhs = AggregationChildTarget::from_buffer(buffer)?; + let rhs = AggregationChildWithDummyTarget::from_buffer(buffer)?; + Ok(Self { + circuit, + lhs, + rhs, + public_values, + cyclic_vk, + }) + } +} + +#[derive(Eq, PartialEq, Debug)] +struct AggregationChildWithDummyTarget { + is_agg: BoolTarget, + is_dummy: BoolTarget, + agg_proof: ProofWithPublicInputsTarget, + real_proof: ProofWithPublicInputsTarget, +} + +impl AggregationChildWithDummyTarget { + fn to_buffer(&self, buffer: &mut Vec) -> IoResult<()> { + buffer.write_target_bool(self.is_agg)?; + buffer.write_target_bool(self.is_dummy)?; + buffer.write_target_proof_with_public_inputs(&self.agg_proof)?; + buffer.write_target_proof_with_public_inputs(&self.real_proof)?; + Ok(()) + } + + fn from_buffer(buffer: &mut Buffer) -> IoResult { + let is_agg = buffer.read_target_bool()?; + let is_dummy = buffer.read_target_bool()?; + let agg_proof = buffer.read_target_proof_with_public_inputs()?; + let real_proof = buffer.read_target_proof_with_public_inputs()?; + Ok(Self { + is_agg, + is_dummy, + agg_proof, + real_proof, + }) + } + + // `len_mem_cap` is the length of the Merkle + // caps for `MemBefore` and `MemAfter`. + fn public_values>( + &self, + builder: &mut CircuitBuilder, + ) -> PublicValuesTarget { + let agg_pv = PublicValuesTarget::from_public_inputs(&self.agg_proof.public_inputs); + let segment_pv = PublicValuesTarget::from_public_inputs(&self.real_proof.public_inputs); + + PublicValuesTarget::select(builder, self.is_agg, agg_pv, segment_pv) + } +} + +/// Data for the transaction aggregation circuit, which is used to compress two +/// proofs into one. Each inner proof can be either a segment aggregation proof +/// or another transaction aggregation proof. +#[derive(Eq, PartialEq, Debug)] +pub struct TxnAggregationCircuitData where F: RichField + Extendable, C: GenericConfig, @@ -171,7 +288,7 @@ where cyclic_vk: VerifierCircuitTarget, } -impl AggregationCircuitData +impl TxnAggregationCircuitData where F: RichField + Extendable, C: GenericConfig, @@ -394,7 +511,9 @@ where let mut buffer = Vec::with_capacity(1 << 34); self.root .to_buffer(&mut buffer, gate_serializer, generator_serializer)?; - self.aggregation + self.segment_aggregation + .to_buffer(&mut buffer, gate_serializer, generator_serializer)?; + self.txn_aggregation .to_buffer(&mut buffer, gate_serializer, generator_serializer)?; self.block .to_buffer(&mut buffer, gate_serializer, generator_serializer)?; @@ -430,7 +549,12 @@ where let mut buffer = Buffer::new(bytes); let root = RootCircuitData::from_buffer(&mut buffer, gate_serializer, generator_serializer)?; - let aggregation = AggregationCircuitData::from_buffer( + let segment_aggregation = SegmentAggregationCircuitData::from_buffer( + &mut buffer, + gate_serializer, + generator_serializer, + )?; + let txn_aggregation = TxnAggregationCircuitData::from_buffer( &mut buffer, gate_serializer, generator_serializer, @@ -475,7 +599,8 @@ where Ok(Self { root, - aggregation, + segment_aggregation, + txn_aggregation, block, two_to_one_block, by_table, @@ -505,6 +630,9 @@ where degree_bits_ranges: &[Range; NUM_TABLES], stark_config: &StarkConfig, ) -> Self { + // Sanity check on the provided config + assert_eq!(DEFAULT_CAP_LEN, 1 << stark_config.fri_config.cap_height); + let arithmetic = RecursiveCircuitsForTable::new( Table::Arithmetic, &all_stark.arithmetic_stark, @@ -554,6 +682,20 @@ where &all_stark.cross_table_lookups, stark_config, ); + let mem_before = RecursiveCircuitsForTable::new( + Table::MemBefore, + &all_stark.mem_before_stark, + degree_bits_ranges[Table::MemBefore as usize].clone(), + &all_stark.cross_table_lookups, + stark_config, + ); + let mem_after = RecursiveCircuitsForTable::new( + Table::MemAfter, + &all_stark.mem_after_stark, + degree_bits_ranges[Table::MemAfter as usize].clone(), + &all_stark.cross_table_lookups, + stark_config, + ); let by_table = [ arithmetic, @@ -563,15 +705,19 @@ where keccak_sponge, logic, memory, + mem_before, + mem_after, ]; - let root = Self::create_root_circuit(&by_table, stark_config); - let aggregation = Self::create_aggregation_circuit(&root); - let block = Self::create_block_circuit(&aggregation); + let root = Self::create_segment_circuit(&by_table, stark_config); + let segment_aggregation = Self::create_segment_aggregation_circuit(&root); + let txn_aggregation = + Self::create_txn_aggregation_circuit(&segment_aggregation, stark_config); + let block = Self::create_block_circuit(&txn_aggregation); let two_to_one_block = Self::create_two_to_one_block_circuit(&block); - Self { root, - aggregation, + segment_aggregation, + txn_aggregation, block, two_to_one_block, by_table, @@ -598,7 +744,7 @@ where self.block.circuit.verifier_data() } - fn create_root_circuit( + fn create_segment_circuit( by_table: &[RecursiveCircuitsForTable; NUM_TABLES], stark_config: &StarkConfig, ) -> RootCircuitData { @@ -713,6 +859,18 @@ where ); } + let merkle_before = + MemCapTarget::from_public_inputs(&recursive_proofs[*Table::MemBefore].public_inputs); + let merkle_after = + MemCapTarget::from_public_inputs(&recursive_proofs[*Table::MemAfter].public_inputs); + // Connect Memory before and after the execution with + // the public values. + MemCapTarget::connect( + &mut builder, + public_values.mem_before.clone(), + merkle_before, + ); + MemCapTarget::connect(&mut builder, public_values.mem_after.clone(), merkle_after); // We want EVM root proofs to have the exact same structure as aggregation // proofs, so we add public inputs for cyclic verification, even though // they'll be ignored. @@ -732,119 +890,257 @@ where } } - fn create_aggregation_circuit( + fn create_segment_aggregation_circuit( root: &RootCircuitData, - ) -> AggregationCircuitData { + ) -> SegmentAggregationCircuitData { let mut builder = CircuitBuilder::::new(root.circuit.common.config.clone()); let public_values = add_virtual_public_values(&mut builder); let cyclic_vk = builder.add_verifier_data_public_inputs(); - let lhs = Self::add_agg_child(&mut builder, &root.circuit); - let rhs = Self::add_agg_child(&mut builder, &root.circuit); + // The right hand side child might be dummy. + let lhs_segment = Self::add_segment_agg_child(&mut builder, root); + let rhs_segment = Self::add_segment_agg_child_with_dummy( + &mut builder, + root, + lhs_segment.base_proof.clone(), + ); + + let lhs_pv = lhs_segment.public_values(&mut builder); + let rhs_pv = rhs_segment.public_values(&mut builder); + + let is_dummy = rhs_segment.is_dummy; + let one = builder.one(); + let is_not_dummy = builder.sub(one, is_dummy.target); + let is_not_dummy = BoolTarget::new_unsafe(is_not_dummy); + + // Always connect the lhs to the aggregation public values. + TrieRootsTarget::connect( + &mut builder, + public_values.trie_roots_before, + lhs_pv.trie_roots_before, + ); + TrieRootsTarget::connect( + &mut builder, + public_values.trie_roots_after, + lhs_pv.trie_roots_after, + ); + BlockMetadataTarget::connect( + &mut builder, + public_values.block_metadata, + lhs_pv.block_metadata, + ); + BlockHashesTarget::connect( + &mut builder, + public_values.block_hashes, + lhs_pv.block_hashes, + ); + ExtraBlockDataTarget::connect( + &mut builder, + public_values.extra_block_data, + lhs_pv.extra_block_data, + ); + RegistersDataTarget::connect( + &mut builder, + public_values.registers_before.clone(), + lhs_pv.registers_before.clone(), + ); + MemCapTarget::connect( + &mut builder, + public_values.mem_before.clone(), + lhs_pv.mem_before.clone(), + ); + + // If the rhs is a real proof, all the block metadata must be the same for both + // segments. It is also the case for the extra block data. + TrieRootsTarget::conditional_assert_eq( + &mut builder, + is_not_dummy, + public_values.trie_roots_before, + rhs_pv.trie_roots_before, + ); + TrieRootsTarget::conditional_assert_eq( + &mut builder, + is_not_dummy, + public_values.trie_roots_after, + rhs_pv.trie_roots_after, + ); + BlockMetadataTarget::conditional_assert_eq( + &mut builder, + is_not_dummy, + public_values.block_metadata, + rhs_pv.block_metadata, + ); + BlockHashesTarget::conditional_assert_eq( + &mut builder, + is_not_dummy, + public_values.block_hashes, + rhs_pv.block_hashes, + ); + ExtraBlockDataTarget::conditional_assert_eq( + &mut builder, + is_not_dummy, + public_values.extra_block_data, + rhs_pv.extra_block_data, + ); + + // If the rhs is a real proof: Connect registers and merkle caps between + // segments. + RegistersDataTarget::conditional_assert_eq( + &mut builder, + is_not_dummy, + public_values.registers_after.clone(), + rhs_pv.registers_after.clone(), + ); + RegistersDataTarget::conditional_assert_eq( + &mut builder, + is_not_dummy, + lhs_pv.registers_after.clone(), + rhs_pv.registers_before.clone(), + ); + MemCapTarget::conditional_assert_eq( + &mut builder, + is_not_dummy, + public_values.mem_after.clone(), + rhs_pv.mem_after.clone(), + ); + MemCapTarget::conditional_assert_eq( + &mut builder, + is_not_dummy, + lhs_pv.mem_after.clone(), + rhs_pv.mem_before.clone(), + ); + + // If the rhs is a dummy, then the lhs must be a segment. + let constr = builder.mul(is_dummy.target, lhs_segment.is_agg.target); + builder.assert_zero(constr); + + // If the rhs is a dummy, then the aggregation PVs are equal to the lhs PVs. + MemCapTarget::conditional_assert_eq( + &mut builder, + is_dummy, + public_values.mem_after.clone(), + lhs_pv.mem_after, + ); + RegistersDataTarget::conditional_assert_eq( + &mut builder, + is_dummy, + public_values.registers_after.clone(), + lhs_pv.registers_after, + ); + + // Pad to match the root circuit's degree. + while log2_ceil(builder.num_gates()) < root.circuit.common.degree_bits() { + builder.add_gate(NoopGate, vec![]); + } + + let circuit = builder.build::(); + SegmentAggregationCircuitData { + circuit, + lhs: lhs_segment, + rhs: rhs_segment, + public_values, + cyclic_vk, + } + } + + fn create_txn_aggregation_circuit( + agg: &SegmentAggregationCircuitData, + stark_config: &StarkConfig, + ) -> TxnAggregationCircuitData { + // Create a circuit for the aggregation of two transactions. + + let mut builder = CircuitBuilder::::new(agg.circuit.common.config.clone()); + let public_values = add_virtual_public_values(&mut builder); + let cyclic_vk = builder.add_verifier_data_public_inputs(); + + let lhs_txn_proof = Self::add_txn_agg_child(&mut builder, agg); + let rhs_txn_proof = Self::add_txn_agg_child(&mut builder, agg); + + let lhs_pv = lhs_txn_proof.public_values(&mut builder); + let rhs_pv = rhs_txn_proof.public_values(&mut builder); - let lhs_public_values = lhs.public_values(&mut builder); - let rhs_public_values = rhs.public_values(&mut builder); // Connect all block hash values BlockHashesTarget::connect( &mut builder, public_values.block_hashes, - lhs_public_values.block_hashes, + rhs_pv.block_hashes, ); BlockHashesTarget::connect( &mut builder, public_values.block_hashes, - rhs_public_values.block_hashes, + lhs_pv.block_hashes, ); // Connect all block metadata values. BlockMetadataTarget::connect( &mut builder, public_values.block_metadata, - lhs_public_values.block_metadata, + rhs_pv.block_metadata, ); BlockMetadataTarget::connect( &mut builder, public_values.block_metadata, - rhs_public_values.block_metadata, - ); - // Connect aggregation `trie_roots_before` with lhs `trie_roots_before`. - TrieRootsTarget::connect( - &mut builder, - public_values.trie_roots_before, - lhs_public_values.trie_roots_before, + lhs_pv.block_metadata, ); // Connect aggregation `trie_roots_after` with rhs `trie_roots_after`. TrieRootsTarget::connect( &mut builder, public_values.trie_roots_after, - rhs_public_values.trie_roots_after, + rhs_pv.trie_roots_after, ); // Connect lhs `trie_roots_after` with rhs `trie_roots_before`. TrieRootsTarget::connect( &mut builder, - lhs_public_values.trie_roots_after, - rhs_public_values.trie_roots_before, + lhs_pv.trie_roots_after, + rhs_pv.trie_roots_before, + ); + // Connect lhs `trie_roots_before` with public values `trie_roots_before`. + TrieRootsTarget::connect( + &mut builder, + public_values.trie_roots_before, + lhs_pv.trie_roots_before, ); - Self::connect_extra_public_values( &mut builder, &public_values.extra_block_data, - &lhs_public_values.extra_block_data, - &rhs_public_values.extra_block_data, + &lhs_pv.extra_block_data, + &rhs_pv.extra_block_data, ); - // Pad to match the root circuit's degree. - while log2_ceil(builder.num_gates()) < root.circuit.common.degree_bits() { + // We check the registers before and after for the current aggregation. + RegistersDataTarget::connect( + &mut builder, + public_values.registers_after.clone(), + rhs_pv.registers_after.clone(), + ); + + RegistersDataTarget::connect( + &mut builder, + public_values.registers_before.clone(), + lhs_pv.registers_before.clone(), + ); + + // Check the initial and final register values. + Self::connect_initial_final_segment(&mut builder, &rhs_pv); + Self::connect_initial_final_segment(&mut builder, &lhs_pv); + + // Check the initial `MemBefore` `MerkleCap` value. + Self::check_init_merkle_cap(&mut builder, &rhs_pv, stark_config); + Self::check_init_merkle_cap(&mut builder, &lhs_pv, stark_config); + + while log2_ceil(builder.num_gates()) < agg.circuit.common.degree_bits() { builder.add_gate(NoopGate, vec![]); } let circuit = builder.build::(); - AggregationCircuitData { + TxnAggregationCircuitData { circuit, - lhs, - rhs, + lhs: lhs_txn_proof, + rhs: rhs_txn_proof, public_values, cyclic_vk, } } - fn connect_extra_public_values( - builder: &mut CircuitBuilder, - pvs: &ExtraBlockDataTarget, - lhs: &ExtraBlockDataTarget, - rhs: &ExtraBlockDataTarget, - ) { - // Connect checkpoint state root values. - for (&limb0, &limb1) in pvs - .checkpoint_state_trie_root - .iter() - .zip(&rhs.checkpoint_state_trie_root) - { - builder.connect(limb0, limb1); - } - for (&limb0, &limb1) in pvs - .checkpoint_state_trie_root - .iter() - .zip(&lhs.checkpoint_state_trie_root) - { - builder.connect(limb0, limb1); - } - - // Connect the transaction number in public values to the lhs and rhs values - // correctly. - builder.connect(pvs.txn_number_before, lhs.txn_number_before); - builder.connect(pvs.txn_number_after, rhs.txn_number_after); - - // Connect lhs `txn_number_after` with rhs `txn_number_before`. - builder.connect(lhs.txn_number_after, rhs.txn_number_before); - - // Connect the gas used in public values to the lhs and rhs values correctly. - builder.connect(pvs.gas_used_before, lhs.gas_used_before); - builder.connect(pvs.gas_used_after, rhs.gas_used_after); - - // Connect lhs `gas_used_after` with rhs `gas_used_before`. - builder.connect(lhs.gas_used_after, rhs.gas_used_before); - } - /// Extend a circuit to verify one of two proofs. /// /// # Arguments @@ -880,10 +1176,53 @@ where } } - fn create_block_circuit(agg: &AggregationCircuitData) -> BlockCircuitData { + fn check_init_merkle_cap( + builder: &mut CircuitBuilder, + x: &PublicValuesTarget, + stark_config: &StarkConfig, + ) where + F: RichField + Extendable, + { + let cap = initial_memory_merkle_cap::( + stark_config.fri_config.rate_bits, + stark_config.fri_config.cap_height, + ); + + let init_cap_target = MemCapTarget { + mem_cap: MerkleCapTarget( + cap.0 + .iter() + .map(|&h| builder.constant_hash(h)) + .collect::>(), + ), + }; + + MemCapTarget::connect(builder, x.mem_before.clone(), init_cap_target); + } + + fn connect_initial_final_segment(builder: &mut CircuitBuilder, x: &PublicValuesTarget) + where + F: RichField + Extendable, + { + builder.assert_zero(x.registers_before.stack_len); + builder.assert_zero(x.registers_after.stack_len); + builder.assert_zero(x.registers_before.context); + builder.assert_zero(x.registers_after.context); + builder.assert_zero(x.registers_before.gas_used); + builder.assert_one(x.registers_before.is_kernel); + builder.assert_one(x.registers_after.is_kernel); + + let halt_label = builder.constant(F::from_canonical_usize(KERNEL.global_labels["halt"])); + builder.connect(x.registers_after.program_counter, halt_label); + + let main_label = builder.constant(F::from_canonical_usize(KERNEL.global_labels["main"])); + builder.connect(x.registers_before.program_counter, main_label); + } + + fn create_block_circuit(agg: &TxnAggregationCircuitData) -> BlockCircuitData { + // Here, we have two block proofs and we aggregate them together. // The block circuit is similar to the agg circuit; both verify two inner - // proofs. We need to adjust a few things, but it's easier than making a - // new CommonCircuitData. + // proofs. let expected_common_data = CommonCircuitData { fri_params: FriParams { degree_bits: 14, @@ -959,6 +1298,125 @@ where } } + fn connect_extra_public_values( + builder: &mut CircuitBuilder, + pvs: &ExtraBlockDataTarget, + lhs: &ExtraBlockDataTarget, + rhs: &ExtraBlockDataTarget, + ) { + // Connect checkpoint state root values. + for (&limb0, &limb1) in pvs + .checkpoint_state_trie_root + .iter() + .zip(&rhs.checkpoint_state_trie_root) + { + builder.connect(limb0, limb1); + } + for (&limb0, &limb1) in pvs + .checkpoint_state_trie_root + .iter() + .zip(&lhs.checkpoint_state_trie_root) + { + builder.connect(limb0, limb1); + } + + // Connect the transaction number in public values to the lhs and rhs values + // correctly. + builder.connect(pvs.txn_number_before, lhs.txn_number_before); + builder.connect(pvs.txn_number_after, rhs.txn_number_after); + + // Connect lhs `txn_number_after` with rhs `txn_number_before`. + builder.connect(lhs.txn_number_after, rhs.txn_number_before); + + // Connect the gas used in public values to the lhs and rhs values correctly. + builder.connect(pvs.gas_used_before, lhs.gas_used_before); + builder.connect(pvs.gas_used_after, rhs.gas_used_after); + + // Connect lhs `gas_used_after` with rhs `gas_used_before`. + builder.connect(lhs.gas_used_after, rhs.gas_used_before); + } + + fn add_segment_agg_child( + builder: &mut CircuitBuilder, + root: &RootCircuitData, + ) -> AggregationChildTarget { + let common = &root.circuit.common; + let root_vk = builder.constant_verifier_data(&root.circuit.verifier_only); + let is_agg = builder.add_virtual_bool_target_safe(); + let agg_proof = builder.add_virtual_proof_with_pis(common); + let base_proof = builder.add_virtual_proof_with_pis(common); + builder + .conditionally_verify_cyclic_proof::( + is_agg, + &agg_proof, + &base_proof, + &root_vk, + common, + ) + .expect("Failed to build cyclic recursion circuit"); + AggregationChildTarget { + is_agg, + agg_proof, + base_proof, + } + } + + fn add_segment_agg_child_with_dummy( + builder: &mut CircuitBuilder, + root: &RootCircuitData, + dummy_proof: ProofWithPublicInputsTarget, + ) -> AggregationChildWithDummyTarget { + let common = &root.circuit.common; + let root_vk = builder.constant_verifier_data(&root.circuit.verifier_only); + let is_agg = builder.add_virtual_bool_target_safe(); + let agg_proof = builder.add_virtual_proof_with_pis(common); + let is_dummy = builder.add_virtual_bool_target_safe(); + let real_proof = builder.add_virtual_proof_with_pis(common); + + let segment_proof = builder.select_proof_with_pis(is_dummy, &dummy_proof, &real_proof); + builder + .conditionally_verify_cyclic_proof::( + is_agg, + &agg_proof, + &segment_proof, + &root_vk, + common, + ) + .expect("Failed to build cyclic recursion circuit"); + AggregationChildWithDummyTarget { + is_agg, + is_dummy, + agg_proof, + real_proof, + } + } + + fn add_txn_agg_child( + builder: &mut CircuitBuilder, + segment_agg: &SegmentAggregationCircuitData, + ) -> AggregationChildTarget { + let common = &segment_agg.circuit.common; + let inner_segment_agg_vk = + builder.constant_verifier_data(&segment_agg.circuit.verifier_only); + let is_agg = builder.add_virtual_bool_target_safe(); + let agg_proof = builder.add_virtual_proof_with_pis(common); + let base_proof = builder.add_virtual_proof_with_pis(common); + builder + .conditionally_verify_cyclic_proof::( + is_agg, + &agg_proof, + &base_proof, + &inner_segment_agg_vk, + common, + ) + .expect("Failed to build cyclic recursion circuit"); + AggregationChildTarget { + is_agg, + agg_proof, + base_proof, + } + } + /// Create two-to-one block aggregation circuit. /// /// # Arguments @@ -1185,18 +1643,20 @@ where /// for a verifier to assert correctness of the computation, /// but the public values are output for the prover convenience, as these /// are necessary during proof aggregation. - pub fn prove_root( + pub fn prove_segment( &self, all_stark: &AllStark, config: &StarkConfig, - generation_inputs: GenerationInputs, + generation_inputs: TrimmedGenerationInputs, + segment_data: &mut GenerationSegmentData, timing: &mut TimingTree, abort_signal: Option>, - ) -> anyhow::Result<(ProofWithPublicInputs, PublicValues)> { + ) -> anyhow::Result> { let all_proof = prove::( all_stark, config, generation_inputs, + segment_data, timing, abort_signal.clone(), )?; @@ -1233,7 +1693,7 @@ where root_inputs.set_verifier_data_target( &self.root.cyclic_vk, - &self.aggregation.circuit.verifier_only, + &self.segment_aggregation.circuit.verifier_only, ); set_public_value_targets( @@ -1247,7 +1707,51 @@ where let root_proof = self.root.circuit.prove(root_inputs)?; - Ok((root_proof, all_proof.public_values)) + Ok(ProverOutputData { + is_dummy: false, + proof_with_pis: root_proof, + public_values: all_proof.public_values, + }) + } + + /// Returns a proof for each segment that is part of a full transaction + /// proof. + pub fn prove_all_segments( + &self, + all_stark: &AllStark, + config: &StarkConfig, + generation_inputs: GenerationInputs, + max_cpu_len_log: usize, + timing: &mut TimingTree, + abort_signal: Option>, + ) -> anyhow::Result>> { + let segment_iterator = + SegmentDataIterator::::new(&generation_inputs, Some(max_cpu_len_log)); + + let mut proofs = vec![]; + + for segment_run in segment_iterator { + let (_, mut next_data) = segment_run.map_err(|e| anyhow::format_err!(e))?; + let proof = self.prove_segment( + all_stark, + config, + generation_inputs.trim(), + &mut next_data, + timing, + abort_signal.clone(), + )?; + proofs.push(proof); + } + + // Since aggregations require at least two segment proofs, add a dummy proof if + // there is only one proof. + if proofs.len() == 1 { + let mut first_proof = proofs[0].clone(); + first_proof.is_dummy = true; + proofs.push(first_proof); + } + + Ok(proofs) } /// From an initial set of STARK proofs passed with their associated @@ -1291,7 +1795,7 @@ where /// let table_circuits = { ... }; /// /// // Finally shrink the STARK proof. - /// let (proof, public_values) = prove_root_after_initial_stark( + /// let (proof, public_values) = prove_segment_after_initial_stark( /// &all_stark, /// &config, /// &stark_proof, @@ -1300,7 +1804,7 @@ where /// abort_signal, /// ).unwrap(); /// ``` - pub fn prove_root_after_initial_stark( + pub fn prove_segment_after_initial_stark( &self, all_proof: AllProof, table_circuits: &[(RecursiveCircuitsForTableSize, u8); NUM_TABLES], @@ -1326,7 +1830,7 @@ where root_inputs.set_verifier_data_target( &self.root.cyclic_vk, - &self.aggregation.circuit.verifier_only, + &self.segment_aggregation.circuit.verifier_only, ); set_public_value_targets( @@ -1348,101 +1852,242 @@ where } /// Create an aggregation proof, combining two contiguous proofs into a - /// single one. The combined proofs can either be transaction (aka root) - /// proofs, or other aggregation proofs, as long as their states are - /// contiguous, meaning that the final state of the left child proof is the - /// initial state of the right child proof. + /// single one. The combined proofs are segment proofs: they are proofs + /// of some parts of one execution. /// - /// While regular transaction proofs can only assert validity of a single - /// transaction, aggregation proofs can cover an arbitrary range, up to - /// an entire block with all its transactions. + /// While regular root proofs can only assert validity of a + /// single segment of a transaction, segment aggregation proofs + /// can cover an arbitrary range, up to an entire transaction. /// /// # Arguments /// /// - `lhs_is_agg`: a boolean indicating whether the left child proof is an - /// aggregation proof or a regular transaction proof. - /// - `lhs_proof`: the left child proof. - /// - `lhs_public_values`: the public values associated to the right child - /// proof. + /// aggregation proof or a regular segment proof. + /// - `lhs_proof`: the left child prover output data. /// - `rhs_is_agg`: a boolean indicating whether the right child proof is an /// aggregation proof or a regular transaction proof. - /// - `rhs_proof`: the right child proof. - /// - `rhs_public_values`: the public values associated to the right child - /// proof. + /// - `rhs_proof`: the right child prover output data. /// /// # Outputs /// - /// This method outputs a tuple of [`ProofWithPublicInputs`] and - /// its [`PublicValues`]. Only the proof with public inputs is necessary - /// for a verifier to assert correctness of the computation, - /// but the public values are output for the prover convenience, as these - /// are necessary during proof aggregation. - pub fn prove_aggregation( + /// This method outputs a [`ProverOutputData`]. Only the proof with + /// public inputs is necessary for a verifier to assert correctness of + /// the computation, but the public values and `is_dummy` are output for the + /// prover convenience, as these are necessary during proof aggregation. + pub fn prove_segment_aggregation( &self, lhs_is_agg: bool, - lhs_proof: &ProofWithPublicInputs, - lhs_public_values: PublicValues, + lhs_prover_output: &ProverOutputData, rhs_is_agg: bool, - rhs_proof: &ProofWithPublicInputs, - rhs_public_values: PublicValues, - ) -> anyhow::Result<(ProofWithPublicInputs, PublicValues)> { + rhs_prover_output: &ProverOutputData, + ) -> anyhow::Result> { let mut agg_inputs = PartialWitness::new(); - agg_inputs.set_bool_target(self.aggregation.lhs.is_agg, lhs_is_agg); - agg_inputs.set_proof_with_pis_target(&self.aggregation.lhs.agg_proof, lhs_proof); - agg_inputs.set_proof_with_pis_target(&self.aggregation.lhs.base_proof, lhs_proof); + let lhs_proof = &lhs_prover_output.proof_with_pis; + let rhs_proof = &rhs_prover_output.proof_with_pis; + let rhs_is_dummy = rhs_prover_output.is_dummy; + Self::set_dummy_if_necessary( + &self.segment_aggregation.lhs, + lhs_is_agg, + &self.segment_aggregation.circuit, + &mut agg_inputs, + lhs_proof, + ); - agg_inputs.set_bool_target(self.aggregation.rhs.is_agg, rhs_is_agg); - agg_inputs.set_proof_with_pis_target(&self.aggregation.rhs.agg_proof, rhs_proof); - agg_inputs.set_proof_with_pis_target(&self.aggregation.rhs.base_proof, rhs_proof); + // If rhs is dummy, the rhs proof is also set to be the lhs. + let real_rhs_proof = if rhs_is_dummy { lhs_proof } else { rhs_proof }; + + Self::set_dummy_if_necessary_with_dummy( + &self.segment_aggregation.rhs, + rhs_is_agg, + rhs_is_dummy, + &self.segment_aggregation.circuit, + &mut agg_inputs, + real_rhs_proof, + ); agg_inputs.set_verifier_data_target( - &self.aggregation.cyclic_vk, - &self.aggregation.circuit.verifier_only, + &self.segment_aggregation.cyclic_vk, + &self.segment_aggregation.circuit.verifier_only, ); // Aggregates both `PublicValues` from the provided proofs into a single one. + + let lhs_public_values = &lhs_prover_output.public_values; + let rhs_public_values = &rhs_prover_output.public_values; + + let real_public_values = if rhs_is_dummy { + lhs_public_values.clone() + } else { + rhs_public_values.clone() + }; + let agg_public_values = PublicValues { - trie_roots_before: lhs_public_values.trie_roots_before, - trie_roots_after: rhs_public_values.trie_roots_after, + trie_roots_before: lhs_public_values.trie_roots_before.clone(), + trie_roots_after: real_public_values.trie_roots_after, extra_block_data: ExtraBlockData { checkpoint_state_trie_root: lhs_public_values .extra_block_data .checkpoint_state_trie_root, txn_number_before: lhs_public_values.extra_block_data.txn_number_before, - txn_number_after: rhs_public_values.extra_block_data.txn_number_after, + txn_number_after: real_public_values.extra_block_data.txn_number_after, gas_used_before: lhs_public_values.extra_block_data.gas_used_before, - gas_used_after: rhs_public_values.extra_block_data.gas_used_after, + gas_used_after: real_public_values.extra_block_data.gas_used_after, }, - block_metadata: rhs_public_values.block_metadata, - block_hashes: rhs_public_values.block_hashes, + block_metadata: real_public_values.block_metadata, + block_hashes: real_public_values.block_hashes, + registers_before: lhs_public_values.registers_before.clone(), + registers_after: real_public_values.registers_after, + mem_before: lhs_public_values.mem_before.clone(), + mem_after: real_public_values.mem_after, }; set_public_value_targets( &mut agg_inputs, - &self.aggregation.public_values, + &self.segment_aggregation.public_values, &agg_public_values, ) .map_err(|_| { anyhow::Error::msg("Invalid conversion when setting public values targets.") })?; - let aggregation_proof = self.aggregation.circuit.prove(agg_inputs)?; - Ok((aggregation_proof, agg_public_values)) + let aggregation_proof = self.segment_aggregation.circuit.prove(agg_inputs)?; + let agg_output = ProverOutputData { + is_dummy: false, + proof_with_pis: aggregation_proof, + public_values: agg_public_values, + }; + Ok(agg_output) } - pub fn verify_aggregation( + pub fn verify_segment_aggregation( &self, agg_proof: &ProofWithPublicInputs, ) -> anyhow::Result<()> { - self.aggregation.circuit.verify(agg_proof.clone())?; + self.segment_aggregation.circuit.verify(agg_proof.clone())?; check_cyclic_proof_verifier_data( agg_proof, - &self.aggregation.circuit.verifier_only, - &self.aggregation.circuit.common, + &self.segment_aggregation.circuit.verifier_only, + &self.segment_aggregation.circuit.common, ) } + /// Creates a final transaction proof, once all segments of a given + /// transaction have been combined into a single aggregation proof. + /// + /// Transaction proofs can either be generated as a standalone, or combined + /// with a previous transaction proof to assert validity of a range of + /// transactions. + /// + /// # Arguments + /// + /// - `opt_parent_txn_proof`: an optional parent transaction proof. Passing + /// one will generate a proof of validity for both the transaction range + /// covered by the previous proof and the current transaction. + /// - `agg_proof`: the final aggregation proof containing all segments + /// within the current transaction. + /// - `public_values`: the public values associated to the aggregation + /// proof. + /// + /// # Outputs + /// + /// This method outputs a tuple of [`ProofWithPublicInputs`] and + /// its [`PublicValues`]. Only the proof with public inputs is necessary + /// for a verifier to assert correctness of the computation. + pub fn prove_transaction_aggregation( + &self, + lhs_is_agg: bool, + lhs_proof: &ProofWithPublicInputs, + lhs_public_values: PublicValues, + rhs_is_agg: bool, + rhs_proof: &ProofWithPublicInputs, + rhs_public_values: PublicValues, + ) -> anyhow::Result<(ProofWithPublicInputs, PublicValues)> { + let mut txn_inputs = PartialWitness::new(); + + Self::set_dummy_if_necessary( + &self.txn_aggregation.lhs, + lhs_is_agg, + &self.txn_aggregation.circuit, + &mut txn_inputs, + lhs_proof, + ); + + Self::set_dummy_if_necessary( + &self.txn_aggregation.rhs, + rhs_is_agg, + &self.txn_aggregation.circuit, + &mut txn_inputs, + rhs_proof, + ); + + txn_inputs.set_verifier_data_target( + &self.txn_aggregation.cyclic_vk, + &self.txn_aggregation.circuit.verifier_only, + ); + + let txn_public_values = PublicValues { + trie_roots_before: lhs_public_values.trie_roots_before, + extra_block_data: ExtraBlockData { + txn_number_before: lhs_public_values.extra_block_data.txn_number_before, + gas_used_before: lhs_public_values.extra_block_data.gas_used_before, + ..rhs_public_values.extra_block_data + }, + ..rhs_public_values + }; + + set_public_value_targets( + &mut txn_inputs, + &self.txn_aggregation.public_values, + &txn_public_values, + ) + .map_err(|_| { + anyhow::Error::msg("Invalid conversion when setting public values targets.") + })?; + + let txn_proof = self.txn_aggregation.circuit.prove(txn_inputs)?; + Ok((txn_proof, txn_public_values)) + } + + pub fn verify_txn_aggregation( + &self, + txn_proof: &ProofWithPublicInputs, + ) -> anyhow::Result<()> { + self.txn_aggregation.circuit.verify(txn_proof.clone())?; + check_cyclic_proof_verifier_data( + txn_proof, + &self.txn_aggregation.circuit.verifier_only, + &self.txn_aggregation.circuit.common, + ) + } + + /// If the proof is not an aggregation, we set the cyclic vk to a dummy + /// value, so that it corresponds to the aggregation cyclic vk. If the proof + /// is dummy, we set `is_dummy` to `true`. Note that only the rhs can be + /// dummy. + fn set_dummy_if_necessary_with_dummy( + agg_child: &AggregationChildWithDummyTarget, + is_agg: bool, + is_dummy: bool, + circuit: &CircuitData, + agg_inputs: &mut PartialWitness, + proof: &ProofWithPublicInputs, + ) { + agg_inputs.set_bool_target(agg_child.is_agg, is_agg); + agg_inputs.set_bool_target(agg_child.is_dummy, is_dummy); + if is_agg { + agg_inputs.set_proof_with_pis_target(&agg_child.agg_proof, proof); + } else { + Self::set_dummy_proof_with_cyclic_vk_pis( + circuit, + agg_inputs, + &agg_child.agg_proof, + proof, + ); + } + agg_inputs.set_proof_with_pis_target(&agg_child.real_proof, proof); + } + /// Create a final block proof, once all transactions of a given block have /// been combined into a single aggregation proof. /// @@ -1469,7 +2114,7 @@ where opt_parent_block_proof: Option<&ProofWithPublicInputs>, agg_root_proof: &ProofWithPublicInputs, public_values: PublicValues, - ) -> anyhow::Result<(ProofWithPublicInputs, PublicValues)> { + ) -> anyhow::Result<(ProofWithPublicInputs, FinalPublicValues)> { let mut block_inputs = PartialWitness::new(); block_inputs.set_bool_target( @@ -1495,21 +2140,19 @@ where let mut nonzero_pis = HashMap::new(); // Initialize the checkpoint block roots before, and state root after. - let state_trie_root_before_keys = 0..TrieRootsTarget::HASH_SIZE; + let state_trie_root_before_keys = 0..TARGET_HASH_SIZE; for (key, &value) in state_trie_root_before_keys .zip_eq(&h256_limbs::(public_values.trie_roots_before.state_root)) { nonzero_pis.insert(key, value); } - let txn_trie_root_before_keys = - TrieRootsTarget::HASH_SIZE..TrieRootsTarget::HASH_SIZE * 2; + let txn_trie_root_before_keys = TARGET_HASH_SIZE..TARGET_HASH_SIZE * 2; for (key, &value) in txn_trie_root_before_keys.clone().zip_eq(&h256_limbs::( public_values.trie_roots_before.transactions_root, )) { nonzero_pis.insert(key, value); } - let receipts_trie_root_before_keys = - TrieRootsTarget::HASH_SIZE * 2..TrieRootsTarget::HASH_SIZE * 3; + let receipts_trie_root_before_keys = TARGET_HASH_SIZE * 2..TARGET_HASH_SIZE * 3; for (key, &value) in receipts_trie_root_before_keys .clone() .zip_eq(&h256_limbs::( @@ -1519,7 +2162,7 @@ where nonzero_pis.insert(key, value); } let state_trie_root_after_keys = - TrieRootsTarget::SIZE..TrieRootsTarget::SIZE + TrieRootsTarget::HASH_SIZE; + TrieRootsTarget::SIZE..TrieRootsTarget::SIZE + TARGET_HASH_SIZE; for (key, &value) in state_trie_root_after_keys .zip_eq(&h256_limbs::(public_values.trie_roots_before.state_root)) { @@ -1601,7 +2244,7 @@ where })?; let block_proof = self.block.circuit.prove(block_inputs)?; - Ok((block_proof, block_public_values)) + Ok((block_proof, block_public_values.into())) } pub fn verify_block(&self, block_proof: &ProofWithPublicInputs) -> anyhow::Result<()> { diff --git a/evm_arithmetization/src/generation/linked_list.rs b/evm_arithmetization/src/generation/linked_list.rs new file mode 100644 index 000000000..a89e49657 --- /dev/null +++ b/evm_arithmetization/src/generation/linked_list.rs @@ -0,0 +1,87 @@ +use std::fmt; + +use anyhow::Result; +use ethereum_types::U256; + +use crate::memory::segments::Segment; +use crate::util::u256_to_usize; +use crate::witness::errors::ProgramError; +use crate::witness::errors::ProverInputError::InvalidInput; + +// A linked list implemented using a vector `access_list_mem`. +// In this representation, the values of nodes are stored in the range +// `access_list_mem[i..i + node_size - 1]`, and `access_list_mem[i + node_size - +// 1]` holds the address of the next node, where i = node_size * j. +#[derive(Clone)] +pub(crate) struct LinkedList<'a, const N: usize> { + mem: &'a [Option], + offset: usize, + pos: usize, +} + +pub(crate) fn empty_list_mem(segment: Segment) -> [Option; N] { + std::array::from_fn(|i| { + if i == 0 { + Some(U256::MAX) + } else if i == N - 1 { + Some((segment as usize).into()) + } else { + Some(U256::zero()) + } + }) +} + +impl<'a, const N: usize> LinkedList<'a, N> { + pub fn from_mem_and_segment( + mem: &'a [Option], + segment: Segment, + ) -> Result { + Self::from_mem_len_and_segment(mem, segment) + } + + pub fn from_mem_len_and_segment( + mem: &'a [Option], + segment: Segment, + ) -> Result { + if mem.is_empty() { + return Err(ProgramError::ProverInputError(InvalidInput)); + } + Ok(Self { + mem, + offset: segment as usize, + pos: 0, + }) + } +} + +impl<'a, const N: usize> fmt::Debug for LinkedList<'a, N> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!(f, "Linked List {{")?; + let cloned_list = self.clone(); + for node in cloned_list { + if node[0] == U256::MAX { + writeln!(f, "{:?}", node)?; + break; + } + writeln!(f, "{:?} ->", node)?; + } + write!(f, "}}") + } +} + +impl<'a, const N: usize> Iterator for LinkedList<'a, N> { + type Item = [U256; N]; + + fn next(&mut self) -> Option { + // The first node is always the special node, so we skip it in the first + // iteration. + if let Ok(new_pos) = u256_to_usize(self.mem[self.pos + N - 1].unwrap_or_default()) { + self.pos = new_pos - self.offset; + Some(std::array::from_fn(|i| { + self.mem[self.pos + i].unwrap_or_default() + })) + } else { + None + } + } +} diff --git a/evm_arithmetization/src/generation/mod.rs b/evm_arithmetization/src/generation/mod.rs index 5940e8be3..161ceda4c 100644 --- a/evm_arithmetization/src/generation/mod.rs +++ b/evm_arithmetization/src/generation/mod.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; use anyhow::anyhow; use ethereum_types::{Address, BigEndianHash, H256, U256}; +use keccak_hash::keccak; use log::log_enabled; use mpt_trie::partial_trie::{HashedPartialTrie, PartialTrie}; use plonky2::field::extension::Extendable; @@ -21,22 +22,35 @@ use crate::all_stark::{AllStark, NUM_TABLES}; use crate::cpu::columns::CpuColumnsView; use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; -use crate::generation::state::GenerationState; +use crate::generation::state::{GenerationState, State}; use crate::generation::trie_extractor::{get_receipt_trie, get_state_trie, get_txn_trie}; -use crate::memory::segments::Segment; -use crate::proof::{BlockHashes, BlockMetadata, ExtraBlockData, PublicValues, TrieRoots}; +use crate::memory::segments::{Segment, PREINITIALIZED_SEGMENTS_INDICES}; +use crate::proof::{ + BlockHashes, BlockMetadata, ExtraBlockData, MemCap, PublicValues, RegistersData, TrieRoots, +}; +use crate::prover::GenerationSegmentData; use crate::util::{h2u, u256_to_usize}; -use crate::witness::memory::{MemoryAddress, MemoryChannel}; +use crate::witness::memory::{MemoryAddress, MemoryChannel, MemoryState}; +use crate::witness::state::RegistersState; +pub(crate) mod linked_list; pub mod mpt; pub(crate) mod prover_input; pub(crate) mod rlp; pub(crate) mod state; -mod trie_extractor; +pub(crate) mod trie_extractor; -use self::state::State; use crate::witness::util::mem_write_log; +/// Number of cycles to go after having reached the halting state. It is +/// equal to the number of cycles in `exc_stop` + 1. +pub const NUM_EXTRA_CYCLES_AFTER: usize = 81; +/// Number of cycles to go before starting the execution: it is the number of +/// cycles in `init`. +pub const NUM_EXTRA_CYCLES_BEFORE: usize = 64; +/// Memory values used to initialize `MemBefore`. +pub type MemBeforeValues = Vec<(MemoryAddress, U256)>; + /// Inputs needed for trace generation. #[derive(Clone, Debug, Deserialize, Serialize, Default)] pub struct GenerationInputs { @@ -52,7 +66,7 @@ pub struct GenerationInputs { /// A None would yield an empty proof, otherwise this contains the encoding /// of a transaction. - pub signed_txn: Option>, + pub signed_txns: Vec>, /// Withdrawal pairs `(addr, amount)`. At the end of the txs, `amount` is /// added to `addr`'s balance. See EIP-4895. pub withdrawals: Vec<(Address, U256)>, @@ -80,6 +94,48 @@ pub struct GenerationInputs { pub block_hashes: BlockHashes, } +/// A lighter version of [`GenerationInputs`], which have been trimmed +/// post pre-initialization processing. +#[derive(Clone, Debug, Deserialize, Serialize, Default)] +pub struct TrimmedGenerationInputs { + pub trimmed_tries: TrimmedTrieInputs, + /// The index of the first transaction in this payload being proven within + /// its block. + pub txn_number_before: U256, + /// The cumulative gas used through the execution of all transactions prior + /// the current ones. + pub gas_used_before: U256, + /// The cumulative gas used after the execution of the current batch of + /// transactions. The exact gas used by the current batch of transactions + /// is `gas_used_after` - `gas_used_before`. + pub gas_used_after: U256, + + /// The list of txn hashes contained in this batch. + pub txn_hashes: Vec, + + /// Expected trie roots before these transactions are executed. + pub trie_roots_before: TrieRoots, + /// Expected trie roots after these transactions are executed. + pub trie_roots_after: TrieRoots, + + /// State trie root of the checkpoint block. + /// This could always be the genesis block of the chain, but it allows a + /// prover to continue proving blocks from certain checkpoint heights + /// without requiring proofs for blocks past this checkpoint. + pub checkpoint_state_trie_root: H256, + + /// Mapping between smart contract code hashes and the contract byte code. + /// All account smart contracts that are invoked will have an entry present. + pub contract_code: HashMap>, + + /// Information contained in the block header. + pub block_metadata: BlockMetadata, + + /// The hash of the current block, and a list of the 256 previous block + /// hashes. + pub block_hashes: BlockHashes, +} + #[derive(Clone, Debug, Deserialize, Serialize, Default)] pub struct TrieInputs { /// A partial version of the state trie prior to these transactions. It @@ -103,12 +159,64 @@ pub struct TrieInputs { pub storage_tries: Vec<(H256, HashedPartialTrie)>, } +#[derive(Clone, Debug, Deserialize, Serialize, Default)] +pub struct TrimmedTrieInputs { + /// A partial version of the state trie prior to these transactions. It + /// should include all nodes that will be accessed by these + /// transactions. + pub state_trie: HashedPartialTrie, + /// A partial version of each storage trie prior to these transactions. It + /// should include all storage tries, and nodes therein, that will be + /// accessed by these transactions. + pub storage_tries: Vec<(H256, HashedPartialTrie)>, +} + +impl TrieInputs { + pub(crate) fn trim(&self) -> TrimmedTrieInputs { + TrimmedTrieInputs { + state_trie: self.state_trie.clone(), + storage_tries: self.storage_tries.clone(), + } + } +} +impl GenerationInputs { + /// Outputs a trimmed version of the `GenerationInputs`, that do not contain + /// the fields that have already been processed during pre-initialization, + /// namely: the input tries, the signed transaction, and the withdrawals. + pub(crate) fn trim(&self) -> TrimmedGenerationInputs { + let txn_hashes = self + .signed_txns + .iter() + .map(|tx_bytes| keccak(&tx_bytes[..])) + .collect(); + + TrimmedGenerationInputs { + trimmed_tries: self.tries.trim(), + txn_number_before: self.txn_number_before, + gas_used_before: self.gas_used_before, + gas_used_after: self.gas_used_after, + txn_hashes, + trie_roots_before: TrieRoots { + state_root: self.tries.state_trie.hash(), + transactions_root: self.tries.transactions_trie.hash(), + receipts_root: self.tries.receipts_trie.hash(), + }, + trie_roots_after: self.trie_roots_after.clone(), + checkpoint_state_trie_root: self.checkpoint_state_trie_root, + contract_code: self.contract_code.clone(), + block_metadata: self.block_metadata.clone(), + block_hashes: self.block_hashes.clone(), + } + } +} + fn apply_metadata_and_tries_memops, const D: usize>( state: &mut GenerationState, - inputs: &GenerationInputs, + inputs: &TrimmedGenerationInputs, + registers_before: &RegistersData, + registers_after: &RegistersData, ) { let metadata = &inputs.block_metadata; - let tries = &inputs.tries; let trie_roots_after = &inputs.trie_roots_after; let fields = [ ( @@ -147,19 +255,19 @@ fn apply_metadata_and_tries_memops, const D: usize> (GlobalMetadata::TxnNumberBefore, inputs.txn_number_before), ( GlobalMetadata::TxnNumberAfter, - inputs.txn_number_before + if inputs.signed_txn.is_some() { 1 } else { 0 }, + inputs.txn_number_before + inputs.txn_hashes.len(), ), ( GlobalMetadata::StateTrieRootDigestBefore, - h2u(tries.state_trie.hash()), + h2u(inputs.trie_roots_before.state_root), ), ( GlobalMetadata::TransactionTrieRootDigestBefore, - h2u(tries.transactions_trie.hash()), + h2u(inputs.trie_roots_before.transactions_root), ), ( GlobalMetadata::ReceiptTrieRootDigestBefore, - h2u(tries.receipts_trie.hash()), + h2u(inputs.trie_roots_before.receipts_root), ), ( GlobalMetadata::StateTrieRootDigestAfter, @@ -214,12 +322,50 @@ fn apply_metadata_and_tries_memops, const D: usize> .collect::>(), ); + // Write initial registers. + let registers_before = [ + registers_before.program_counter, + registers_before.is_kernel, + registers_before.stack_len, + registers_before.stack_top, + registers_before.context, + registers_before.gas_used, + ]; + ops.extend((0..registers_before.len()).map(|i| { + mem_write_log( + channel, + MemoryAddress::new(0, Segment::RegistersStates, i), + state, + registers_before[i], + ) + })); + + let length = registers_before.len(); + + // Write final registers. + let registers_after = [ + registers_after.program_counter, + registers_after.is_kernel, + registers_after.stack_len, + registers_after.stack_top, + registers_after.context, + registers_after.gas_used, + ]; + ops.extend((0..registers_before.len()).map(|i| { + mem_write_log( + channel, + MemoryAddress::new(0, Segment::RegistersStates, length + i), + state, + registers_after[i], + ) + })); + state.memory.apply_ops(&ops); state.traces.memory_ops.extend(ops); } pub(crate) fn debug_inputs(inputs: &GenerationInputs) { - log::debug!("Input signed_txn: {:?}", &inputs.signed_txn); + log::debug!("Input signed_txns: {:?}", &inputs.signed_txns); log::debug!("Input state_trie: {:?}", &inputs.tries.state_trie); log::debug!( "Input transactions_trie: {:?}", @@ -230,29 +376,93 @@ pub(crate) fn debug_inputs(inputs: &GenerationInputs) { log::debug!("Input contract_code: {:?}", &inputs.contract_code); } +fn initialize_kernel_code_and_shift_table(memory: &mut MemoryState) { + let mut code_addr = MemoryAddress::new(0, Segment::Code, 0); + for &byte in &KERNEL.code { + memory.set(code_addr, U256::from(byte)); + code_addr.increment(); + } + + let mut shift_addr = MemoryAddress::new(0, Segment::ShiftTable, 0); + let mut shift_val = U256::one(); + for _ in 0..256 { + memory.set(shift_addr, shift_val); + shift_addr.increment(); + shift_val <<= 1; + } +} + +/// Returns the memory addresses and values that should comprise the state at +/// the start of the segment's execution. +/// Ignores zero values in non-preinitialized segments. +fn get_all_memory_address_and_values(memory_before: &MemoryState) -> Vec<(MemoryAddress, U256)> { + let mut res = vec![]; + for (ctx_idx, ctx) in memory_before.contexts.iter().enumerate() { + for (segment_idx, segment) in ctx.segments.iter().enumerate() { + for (virt, value) in segment.content.iter().enumerate() { + if let &Some(val) = value { + // We skip zero values in non-preinitialized segments. + if !val.is_zero() || PREINITIALIZED_SEGMENTS_INDICES.contains(&segment_idx) { + res.push(( + MemoryAddress { + context: ctx_idx, + segment: segment_idx, + virt, + }, + val, + )); + } + } + } + } + } + res +} + +type TablesWithPVsAndFinalMem = ([Vec>; NUM_TABLES], PublicValues); pub fn generate_traces, const D: usize>( all_stark: &AllStark, - inputs: GenerationInputs, + inputs: &TrimmedGenerationInputs, config: &StarkConfig, + segment_data: &mut GenerationSegmentData, timing: &mut TimingTree, -) -> anyhow::Result<([Vec>; NUM_TABLES], PublicValues)> { - debug_inputs(&inputs); - let mut state = GenerationState::::new(inputs.clone(), &KERNEL.code) +) -> anyhow::Result> { + let mut state = GenerationState::::new_with_segment_data(inputs, segment_data) .map_err(|err| anyhow!("Failed to parse all the initial prover inputs: {:?}", err))?; - apply_metadata_and_tries_memops(&mut state, &inputs); + initialize_kernel_code_and_shift_table(&mut segment_data.memory); - let cpu_res = timed!(timing, "simulate CPU", simulate_cpu(&mut state)); - if cpu_res.is_err() { - let _ = output_debug_tries(&state); + // Retrieve initial memory addresses and values. + let actual_mem_before = get_all_memory_address_and_values(&segment_data.memory); - cpu_res?; + // Initialize the state with the one at the end of the + // previous segment execution, if any. + let GenerationSegmentData { + max_cpu_len_log, + registers_before, + registers_after, + .. + } = segment_data; + + for &(address, val) in &actual_mem_before { + state.memory.set(address, val); } - log::info!( - "Trace lengths (before padding): {:?}", - state.traces.get_lengths() + let registers_before: RegistersData = RegistersData::from(*registers_before); + let registers_after: RegistersData = RegistersData::from(*registers_after); + apply_metadata_and_tries_memops(&mut state, inputs, ®isters_before, ®isters_after); + + let cpu_res = timed!( + timing, + "simulate CPU", + simulate_cpu(&mut state, *max_cpu_len_log) ); + if cpu_res.is_err() { + output_debug_tries(&state)?; + cpu_res?; + }; + + let trace_lengths = state.traces.get_lengths(); let read_metadata = |field| state.memory.read_global_metadata(field); let trie_roots_before = TrieRoots { @@ -277,29 +487,46 @@ pub fn generate_traces, const D: usize>( gas_used_after, }; + // `mem_before` and `mem_after` are initialized with an empty cap. + // They will be set to the caps of `MemBefore` and `MemAfter` + // respectively, while proving. let public_values = PublicValues { trie_roots_before, trie_roots_after, - block_metadata: inputs.block_metadata, - block_hashes: inputs.block_hashes, + block_metadata: inputs.block_metadata.clone(), + block_hashes: inputs.block_hashes.clone(), extra_block_data, + registers_before, + registers_after, + mem_before: MemCap::default(), + mem_after: MemCap::default(), }; let tables = timed!( timing, "convert trace data to tables", - state.traces.into_tables(all_stark, config, timing) + state.traces.into_tables( + all_stark, + &actual_mem_before, + state.stale_contexts, + trace_lengths, + config, + timing + ) ); Ok((tables, public_values)) } -fn simulate_cpu(state: &mut GenerationState) -> anyhow::Result<()> { - state.run_cpu()?; +fn simulate_cpu( + state: &mut GenerationState, + max_cpu_len_log: Option, +) -> anyhow::Result<(RegistersState, Option)> { + let (final_registers, mem_after) = state.run_cpu(max_cpu_len_log)?; let pc = state.registers.program_counter; // Setting the values of padding rows. let mut row = CpuColumnsView::::default(); - row.clock = F::from_canonical_usize(state.traces.clock()); + row.clock = F::from_canonical_usize(state.traces.clock() + 1); row.context = F::from_canonical_usize(state.registers.context); row.program_counter = F::from_canonical_usize(pc); row.is_kernel_mode = F::ONE; @@ -317,7 +544,7 @@ fn simulate_cpu(state: &mut GenerationState) -> anyhow::Result<()> log::info!("CPU trace padded to {} cycles", state.traces.clock()); - Ok(()) + Ok((final_registers, mem_after)) } /// Outputs the tries that have been obtained post transaction execution, as diff --git a/evm_arithmetization/src/generation/mpt.rs b/evm_arithmetization/src/generation/mpt.rs index dd96ff08f..36f520d9f 100644 --- a/evm_arithmetization/src/generation/mpt.rs +++ b/evm_arithmetization/src/generation/mpt.rs @@ -8,9 +8,14 @@ use mpt_trie::nibbles::{Nibbles, NibblesIntern}; use mpt_trie::partial_trie::{HashedPartialTrie, PartialTrie}; use rlp::{Decodable, DecoderError, Encodable, PayloadInfo, Rlp, RlpStream}; use rlp_derive::{RlpDecodable, RlpEncodable}; +use serde::{Deserialize, Serialize}; +use super::linked_list::empty_list_mem; +use super::prover_input::{ACCOUNTS_LINKED_LIST_NODE_SIZE, STORAGE_LINKED_LIST_NODE_SIZE}; +use super::TrimmedTrieInputs; use crate::cpu::kernel::constants::trie_type::PartialTrieType; use crate::generation::TrieInputs; +use crate::memory::segments::Segment; use crate::util::h2u; use crate::witness::errors::{ProgramError, ProverInputError}; use crate::Node; @@ -23,9 +28,9 @@ pub struct AccountRlp { pub code_hash: H256, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Default, Serialize, Deserialize)] pub struct TrieRootPtrs { - pub state_root_ptr: usize, + pub state_root_ptr: Option, pub txn_root_ptr: usize, pub receipt_root_ptr: usize, } @@ -128,6 +133,10 @@ fn parse_storage_value(value_rlp: &[u8]) -> Result, ProgramError> { Ok(vec![value]) } +fn parse_storage_value_no_return(_value_rlp: &[u8]) -> Result, ProgramError> { + Ok(vec![]) +} + const fn empty_nibbles() -> Nibbles { Nibbles { count: 0, @@ -137,7 +146,7 @@ const fn empty_nibbles() -> Nibbles { fn load_mpt( trie: &HashedPartialTrie, - trie_data: &mut Vec, + trie_data: &mut Vec>, parse_value: &F, ) -> Result where @@ -146,66 +155,65 @@ where let node_ptr = trie_data.len(); let type_of_trie = PartialTrieType::of(trie) as u32; if type_of_trie > 0 { - trie_data.push(type_of_trie.into()); + trie_data.push(Some(type_of_trie.into())); } match trie.deref() { Node::Empty => Ok(0), Node::Hash(h) => { - trie_data.push(h2u(*h)); - + trie_data.push(Some(h2u(*h))); Ok(node_ptr) } Node::Branch { children, value } => { // First, set children pointers to 0. let first_child_ptr = trie_data.len(); - trie_data.extend(vec![U256::zero(); 16]); + trie_data.extend(vec![Some(U256::zero()); 16]); // Then, set value. if value.is_empty() { - trie_data.push(U256::zero()); + trie_data.push(Some(U256::zero())); } else { - let parsed_value = parse_value(value)?; - trie_data.push((trie_data.len() + 1).into()); + let parsed_value = parse_value(value)?.into_iter().map(Some); + trie_data.push(Some((trie_data.len() + 1).into())); trie_data.extend(parsed_value); } // Now, load all children and update their pointers. for (i, child) in children.iter().enumerate() { let child_ptr = load_mpt(child, trie_data, parse_value)?; - trie_data[first_child_ptr + i] = child_ptr.into(); + trie_data[first_child_ptr + i] = Some(child_ptr.into()); } Ok(node_ptr) } Node::Extension { nibbles, child } => { - trie_data.push(nibbles.count.into()); - trie_data.push( + trie_data.push(Some(nibbles.count.into())); + trie_data.push(Some( nibbles .try_into() .map_err(|_| ProgramError::IntegerTooLarge)?, - ); - trie_data.push((trie_data.len() + 1).into()); + )); + trie_data.push(Some((trie_data.len() + 1).into())); let child_ptr = load_mpt(child, trie_data, parse_value)?; if child_ptr == 0 { - trie_data.push(0.into()); + trie_data.push(Some(0.into())); } Ok(node_ptr) } Node::Leaf { nibbles, value } => { - trie_data.push(nibbles.count.into()); - trie_data.push( + trie_data.push(Some(nibbles.count.into())); + trie_data.push(Some( nibbles .try_into() .map_err(|_| ProgramError::IntegerTooLarge)?, - ); + )); // Set `value_ptr_ptr`. - trie_data.push((trie_data.len() + 1).into()); + trie_data.push(Some((trie_data.len() + 1).into())); - let leaf = parse_value(value)?; + let leaf = parse_value(value)?.into_iter().map(Some); trie_data.extend(leaf); Ok(node_ptr) @@ -216,19 +224,19 @@ where fn load_state_trie( trie: &HashedPartialTrie, key: Nibbles, - trie_data: &mut Vec, + trie_data: &mut Vec>, + storage_tries_by_state_key: &HashMap, ) -> Result { let node_ptr = trie_data.len(); let type_of_trie = PartialTrieType::of(trie) as u32; if type_of_trie > 0 { - trie_data.push(type_of_trie.into()); + trie_data.push(Some(type_of_trie.into())); } match trie.deref() { Node::Empty => Ok(0), Node::Hash(h) => { - trie_data.push(h2u(*h)); - + trie_data.push(Some(h2u(*h))); Ok(node_ptr) } Node::Branch { children, value } => { @@ -239,9 +247,9 @@ fn load_state_trie( } // First, set children pointers to 0. let first_child_ptr = trie_data.len(); - trie_data.extend(vec![U256::zero(); 16]); + trie_data.extend(vec![Some(U256::zero()); 16]); // Then, set value pointer to 0. - trie_data.push(U256::zero()); + trie_data.push(Some(U256::zero())); // Now, load all children and update their pointers. for (i, child) in children.iter().enumerate() { @@ -252,25 +260,25 @@ fn load_state_trie( let child_ptr = load_state_trie(child, extended_key, trie_data, storage_tries_by_state_key)?; - trie_data[first_child_ptr + i] = child_ptr.into(); + trie_data[first_child_ptr + i] = Some(child_ptr.into()); } Ok(node_ptr) } Node::Extension { nibbles, child } => { - trie_data.push(nibbles.count.into()); - trie_data.push( + trie_data.push(Some(nibbles.count.into())); + trie_data.push(Some( nibbles .try_into() .map_err(|_| ProgramError::IntegerTooLarge)?, - ); + )); // Set `value_ptr_ptr`. - trie_data.push((trie_data.len() + 1).into()); + trie_data.push(Some((trie_data.len() + 1).into())); let extended_key = key.merge_nibbles(nibbles); let child_ptr = load_state_trie(child, extended_key, trie_data, storage_tries_by_state_key)?; if child_ptr == 0 { - trie_data.push(0.into()); + trie_data.push(Some(0.into())); } Ok(node_ptr) @@ -294,24 +302,26 @@ fn load_state_trie( assert_eq!(storage_trie.hash(), storage_root, "In TrieInputs, an account's storage_root didn't match the associated storage trie hash"); - trie_data.push(nibbles.count.into()); - trie_data.push( + trie_data.push(Some(nibbles.count.into())); + trie_data.push(Some( nibbles .try_into() .map_err(|_| ProgramError::IntegerTooLarge)?, - ); + )); // Set `value_ptr_ptr`. - trie_data.push((trie_data.len() + 1).into()); + trie_data.push(Some((trie_data.len() + 1).into())); - trie_data.push(nonce); - trie_data.push(balance); + trie_data.push(Some(nonce)); + trie_data.push(Some(balance)); // Storage trie ptr. let storage_ptr_ptr = trie_data.len(); - trie_data.push((trie_data.len() + 2).into()); - trie_data.push(code_hash.into_uint()); - let storage_ptr = load_mpt(storage_trie, trie_data, &parse_storage_value)?; + trie_data.push(Some((trie_data.len() + 2).into())); + trie_data.push(Some(code_hash.into_uint())); + // We don't need to store the slot values, as they will be overwritten in + // `mpt_set_payload`. + let storage_ptr = load_mpt(storage_trie, trie_data, &parse_storage_value_no_return)?; if storage_ptr == 0 { - trie_data[storage_ptr_ptr] = 0.into(); + trie_data[storage_ptr_ptr] = Some(0.into()); } Ok(node_ptr) @@ -319,10 +329,199 @@ fn load_state_trie( } } -pub(crate) fn load_all_mpts( +fn get_state_and_storage_leaves( + trie: &HashedPartialTrie, + key: Nibbles, + state_leaves: &mut Vec>, + storage_leaves: &mut Vec>, + trie_data: &mut Vec>, + storage_tries_by_state_key: &HashMap, +) -> Result<(), ProgramError> { + match trie.deref() { + Node::Branch { children, value } => { + if !value.is_empty() { + return Err(ProgramError::ProverInputError( + ProverInputError::InvalidMptInput, + )); + } + + for (i, child) in children.iter().enumerate() { + let extended_key = key.merge_nibbles(&Nibbles { + count: 1, + packed: i.into(), + }); + + get_state_and_storage_leaves( + child, + extended_key, + state_leaves, + storage_leaves, + trie_data, + storage_tries_by_state_key, + )?; + } + + Ok(()) + } + Node::Extension { nibbles, child } => { + let extended_key = key.merge_nibbles(nibbles); + get_state_and_storage_leaves( + child, + extended_key, + state_leaves, + storage_leaves, + trie_data, + storage_tries_by_state_key, + )?; + + Ok(()) + } + Node::Leaf { nibbles, value } => { + let account: AccountRlp = rlp::decode(value).map_err(|_| ProgramError::InvalidRlp)?; + let AccountRlp { + nonce, + balance, + storage_root, + code_hash, + } = account; + + let storage_hash_only = HashedPartialTrie::new(Node::Hash(storage_root)); + let merged_key = key.merge_nibbles(nibbles); + let storage_trie: &HashedPartialTrie = storage_tries_by_state_key + .get(&merged_key) + .copied() + .unwrap_or(&storage_hash_only); + + assert_eq!( + storage_trie.hash(), + storage_root, + "In TrieInputs, an account's storage_root didn't match the associated storage trie hash" + ); + + // The last leaf must point to the new one. + let len = state_leaves.len(); + state_leaves[len - 1] = Some(U256::from( + Segment::AccountsLinkedList as usize + state_leaves.len(), + )); + // The nibbles are the address. + let address = merged_key + .try_into() + .map_err(|_| ProgramError::IntegerTooLarge)?; + state_leaves.push(Some(address)); + // Set `value_ptr_ptr`. + state_leaves.push(Some(trie_data.len().into())); + // Set counter. + state_leaves.push(Some(0.into())); + // Set the next node as the initial node. + state_leaves.push(Some((Segment::AccountsLinkedList as usize).into())); + + // Push the payload in the trie data. + trie_data.push(Some(nonce)); + trie_data.push(Some(balance)); + // The Storage pointer is only written in the trie. + trie_data.push(Some(0.into())); + trie_data.push(Some(code_hash.into_uint())); + get_storage_leaves( + address, + empty_nibbles(), + storage_trie, + storage_leaves, + &parse_storage_value, + )?; + + Ok(()) + } + _ => Ok(()), + } +} + +pub(crate) fn get_storage_leaves( + address: U256, + key: Nibbles, + trie: &HashedPartialTrie, + storage_leaves: &mut Vec>, + parse_value: &F, +) -> Result<(), ProgramError> +where + F: Fn(&[u8]) -> Result, ProgramError>, +{ + match trie.deref() { + Node::Branch { children, value: _ } => { + // Now, load all children and update their pointers. + for (i, child) in children.iter().enumerate() { + let extended_key = key.merge_nibbles(&Nibbles { + count: 1, + packed: i.into(), + }); + get_storage_leaves(address, extended_key, child, storage_leaves, parse_value)?; + } + + Ok(()) + } + + Node::Extension { nibbles, child } => { + let extended_key = key.merge_nibbles(nibbles); + get_storage_leaves(address, extended_key, child, storage_leaves, parse_value)?; + + Ok(()) + } + Node::Leaf { nibbles, value } => { + // The last leaf must point to the new one. + let len = storage_leaves.len(); + let merged_key = key.merge_nibbles(nibbles); + storage_leaves[len - 1] = Some(U256::from( + Segment::StorageLinkedList as usize + storage_leaves.len(), + )); + // Write the address. + storage_leaves.push(Some(address)); + // Write the key. + storage_leaves.push(Some( + merged_key + .try_into() + .map_err(|_| ProgramError::IntegerTooLarge)?, + )); + // Write `value_ptr_ptr`. + let leaves = parse_value(value)? + .into_iter() + .map(Some) + .collect::>(); + let leaf = match leaves.len() { + 1 => leaves[0], + _ => panic!("Slot can only store exactly one value."), + }; + storage_leaves.push(leaf); + // Write the counter. + storage_leaves.push(Some(0.into())); + // Set the next node as the initial node. + storage_leaves.push(Some((Segment::StorageLinkedList as usize).into())); + + Ok(()) + } + _ => Ok(()), + } +} + +/// A type alias used to gather: +/// - the trie root pointers for all tries +/// - the vector of state trie leaves +/// - the vector of storage trie leaves +/// - the `TrieData` segment's memory content +type TriePtrsLinkedLists = ( + TrieRootPtrs, + Vec>, + Vec>, + Vec>, +); + +pub(crate) fn load_linked_lists_and_txn_and_receipt_mpts( trie_inputs: &TrieInputs, -) -> Result<(TrieRootPtrs, Vec), ProgramError> { - let mut trie_data = vec![U256::zero()]; +) -> Result { + let mut state_leaves = + empty_list_mem::(Segment::AccountsLinkedList).to_vec(); + let mut storage_leaves = + empty_list_mem::(Segment::StorageLinkedList).to_vec(); + let mut trie_data = vec![Some(U256::zero())]; + let storage_tries_by_state_key = trie_inputs .storage_tries .iter() @@ -333,13 +532,6 @@ pub(crate) fn load_all_mpts( }) .collect(); - let state_root_ptr = load_state_trie( - &trie_inputs.state_trie, - empty_nibbles(), - &mut trie_data, - &storage_tries_by_state_key, - )?; - let txn_root_ptr = load_mpt(&trie_inputs.transactions_trie, &mut trie_data, &|rlp| { let mut parsed_txn = vec![U256::from(rlp.len())]; parsed_txn.extend(rlp.iter().copied().map(U256::from)); @@ -348,13 +540,47 @@ pub(crate) fn load_all_mpts( let receipt_root_ptr = load_mpt(&trie_inputs.receipts_trie, &mut trie_data, &parse_receipts)?; - let trie_root_ptrs = TrieRootPtrs { - state_root_ptr, - txn_root_ptr, - receipt_root_ptr, - }; + get_state_and_storage_leaves( + &trie_inputs.state_trie, + empty_nibbles(), + &mut state_leaves, + &mut storage_leaves, + &mut trie_data, + &storage_tries_by_state_key, + )?; + + Ok(( + TrieRootPtrs { + state_root_ptr: None, + txn_root_ptr, + receipt_root_ptr, + }, + state_leaves, + storage_leaves, + trie_data, + )) +} - Ok((trie_root_ptrs, trie_data)) +pub(crate) fn load_state_mpt( + trie_inputs: &TrimmedTrieInputs, + trie_data: &mut Vec>, +) -> Result { + let storage_tries_by_state_key = trie_inputs + .storage_tries + .iter() + .map(|(hashed_address, storage_trie)| { + let key = Nibbles::from_bytes_be(hashed_address.as_bytes()) + .expect("An H256 is 32 bytes long"); + (key, storage_trie) + }) + .collect(); + + load_state_trie( + &trie_inputs.state_trie, + empty_nibbles(), + trie_data, + &storage_tries_by_state_key, + ) } pub mod transaction_testing { diff --git a/evm_arithmetization/src/generation/prover_input.rs b/evm_arithmetization/src/generation/prover_input.rs index 29186d364..601e1c525 100644 --- a/evm_arithmetization/src/generation/prover_input.rs +++ b/evm_arithmetization/src/generation/prover_input.rs @@ -10,13 +10,14 @@ use num_bigint::BigUint; use plonky2::field::types::Field; use serde::{Deserialize, Serialize}; +use super::linked_list::LinkedList; +use super::mpt::load_state_mpt; use crate::cpu::kernel::cancun_constants::KZG_VERSIONED_HASH; use crate::cpu::kernel::constants::cancun_constants::{ BLOB_BASE_FEE_UPDATE_FRACTION, G2_TRUSTED_SETUP_POINT, MIN_BASE_FEE_PER_BLOB_GAS, POINT_EVALUATION_PRECOMPILE_RETURN_VALUE, }; use crate::cpu::kernel::constants::context_metadata::ContextMetadata; -use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; use crate::cpu::kernel::interpreter::simulate_cpu_and_get_user_jumps; use crate::curve_pairings::{bls381, CurveAff, CyclicGroup}; use crate::extension_tower::{FieldExt, Fp12, Fp2, BLS381, BLS_BASE, BLS_SCALAR, BN254, BN_BASE}; @@ -40,6 +41,11 @@ use crate::witness::util::{current_context_peek, stack_peek}; #[derive(PartialEq, Eq, Debug, Clone, Serialize, Deserialize)] pub struct ProverInputFn(Vec); +pub const ADDRESSES_ACCESS_LIST_LEN: usize = 2; +pub const STORAGE_KEYS_ACCESS_LIST_LEN: usize = 4; +pub const ACCOUNTS_LINKED_LIST_NODE_SIZE: usize = 4; +pub const STORAGE_LINKED_LIST_NODE_SIZE: usize = 5; + impl From> for ProverInputFn { fn from(v: Vec) -> Self { Self(v) @@ -49,7 +55,7 @@ impl From> for ProverInputFn { impl GenerationState { pub(crate) fn prover_input(&mut self, input_fn: &ProverInputFn) -> Result { match input_fn.0[0].as_str() { - "no_txn" => self.no_txn(), + "end_of_txns" => self.run_end_of_txns(), "trie_ptr" => self.run_trie_ptr(input_fn), "ff" => self.run_ff(input_fn), "sf" => self.run_sf(input_fn), @@ -63,6 +69,7 @@ impl GenerationState { "num_bits" => self.run_num_bits(), "jumpdest_table" => self.run_jumpdest_table(input_fn), "access_lists" => self.run_access_lists(input_fn), + "linked_list" => self.run_linked_list(input_fn), "ger" => self.run_global_exit_roots(), "kzg_point_eval" => self.run_kzg_point_eval(), "kzg_point_eval_2" => self.run_kzg_point_eval_2(), @@ -70,16 +77,54 @@ impl GenerationState { } } - fn no_txn(&mut self) -> Result { - Ok(U256::from(self.inputs.signed_txn.is_none() as u8)) + fn run_end_of_txns(&mut self) -> Result { + // Reset the jumpdest table before the next transaction. + self.jumpdest_table = None; + let end = self.next_txn_index == self.inputs.txn_hashes.len(); + if end { + Ok(U256::one()) + } else { + self.next_txn_index += 1; + Ok(U256::zero()) + } } fn run_trie_ptr(&mut self, input_fn: &ProverInputFn) -> Result { let trie = input_fn.0[1].as_str(); match trie { - "state" => Ok(U256::from(self.trie_root_ptrs.state_root_ptr)), + "state" => match self.trie_root_ptrs.state_root_ptr { + Some(state_root_ptr) => Ok(state_root_ptr), + None => { + let mut new_content = self.memory.get_preinit_memory(Segment::TrieData); + + let n = load_state_mpt(&self.inputs.trimmed_tries, &mut new_content)?; + + self.memory.insert_preinitialized_segment( + Segment::TrieData, + crate::witness::memory::MemorySegmentState { + content: new_content, + }, + ); + Ok(n) + } + } + .map(U256::from), "txn" => Ok(U256::from(self.trie_root_ptrs.txn_root_ptr)), "receipt" => Ok(U256::from(self.trie_root_ptrs.receipt_root_ptr)), + "trie_data_size" => Ok(self + .memory + .preinitialized_segments + .get(&Segment::TrieData) + .unwrap_or(&crate::witness::memory::MemorySegmentState { content: vec![] }) + .content + .len() + .max( + self.memory.contexts[0].segments[Segment::TrieData.unscale()] + .content + .len(), + ) + .into()), + _ => Err(ProgramError::ProverInputError(InvalidInput)), } } @@ -282,6 +327,51 @@ impl GenerationState { } } + /// Generates either the next used jump address or the proof for the last + /// jump address. + fn run_linked_list(&mut self, input_fn: &ProverInputFn) -> Result { + match input_fn.0[1].as_str() { + "insert_account" => self.run_next_insert_account(), + "remove_account" => self.run_next_remove_account(), + "insert_slot" => self.run_next_insert_slot(), + "remove_slot" => self.run_next_remove_slot(), + "remove_address_slots" => self.run_next_remove_address_slots(), + "accounts_linked_list_len" => { + let len = self + .memory + .preinitialized_segments + .get(&Segment::AccountsLinkedList) + .unwrap_or(&crate::witness::memory::MemorySegmentState { content: vec![] }) + .content + .len() + .max( + self.memory.contexts[0].segments[Segment::AccountsLinkedList.unscale()] + .content + .len(), + ); + + Ok((Segment::AccountsLinkedList as usize + len).into()) + } + "storage_linked_list_len" => { + let len = self + .memory + .preinitialized_segments + .get(&Segment::StorageLinkedList) + .unwrap_or(&crate::witness::memory::MemorySegmentState { content: vec![] }) + .content + .len() + .max( + self.memory.contexts[0].segments[Segment::StorageLinkedList.unscale()] + .content + .len(), + ); + + Ok((Segment::StorageLinkedList as usize + len).into()) + } + _ => Err(ProgramError::ProverInputError(InvalidInput)), + } + } + fn run_global_exit_roots(&mut self) -> Result { self.ger_prover_inputs .pop() @@ -335,7 +425,7 @@ impl GenerationState { /// Returns a non-jumpdest proof for the address on the top of the stack. A /// non-jumpdest proof is the closest address to the address on the top of /// the stack, if the closest address is >= 32, or zero otherwise. - fn run_next_non_jumpdest_proof(&mut self) -> Result { + fn run_next_non_jumpdest_proof(&self) -> Result { let code = self.get_current_code()?; let address = u256_to_usize(stack_peek(self, 0)?)?; let closest_opcode_addr = get_closest_opcode_address(&code, address); @@ -348,57 +438,207 @@ impl GenerationState { /// Returns a pointer to an element in the list whose value is such that /// `value <= addr < next_value` and `addr` is the top of the stack. - fn run_next_addresses_insert(&mut self) -> Result { + fn run_next_addresses_insert(&self) -> Result { let addr = stack_peek(self, 0)?; - for (curr_ptr, next_addr, _) in self.get_addresses_access_list()? { - if next_addr > addr { - // In order to avoid pointers to the next ptr, we use the fact - // that valid pointers and Segment::AccessedAddresses are always even - return Ok(((Segment::AccessedAddresses as usize + curr_ptr) / 2usize).into()); - } + if let Some((([_, ptr], _), _)) = self + .get_addresses_access_list()? + .zip(self.get_addresses_access_list()?.skip(1)) + .zip(self.get_addresses_access_list()?.skip(2)) + .find(|&((_, [prev_addr, _]), [next_addr, _])| { + (prev_addr <= addr || prev_addr == U256::MAX) && addr < next_addr + }) + { + Ok(ptr / U256::from(2)) + } else { + Ok((Segment::AccessedAddresses as usize).into()) } - Ok((Segment::AccessedAddresses as usize).into()) } /// Returns a pointer to an element in the list whose value is such that /// `value < addr == next_value` and addr is the top of the stack. - /// If the element is not in the list returns loops forever - fn run_next_addresses_remove(&mut self) -> Result { + /// If the element is not in the list, it loops forever + fn run_next_addresses_remove(&self) -> Result { let addr = stack_peek(self, 0)?; - for (curr_ptr, next_addr, _) in self.get_addresses_access_list()? { - if next_addr == addr { - return Ok(((Segment::AccessedAddresses as usize + curr_ptr) / 2usize).into()); - } + if let Some(([_, ptr], _)) = self + .get_addresses_access_list()? + .zip(self.get_addresses_access_list()?.skip(2)) + .find(|&(_, [next_addr, _])| next_addr == addr) + { + Ok(ptr / U256::from(2)) + } else { + Ok((Segment::AccessedAddresses as usize).into()) } - Ok((Segment::AccessedAddresses as usize).into()) } /// Returns a pointer to the predecessor of the top of the stack in the /// accessed storage keys list. - fn run_next_storage_insert(&mut self) -> Result { + fn run_next_storage_insert(&self) -> Result { let addr = stack_peek(self, 0)?; let key = stack_peek(self, 1)?; - for (curr_ptr, next_addr, next_key) in self.get_storage_keys_access_list()? { - if next_addr > addr || (next_addr == addr && next_key > key) { - // In order to avoid pointers to the key, value or next ptr, we use the fact - // that valid pointers and Segment::AccessedAddresses are always multiples of 4 - return Ok(((Segment::AccessedStorageKeys as usize + curr_ptr) / 4usize).into()); - } + if let Some((([.., ptr], _), _)) = self + .get_storage_keys_access_list()? + .zip(self.get_storage_keys_access_list()?.skip(1)) + .zip(self.get_storage_keys_access_list()?.skip(2)) + .find( + |&((_, [prev_addr, prev_key, ..]), [next_addr, next_key, ..])| { + let prev_is_less_or_equal = (prev_addr < addr || prev_addr == U256::MAX) + || (prev_addr == addr && prev_key <= key); + let next_is_strictly_larger = + next_addr > addr || (next_addr == addr && next_key > key); + prev_is_less_or_equal && next_is_strictly_larger + }, + ) + { + Ok(ptr / U256::from(4)) + } else { + Ok((Segment::AccessedStorageKeys as usize).into()) } - Ok((Segment::AccessedAddresses as usize).into()) } /// Returns a pointer to the predecessor of the top of the stack in the /// accessed storage keys list. - fn run_next_storage_remove(&mut self) -> Result { + fn run_next_storage_remove(&self) -> Result { let addr = stack_peek(self, 0)?; let key = stack_peek(self, 1)?; - for (curr_ptr, next_addr, next_key) in self.get_storage_keys_access_list()? { - if (next_addr == addr && next_key == key) || next_addr == U256::MAX { - return Ok(((Segment::AccessedStorageKeys as usize + curr_ptr) / 4usize).into()); - } + if let Some(([.., ptr], _)) = self + .get_storage_keys_access_list()? + .zip(self.get_storage_keys_access_list()?.skip(2)) + .find(|&(_, [next_addr, next_key, ..])| (next_addr == addr && next_key == key)) + { + Ok(ptr / U256::from(4)) + } else { + Ok((Segment::AccessedStorageKeys as usize).into()) + } + } + + /// Returns a pointer to a node in the list such that + /// `node[0] <= addr < next_node[0]` and `addr` is the top of the stack. + fn run_next_insert_account(&self) -> Result { + let addr = stack_peek(self, 0)?; + let accounts_mem = self.memory.get_preinit_memory(Segment::AccountsLinkedList); + let accounts_linked_list = + LinkedList::::from_mem_and_segment( + &accounts_mem, + Segment::AccountsLinkedList, + )?; + + if let Some(([.., pred_ptr], [_, ..], _)) = + accounts_linked_list + .tuple_windows() + .find(|&(_, [prev_addr, ..], [next_addr, ..])| { + (prev_addr <= addr || prev_addr == U256::MAX) && addr < next_addr + }) + { + Ok(pred_ptr / U256::from(ACCOUNTS_LINKED_LIST_NODE_SIZE)) + } else { + Ok((Segment::AccountsLinkedList as usize).into()) + } + } + + /// Returns an unscaled pointer to an element in the list such that + /// `node[0] <= addr < next_node[0]`, or node[0] == addr and `node[1] <= + /// key < next_node[1]`, where `addr` and `key` are the elements at the top + /// of the stack. + fn run_next_insert_slot(&self) -> Result { + let addr = stack_peek(self, 0)?; + let key = stack_peek(self, 1)?; + let storage_mem = self.memory.get_preinit_memory(Segment::StorageLinkedList); + let storage_linked_list = + LinkedList::::from_mem_and_segment( + &storage_mem, + Segment::StorageLinkedList, + )?; + + if let Some(([.., pred_ptr], _, _)) = storage_linked_list.tuple_windows().find( + |&(_, [prev_addr, prev_key, ..], [next_addr, next_key, ..])| { + let prev_is_less_or_equal = (prev_addr < addr || prev_addr == U256::MAX) + || (prev_addr == addr && prev_key <= key); + let next_is_strictly_larger = + next_addr > addr || (next_addr == addr && next_key > key); + prev_is_less_or_equal && next_is_strictly_larger + }, + ) { + Ok((pred_ptr - U256::from(Segment::StorageLinkedList as usize)) + / U256::from(STORAGE_LINKED_LIST_NODE_SIZE)) + } else { + Ok(U256::zero()) + } + } + + /// Returns a pointer `ptr` to a node of the form [next_addr, ..] in the + /// list such that `next_addr = addr` and `addr` is the top of the stack. + /// If the element is not in the list, loops forever. + fn run_next_remove_account(&self) -> Result { + let addr = stack_peek(self, 0)?; + let accounts_mem = self.memory.get_preinit_memory(Segment::AccountsLinkedList); + let accounts_linked_list = + LinkedList::::from_mem_and_segment( + &accounts_mem, + Segment::AccountsLinkedList, + )?; + + if let Some(([.., ptr], _, _)) = accounts_linked_list + .tuple_windows() + .find(|&(_, _, [next_node_addr, ..])| next_node_addr == addr) + { + Ok(ptr / ACCOUNTS_LINKED_LIST_NODE_SIZE) + } else { + Ok((Segment::AccountsLinkedList as usize).into()) + } + } + + /// Returns a pointer `ptr` to a node = `[next_addr, next_key]` in the list + /// such that `next_addr == addr` and `next_key == key`, + /// and `addr, key` are the elements at the top of the stack. + /// If the element is not in the list, loops forever. + fn run_next_remove_slot(&self) -> Result { + let addr = stack_peek(self, 0)?; + let key = stack_peek(self, 1)?; + let storage_mem = self.memory.get_preinit_memory(Segment::StorageLinkedList); + let storage_linked_list = + LinkedList::::from_mem_and_segment( + &storage_mem, + Segment::StorageLinkedList, + )?; + + if let Some(([.., ptr], _, _)) = storage_linked_list + .tuple_windows() + .find(|&(_, _, [next_addr, next_key, ..])| next_addr == addr && next_key == key) + { + Ok((ptr - U256::from(Segment::StorageLinkedList as usize)) + / U256::from(STORAGE_LINKED_LIST_NODE_SIZE)) + } else { + Ok((Segment::StorageLinkedList as usize).into()) + } + } + + /// Returns a pointer `ptr` to a storage node in the storage linked list. + /// The node's next element = `[next_addr, next_key]` is such that + /// `next_addr = addr`, if such an element exists, or such that + /// `next_addr = @U256_MAX`. This is used to determine the first storage + /// node for the account at `addr`. `addr` is the element at the top of the + /// stack. + fn run_next_remove_address_slots(&self) -> Result { + let addr = stack_peek(self, 0)?; + let storage_mem = self.memory.get_preinit_memory(Segment::StorageLinkedList); + let storage_linked_list = + LinkedList::::from_mem_and_segment( + &storage_mem, + Segment::StorageLinkedList, + )?; + + if let Some(([.., pred_ptr], _, _)) = storage_linked_list.tuple_windows().find( + |&(_, [prev_addr, _, ..], [next_addr, _, ..])| { + let prev_is_less = prev_addr < addr || prev_addr == U256::MAX; + let next_is_larger_or_equal = next_addr >= addr; + prev_is_less && next_is_larger_or_equal + }, + ) { + Ok((pred_ptr - U256::from(Segment::StorageLinkedList as usize)) + / U256::from(STORAGE_LINKED_LIST_NODE_SIZE)) + } else { + Ok((Segment::StorageLinkedList as usize).into()) } - Ok((Segment::AccessedStorageKeys as usize).into()) } /// Returns the first part of the KZG precompile output. @@ -620,42 +860,26 @@ impl GenerationState { } } - pub(crate) fn get_addresses_access_list(&self) -> Result { + pub(crate) fn get_addresses_access_list( + &self, + ) -> Result, ProgramError> { // `GlobalMetadata::AccessedAddressesLen` stores the value of the next available // virtual address in the segment. In order to get the length we need - // to subtract `Segment::AccessedAddresses` as usize. - let acc_addr_len = - u256_to_usize(self.get_global_metadata(GlobalMetadata::AccessedAddressesLen))? - - Segment::AccessedAddresses as usize; - AccList::from_mem_and_segment( - &self.memory.contexts[0].segments[Segment::AccessedAddresses.unscale()].content - [..acc_addr_len], + // to substract `Segment::AccessedAddresses` as usize. + LinkedList::from_mem_and_segment( + &self.memory.contexts[0].segments[Segment::AccessedAddresses.unscale()].content, Segment::AccessedAddresses, ) } - fn get_global_metadata(&self, data: GlobalMetadata) -> U256 { - self.memory.get_with_init(MemoryAddress::new( - 0, - Segment::GlobalMetadata, - data.unscale(), - )) - } - - pub(crate) fn get_storage_keys_access_list(&self) -> Result { + pub(crate) fn get_storage_keys_access_list( + &self, + ) -> Result, ProgramError> { // GlobalMetadata::AccessedStorageKeysLen stores the value of the next available // virtual address in the segment. In order to get the length we need - // to subtract Segment::AccessedStorageKeys as usize - let acc_storage_len = u256_to_usize( - self.memory.get_with_init(MemoryAddress::new( - 0, - Segment::GlobalMetadata, - GlobalMetadata::AccessedStorageKeysLen.unscale(), - )) - Segment::AccessedStorageKeys as usize, - )?; - AccList::from_mem_and_segment( - &self.memory.contexts[0].segments[Segment::AccessedStorageKeys.unscale()].content - [..acc_storage_len], + // to substract `Segment::AccessedStorageKeys` as usize. + LinkedList::from_mem_and_segment( + &self.memory.contexts[0].segments[Segment::AccessedStorageKeys.unscale()].content, Segment::AccessedStorageKeys, ) } @@ -753,68 +977,6 @@ impl<'a> Iterator for CodeIterator<'a> { } } -// Iterates over a linked list implemented using a vector `access_list_mem`. -// In this representation, the values of nodes are stored in the range -// `access_list_mem[i..i + node_size - 1]`, and `access_list_mem[i + node_size - -// 1]` holds the address of the next node, where i = node_size * j. -pub(crate) struct AccList<'a> { - access_list_mem: &'a [Option], - node_size: usize, - offset: usize, - pos: usize, -} - -impl<'a> AccList<'a> { - const fn from_mem_and_segment( - access_list_mem: &'a [Option], - segment: Segment, - ) -> Result { - if access_list_mem.is_empty() { - return Err(ProgramError::ProverInputError(InvalidInput)); - } - Ok(Self { - access_list_mem, - node_size: match segment { - Segment::AccessedAddresses => 2, - Segment::AccessedStorageKeys => 4, - _ => return Err(ProgramError::ProverInputError(InvalidInput)), - }, - offset: segment as usize, - pos: 0, - }) - } -} - -impl<'a> Iterator for AccList<'a> { - type Item = (usize, U256, U256); - - fn next(&mut self) -> Option { - if let Ok(new_pos) = - u256_to_usize(self.access_list_mem[self.pos + self.node_size - 1].unwrap_or_default()) - { - let old_pos = self.pos; - self.pos = new_pos - self.offset; - if self.node_size == 2 { - // addresses - Some(( - old_pos, - self.access_list_mem[self.pos].unwrap_or_default(), - U256::zero(), - )) - } else { - // storage_keys - Some(( - old_pos, - self.access_list_mem[self.pos].unwrap_or_default(), - self.access_list_mem[self.pos + 1].unwrap_or_default(), - )) - } - } else { - None - } - } -} - enum EvmField { Bls381Base, Bls381Scalar, diff --git a/evm_arithmetization/src/generation/rlp.rs b/evm_arithmetization/src/generation/rlp.rs index ffc302fd5..c1dfa10d8 100644 --- a/evm_arithmetization/src/generation/rlp.rs +++ b/evm_arithmetization/src/generation/rlp.rs @@ -1,22 +1,25 @@ use ethereum_types::U256; -pub(crate) fn all_rlp_prover_inputs_reversed(signed_txn: &[u8]) -> Vec { - let mut inputs = all_rlp_prover_inputs(signed_txn); +pub(crate) fn all_rlp_prover_inputs_reversed(signed_txns: &[Vec]) -> Vec { + let mut inputs = all_rlp_prover_inputs(signed_txns); inputs.reverse(); inputs } -fn all_rlp_prover_inputs(signed_txn: &[u8]) -> Vec { +fn all_rlp_prover_inputs(signed_txns: &[Vec]) -> Vec { let mut prover_inputs = vec![]; - prover_inputs.push(signed_txn.len().into()); - let mut chunks = signed_txn.chunks_exact(32); - for bytes in chunks.by_ref() { - prover_inputs.push(U256::from_big_endian(bytes)); - } - let mut last_chunk = chunks.remainder().to_vec(); - if !last_chunk.is_empty() { - last_chunk.extend_from_slice(&vec![0u8; 32 - last_chunk.len()]); - prover_inputs.push(U256::from_big_endian(&last_chunk)); + for txn in signed_txns { + prover_inputs.push(txn.len().into()); + let mut chunks = txn.chunks_exact(32); + for bytes in chunks.by_ref() { + prover_inputs.push(U256::from_big_endian(bytes)); + } + let mut last_chunk = chunks.remainder().to_vec(); + if !last_chunk.is_empty() { + last_chunk.extend_from_slice(&vec![0u8; 32 - last_chunk.len()]); + prover_inputs.push(U256::from_big_endian(&last_chunk)); + } } + prover_inputs } diff --git a/evm_arithmetization/src/generation/state.rs b/evm_arithmetization/src/generation/state.rs index 2f1258a81..b5defe364 100644 --- a/evm_arithmetization/src/generation/state.rs +++ b/evm_arithmetization/src/generation/state.rs @@ -8,18 +8,20 @@ use keccak_hash::keccak; use log::Level; use plonky2::field::types::Field; -use super::mpt::{load_all_mpts, TrieRootPtrs}; -use super::TrieInputs; +use super::mpt::TrieRootPtrs; +use super::{TrieInputs, TrimmedGenerationInputs, NUM_EXTRA_CYCLES_AFTER}; use crate::byte_packing::byte_packing_stark::BytePackingOp; use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::constants::context_metadata::ContextMetadata; use crate::cpu::stack::MAX_USER_STACK_SIZE; +use crate::generation::mpt::load_linked_lists_and_txn_and_receipt_mpts; use crate::generation::rlp::all_rlp_prover_inputs_reversed; use crate::generation::CpuColumnsView; use crate::generation::GenerationInputs; use crate::keccak_sponge::columns::KECCAK_WIDTH_BYTES; use crate::keccak_sponge::keccak_sponge_stark::KeccakSpongeOp; use crate::memory::segments::Segment; +use crate::prover::GenerationSegmentData; use crate::util::u256_to_usize; use crate::witness::errors::ProgramError; use crate::witness::memory::MemoryChannel::GeneralPurpose; @@ -75,6 +77,22 @@ pub(crate) trait State { /// Returns the current context. fn get_context(&self) -> usize; + /// Checks whether we have reached the maximal cpu length. + fn at_end_segment(&self, opt_cycle_limit: Option) -> bool { + if let Some(cycle_limit) = opt_cycle_limit { + self.get_clock() == cycle_limit + } else { + false + } + } + + /// Checks whether we have reached the `halt` label in kernel mode. + fn at_halt(&self) -> bool { + let halt = KERNEL.global_labels["halt"]; + let registers = self.get_registers(); + registers.is_kernel && (registers.program_counter == halt) + } + /// Returns the context in which the jumpdest analysis should end. fn get_halt_context(&self) -> Option { None @@ -136,35 +154,72 @@ pub(crate) trait State { /// Applies a `State`'s operations since a checkpoint. fn apply_ops(&mut self, checkpoint: GenerationStateCheckpoint); - /// Return the offsets at which execution must halt + /// Returns the offsets at which execution must halt fn get_halt_offsets(&self) -> Vec; + fn update_interpreter_final_registers(&mut self, _final_registers: RegistersState) {} + + /// Returns all the memory from non-stale contexts. + /// This is only necessary during segment data generation, hence the blanket + /// impl returns a dummy value. + fn get_active_memory(&self) -> Option { + None + } + /// Simulates the CPU. It only generates the traces if the `State` is a /// `GenerationState`. - fn run_cpu(&mut self) -> anyhow::Result<()> + fn run_cpu( + &mut self, + max_cpu_len_log: Option, + ) -> anyhow::Result<(RegistersState, Option)> where Self: Transition, - Self: Sized, { let halt_offsets = self.get_halt_offsets(); + let cycle_limit = + max_cpu_len_log.map(|max_len_log| (1 << max_len_log) - NUM_EXTRA_CYCLES_AFTER); + + let mut final_registers = RegistersState::default(); + let mut running = true; + let mut final_clock = 0; loop { let registers = self.get_registers(); let pc = registers.program_counter; - let halt = registers.is_kernel && halt_offsets.contains(&pc); + let halt_final = registers.is_kernel && halt_offsets.contains(&pc); + if running && (self.at_halt() || self.at_end_segment(cycle_limit)) { + running = false; + final_registers = registers; - // If we've reached the kernel's halt routine, halt. - if halt { - if let Some(halt_context) = self.get_halt_context() { - if registers.context == halt_context { - // Only happens during jumpdest analysis. - return Ok(()); + // If `stack_len` is 0, `stack_top` still contains a residual value. + if final_registers.stack_len == 0 { + final_registers.stack_top = 0.into(); + } + // If we are in the interpreter, we need to set the final register values. + self.update_interpreter_final_registers(final_registers); + final_clock = self.get_clock(); + self.final_exception()?; + } + + let opt_halt_context = self.get_halt_context(); + if registers.is_kernel && halt_final { + if let Some(halt_context) = opt_halt_context { + if self.get_context() == halt_context { + // Only happens during jumpdest analysis, we don't care about the output. + return Ok((final_registers, None)); } } else { + if !running { + debug_assert!(self.get_clock() - final_clock == NUM_EXTRA_CYCLES_AFTER - 1); + } + let final_mem = self.get_active_memory(); #[cfg(not(test))] - log::info!("CPU halted after {} cycles", self.get_clock()); - return Ok(()); + self.log( + Level::Info, + format!("CPU halted after {} cycles", self.get_clock()), + ); + return Ok((final_registers, final_mem)); } } @@ -213,6 +268,7 @@ pub(crate) trait State { if might_overflow_op(op) { self.get_mut_registers().check_overflow = true; } + Ok(()) } Err(e) => { @@ -244,7 +300,7 @@ pub(crate) trait State { fn base_row(&mut self) -> (CpuColumnsView, u8) { let generation_state = self.get_mut_generation_state(); let mut row: CpuColumnsView = CpuColumnsView::default(); - row.clock = F::from_canonical_usize(generation_state.traces.clock()); + row.clock = F::from_canonical_usize(generation_state.traces.clock() + 1); row.context = F::from_canonical_usize(generation_state.registers.context); row.program_counter = F::from_canonical_usize(generation_state.registers.program_counter); row.is_kernel_mode = F::from_bool(generation_state.registers.is_kernel); @@ -269,13 +325,19 @@ pub(crate) trait State { } } -#[derive(Debug)] -pub(crate) struct GenerationState { - pub(crate) inputs: GenerationInputs, +#[derive(Debug, Default)] +pub struct GenerationState { + pub(crate) inputs: TrimmedGenerationInputs, pub(crate) registers: RegistersState, pub(crate) memory: MemoryState, pub(crate) traces: Traces, + pub(crate) next_txn_index: usize, + + /// Memory used by stale contexts can be pruned so proving segments can be + /// smaller. + pub(crate) stale_contexts: Vec, + /// Prover inputs containing RLP data, in reverse order so that the next /// input can be obtained via `pop()`. pub(crate) rlp_prover_inputs: Vec, @@ -307,45 +369,82 @@ pub(crate) struct GenerationState { } impl GenerationState { - fn preinitialize_mpts(&mut self, trie_inputs: &TrieInputs) -> TrieRootPtrs { - let (trie_roots_ptrs, trie_data) = - load_all_mpts(trie_inputs).expect("Invalid MPT data for preinitialization"); - - self.memory.contexts[0].segments[Segment::TrieData.unscale()].content = - trie_data.iter().map(|&val| Some(val)).collect(); + fn preinitialize_linked_lists_and_txn_and_receipt_mpts( + &mut self, + trie_inputs: &TrieInputs, + ) -> TrieRootPtrs { + let (trie_roots_ptrs, state_leaves, storage_leaves, trie_data) = + load_linked_lists_and_txn_and_receipt_mpts(trie_inputs) + .expect("Invalid MPT data for preinitialization"); + + self.memory.insert_preinitialized_segment( + Segment::AccountsLinkedList, + crate::witness::memory::MemorySegmentState { + content: state_leaves, + }, + ); + self.memory.insert_preinitialized_segment( + Segment::StorageLinkedList, + crate::witness::memory::MemorySegmentState { + content: storage_leaves, + }, + ); + self.memory.insert_preinitialized_segment( + Segment::TrieData, + crate::witness::memory::MemorySegmentState { content: trie_data }, + ); trie_roots_ptrs } - pub(crate) fn new(inputs: GenerationInputs, kernel_code: &[u8]) -> Result { - let rlp_prover_inputs = - all_rlp_prover_inputs_reversed(inputs.clone().signed_txn.as_ref().unwrap_or(&vec![])); + + pub(crate) fn new(inputs: &GenerationInputs, kernel_code: &[u8]) -> Result { + let rlp_prover_inputs = all_rlp_prover_inputs_reversed(&inputs.signed_txns); let withdrawal_prover_inputs = all_withdrawals_prover_inputs_reversed(&inputs.withdrawals); let ger_prover_inputs = all_ger_prover_inputs_reversed(&inputs.global_exit_roots); let bignum_modmul_result_limbs = Vec::new(); let mut state = Self { - inputs: inputs.clone(), + inputs: inputs.trim(), registers: Default::default(), memory: MemoryState::new(kernel_code), traces: Traces::default(), + next_txn_index: 0, + stale_contexts: Vec::new(), rlp_prover_inputs, withdrawal_prover_inputs, ger_prover_inputs, state_key_to_address: HashMap::new(), bignum_modmul_result_limbs, trie_root_ptrs: TrieRootPtrs { - state_root_ptr: 0, + state_root_ptr: Some(0), txn_root_ptr: 0, receipt_root_ptr: 0, }, jumpdest_table: None, }; - let trie_root_ptrs = state.preinitialize_mpts(&inputs.tries); + let trie_root_ptrs = + state.preinitialize_linked_lists_and_txn_and_receipt_mpts(&inputs.tries); state.trie_root_ptrs = trie_root_ptrs; Ok(state) } + pub(crate) fn new_with_segment_data( + trimmed_inputs: &TrimmedGenerationInputs, + segment_data: &GenerationSegmentData, + ) -> Result { + let mut state = Self { + inputs: trimmed_inputs.clone(), + ..Default::default() + }; + + state.memory.preinitialized_segments = segment_data.memory.preinitialized_segments.clone(); + + state.set_segment_data(segment_data); + + Ok(state) + } + /// Updates `program_counter`, and potentially adds some extra handling if /// we're jumping to a special location. pub(crate) fn jump_to(&mut self, dst: usize) -> Result<(), ProgramError> { @@ -416,23 +515,48 @@ impl GenerationState { /// Clones everything but the traces. pub(crate) fn soft_clone(&self) -> GenerationState { Self { - inputs: self.inputs.clone(), + inputs: self.inputs.clone(), // inputs have already been trimmed here registers: self.registers, memory: self.memory.clone(), traces: Traces::default(), + next_txn_index: 0, + stale_contexts: Vec::new(), rlp_prover_inputs: self.rlp_prover_inputs.clone(), state_key_to_address: self.state_key_to_address.clone(), bignum_modmul_result_limbs: self.bignum_modmul_result_limbs.clone(), withdrawal_prover_inputs: self.withdrawal_prover_inputs.clone(), ger_prover_inputs: self.ger_prover_inputs.clone(), trie_root_ptrs: TrieRootPtrs { - state_root_ptr: 0, + state_root_ptr: Some(0), txn_root_ptr: 0, receipt_root_ptr: 0, }, jumpdest_table: None, } } + + pub(crate) fn set_segment_data(&mut self, segment_data: &GenerationSegmentData) { + self.bignum_modmul_result_limbs + .clone_from(&segment_data.extra_data.bignum_modmul_result_limbs); + self.rlp_prover_inputs + .clone_from(&segment_data.extra_data.rlp_prover_inputs); + self.withdrawal_prover_inputs + .clone_from(&segment_data.extra_data.withdrawal_prover_inputs); + self.ger_prover_inputs + .clone_from(&segment_data.extra_data.ger_prover_inputs); + self.trie_root_ptrs + .clone_from(&segment_data.extra_data.trie_root_ptrs); + self.jumpdest_table + .clone_from(&segment_data.extra_data.jumpdest_table); + self.next_txn_index = segment_data.extra_data.next_txn_index; + self.registers = RegistersState { + program_counter: self.registers.program_counter, + is_kernel: self.registers.is_kernel, + is_stack_top_read: false, + check_overflow: false, + ..segment_data.registers_before + }; + } } impl State for GenerationState { @@ -500,7 +624,7 @@ impl State for GenerationState { } fn get_halt_offsets(&self) -> Vec { - vec![KERNEL.global_labels["halt"]] + vec![KERNEL.global_labels["halt_final"]] } fn try_perform_instruction(&mut self) -> Result { diff --git a/evm_arithmetization/src/keccak/round_flags.rs b/evm_arithmetization/src/keccak/round_flags.rs index 8bb7d4f5d..e970f1df6 100644 --- a/evm_arithmetization/src/keccak/round_flags.rs +++ b/evm_arithmetization/src/keccak/round_flags.rs @@ -25,9 +25,13 @@ pub(crate) fn eval_round_flags>( } // Initially, the first step flag should be 1 while the others should be 0. - yield_constr.constraint_first_row(local_values[reg_step(0)] - F::ONE); + let local_any_flag = (0..NUM_ROUNDS) + .map(|i| local_values[reg_step(i)]) + .sum::

(); + + yield_constr.constraint_first_row(local_any_flag * (local_values[reg_step(0)] - F::ONE)); for i in 1..NUM_ROUNDS { - yield_constr.constraint_first_row(local_values[reg_step(i)]); + yield_constr.constraint_first_row(local_any_flag * local_values[reg_step(i)]); } // Flags should circularly increment, or be all zero for padding rows. @@ -68,11 +72,16 @@ pub(crate) fn eval_round_flags_recursively, const D yield_constr.constraint(builder, constraint); } + // Initially, the first step flag should be 1 while the others should be 0. + let local_any_flag = + builder.add_many_extension((0..NUM_ROUNDS).map(|i| local_values[reg_step(i)])); // Initially, the first step flag should be 1 while the others should be 0. let step_0_minus_1 = builder.sub_extension(local_values[reg_step(0)], one); + let step_0_minus_1 = builder.mul_extension(local_any_flag, step_0_minus_1); yield_constr.constraint_first_row(builder, step_0_minus_1); for i in 1..NUM_ROUNDS { - yield_constr.constraint_first_row(builder, local_values[reg_step(i)]); + let constr = builder.mul_extension(local_any_flag, local_values[reg_step(i)]); + yield_constr.constraint_first_row(builder, constr); } // Flags should circularly increment, or be all zero for padding rows. diff --git a/evm_arithmetization/src/keccak_sponge/keccak_sponge_stark.rs b/evm_arithmetization/src/keccak_sponge/keccak_sponge_stark.rs index 398dab337..b517117a3 100644 --- a/evm_arithmetization/src/keccak_sponge/keccak_sponge_stark.rs +++ b/evm_arithmetization/src/keccak_sponge/keccak_sponge_stark.rs @@ -13,7 +13,6 @@ use plonky2::iop::ext_target::ExtensionTarget; use plonky2::timed; use plonky2::util::timing::TimingTree; use plonky2::util::transpose; -use plonky2_util::ceil_div_usize; use starky::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use starky::evaluation_frame::StarkEvaluationFrame; use starky::lookup::{Column, Filter, Lookup}; @@ -137,7 +136,7 @@ pub(crate) fn ctl_looking_memory(i: usize) -> Vec> { /// Returns the number of `KeccakSponge` tables looking into the `LogicStark`. pub(crate) const fn num_logic_ctls() -> usize { const U8S_PER_CTL: usize = 32; - ceil_div_usize(KECCAK_RATE_BYTES, U8S_PER_CTL) + KECCAK_RATE_BYTES.div_ceil(U8S_PER_CTL) } /// Creates the vector of `Columns` required to perform the `i`th logic CTL. @@ -279,17 +278,18 @@ impl, const D: usize> KeccakSpongeStark { operations: Vec, min_rows: usize, ) -> Vec<[F; NUM_KECCAK_SPONGE_COLUMNS]> { + let min_num_rows = min_rows.max(BYTE_RANGE_MAX); let base_len: usize = operations .iter() .map(|op| op.input.len() / KECCAK_RATE_BYTES + 1) .sum(); - let mut rows = Vec::with_capacity(base_len.max(min_rows).next_power_of_two()); + let mut rows = Vec::with_capacity(base_len.max(min_num_rows).next_power_of_two()); // Generate active rows. for op in operations { rows.extend(self.generate_rows_for_op(op)); } // Pad the trace. - let padded_rows = rows.len().max(min_rows).next_power_of_two(); + let padded_rows = rows.len().max(min_num_rows).next_power_of_two(); for _ in rows.len()..padded_rows { rows.push(self.generate_padding_row()); } diff --git a/evm_arithmetization/src/lib.rs b/evm_arithmetization/src/lib.rs index 0ee30fe08..5af08fbfe 100644 --- a/evm_arithmetization/src/lib.rs +++ b/evm_arithmetization/src/lib.rs @@ -191,6 +191,7 @@ pub mod keccak; pub mod keccak_sponge; pub mod logic; pub mod memory; +pub mod memory_continuation; // Proving system components pub mod all_stark; @@ -211,6 +212,7 @@ pub mod extension_tower; pub mod testing_utils; pub mod util; +use generation::TrimmedGenerationInputs; use mpt_trie::partial_trie::HashedPartialTrie; // Public definitions and re-exports @@ -222,4 +224,9 @@ pub type BlockHeight = u64; pub use all_stark::AllStark; pub use fixed_recursive_verifier::AllRecursiveCircuits; pub use generation::GenerationInputs; +use prover::{GenerationSegmentData, SegmentError}; pub use starky::config::StarkConfig; + +/// Returned type from a `SegmentDataIterator`, needed to prove all segments in +/// a transaction batch. +pub type AllData = Result<(TrimmedGenerationInputs, GenerationSegmentData), SegmentError>; diff --git a/evm_arithmetization/src/logic.rs b/evm_arithmetization/src/logic.rs index d411a4482..eae7c972b 100644 --- a/evm_arithmetization/src/logic.rs +++ b/evm_arithmetization/src/logic.rs @@ -11,7 +11,6 @@ use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; use plonky2::timed; use plonky2::util::timing::TimingTree; -use plonky2_util::ceil_div_usize; use starky::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use starky::evaluation_frame::StarkEvaluationFrame; use starky::lookup::{Column, Filter}; @@ -29,7 +28,7 @@ const VAL_BITS: usize = 256; pub(crate) const PACKED_LIMB_BITS: usize = 32; /// Number of field elements needed to store each input/output at the specified /// packing. -const PACKED_LEN: usize = ceil_div_usize(VAL_BITS, PACKED_LIMB_BITS); +const PACKED_LEN: usize = VAL_BITS.div_ceil(PACKED_LIMB_BITS); /// `LogicStark` columns. pub(crate) mod columns { diff --git a/evm_arithmetization/src/memory/columns.rs b/evm_arithmetization/src/memory/columns.rs index 8d9bcf33e..10fb75b4a 100644 --- a/evm_arithmetization/src/memory/columns.rs +++ b/evm_arithmetization/src/memory/columns.rs @@ -19,6 +19,8 @@ pub(crate) struct MemoryColumnsView { /// of general memory channels, and `i` is the index of the memory /// channel at which the memory operation is performed. pub timestamp: T, + /// Contains the inverse of `timestamp`. Used to check if `timestamp = 0`. + pub timestamp_inv: T, /// 1 if this is a read operation, 0 if it is a write one. pub is_read: T, /// The execution context of this address. @@ -40,11 +42,45 @@ pub(crate) struct MemoryColumnsView { pub segment_first_change: T, pub virtual_first_change: T, - // Used to lower the degree of the zero-initializing constraints. - // Contains `next_segment * addr_changed * next_is_read`. + /// Used to lower the degree of the zero-initializing constraints. + /// Contains `preinitialized_segments * addr_changed * next_is_read`. pub initialize_aux: T, - // We use a range check to enforce the ordering. + /// Used to allow pre-initialization of some segments. + /// Contains `(next_segment - Segment::Code) * (next_segment - + /// Segment::TrieData) + /// * preinitialized_segments_aux`. + pub preinitialized_segments: T, + + /// Used to allow pre-initialization of more segments. + /// Contains `(next_segment - Segment::AccountsLinkedList) * (next_segment - + /// Segment::StorageLinkedList)`. + pub preinitialized_segments_aux: T, + + /// Contains `row_index` + 1 if and only if context `row_index` is stale, + /// and zero if not. + pub stale_contexts: T, + + /// Flag indicating whether the current context needs to be pruned. It is + /// set to 1 when the value in `state_contexts` is non-zero. + pub is_pruned: T, + + /// Used for the context pruning lookup. + pub stale_context_frequencies: T, + + /// Flag indicating whether the row should be pruned, i.e. whether its + /// `addr_context` + 1 is in `state_contexts`. + pub is_stale: T, + + /// Flag indicating that a value can potentially be propagated. + /// Contains `filter * address_changed * is_not_stale`. + pub maybe_in_mem_after: T, + + /// Filter for the `MemAfter` CTL. Is equal to `MAYBE_IN_MEM_AFTER` if + /// segment is preinitialized or the value is non-zero, is 0 otherwise. + pub mem_after_filter: T, + + /// We use a range check to enforce the ordering. pub range_check: T, /// The counter column (used for the range check) starts from 0 and /// increments. diff --git a/evm_arithmetization/src/memory/memory_stark.rs b/evm_arithmetization/src/memory/memory_stark.rs index ba4f1255f..e58a85a26 100644 --- a/evm_arithmetization/src/memory/memory_stark.rs +++ b/evm_arithmetization/src/memory/memory_stark.rs @@ -14,16 +14,17 @@ use plonky2::util::timing::TimingTree; use plonky2::util::transpose; use plonky2_maybe_rayon::*; use starky::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; +use starky::cross_table_lookup::TableWithColumns; use starky::evaluation_frame::StarkEvaluationFrame; use starky::lookup::{Column, Filter, Lookup}; use starky::stark::Stark; use super::columns::{MemoryColumnsView, MEMORY_COL_MAP}; -use super::segments::Segment; -use crate::all_stark::EvmStarkFrame; +use super::segments::{Segment, PREINITIALIZED_SEGMENTS_INDICES}; +use crate::all_stark::{EvmStarkFrame, Table}; use crate::memory::columns::NUM_COLUMNS; use crate::memory::VALUE_LIMBS; -use crate::witness::memory::MemoryOpKind::Read; +use crate::witness::memory::MemoryOpKind::{self, Read}; use crate::witness::memory::{MemoryAddress, MemoryOp}; /// Creates the vector of `Columns` corresponding to: @@ -39,9 +40,7 @@ pub(crate) fn ctl_data() -> Vec> { MEMORY_COL_MAP.addr_virtual, ]) .collect_vec(); - res.extend(Column::singles( - (0..8).map(|i| MEMORY_COL_MAP.value_limbs[i]), - )); + res.extend(Column::singles(MEMORY_COL_MAP.value_limbs)); res.push(Column::single(MEMORY_COL_MAP.timestamp)); res } @@ -51,6 +50,52 @@ pub(crate) fn ctl_filter() -> Filter { Filter::new_simple(Column::single(MEMORY_COL_MAP.filter)) } +/// Creates the vector of `Columns` corresponding to: +/// - the initialized address (context, segment, virt), +/// - the value in u32 limbs. +pub(crate) fn ctl_looking_mem() -> Vec> { + let mut res = Column::singles([ + MEMORY_COL_MAP.addr_context, + MEMORY_COL_MAP.addr_segment, + MEMORY_COL_MAP.addr_virtual, + ]) + .collect_vec(); + res.extend(Column::singles(MEMORY_COL_MAP.value_limbs)); + res +} + +/// Returns the (non-zero) stale contexts. +pub(crate) fn ctl_context_pruning_looking() -> TableWithColumns { + TableWithColumns::new( + *Table::Memory, + vec![Column::linear_combination_with_constant( + vec![(MEMORY_COL_MAP.stale_contexts, F::ONE)], + F::NEG_ONE, + )], + Filter::new(vec![], vec![Column::single(MEMORY_COL_MAP.is_pruned)]), + ) +} + +/// CTL filter for initialization writes. +/// Initialization operations have timestamp 0. +/// The filter is `1 - timestamp * timestamp_inv`. +pub(crate) fn ctl_filter_mem_before() -> Filter { + Filter::new( + vec![( + Column::single(MEMORY_COL_MAP.timestamp), + Column::linear_combination([(MEMORY_COL_MAP.timestamp_inv, -F::ONE)]), + )], + vec![Column::constant(F::ONE)], + ) +} + +/// CTL filter for final values. +/// Final values are the last row with a given address. +/// The filter is `address_changed`. +pub(crate) fn ctl_filter_mem_after() -> Filter { + Filter::new_simple(Column::single(MEMORY_COL_MAP.mem_after_filter)) +} + #[derive(Copy, Clone, Default)] pub(crate) struct MemoryStark { pub(crate) f: PhantomData, @@ -59,13 +104,14 @@ pub(crate) struct MemoryStark { impl MemoryOp { /// Generate a row for a given memory operation. Note that this does not /// generate columns which depend on the next operation, such as - /// `CONTEXT_FIRST_CHANGE`; those are generated later. It also does not - /// generate columns such as `COUNTER`, which are generated later, after the + /// `context_first_change`; those are generated later. It also does not + /// generate columns such as `counter`, which are generated later, after the /// trace has been transposed into column-major form. fn into_row(self) -> MemoryColumnsView { let mut row = MemoryColumnsView::default(); row.filter = F::from_bool(self.filter); row.timestamp = F::from_canonical_usize(self.timestamp); + row.timestamp_inv = row.timestamp.try_inverse().unwrap_or_default(); row.is_read = F::from_bool(self.kind == Read); let MemoryAddress { context, @@ -83,15 +129,19 @@ impl MemoryOp { } } -/// Generates the `_FIRST_CHANGE` columns and the `RANGE_CHECK` column in the +/// Generates the `*_first_change` columns and the `range_check` column in the /// trace. pub(crate) fn generate_first_change_flags_and_rc( trace_rows: &mut [MemoryColumnsView], ) { let num_ops = trace_rows.len(); - for idx in 0..num_ops - 1 { + for idx in 0..num_ops { let row = &trace_rows[idx]; - let next_row = &trace_rows[idx + 1]; + let next_row = if idx == num_ops - 1 { + &trace_rows[0] + } else { + &trace_rows[idx + 1] + }; let context = row.addr_context; let segment = row.addr_segment; @@ -117,7 +167,9 @@ pub(crate) fn generate_first_change_flags_and_rc( row.segment_first_change = F::from_bool(segment_first_change); row.virtual_first_change = F::from_bool(virtual_first_change); - row.range_check = if context_first_change { + row.range_check = if idx == num_ops - 1 { + F::ZERO + } else if context_first_change { next_context - context - F::ONE } else if segment_first_change { next_segment - segment - F::ONE @@ -133,20 +185,36 @@ pub(crate) fn generate_first_change_flags_and_rc( row.range_check ); + row.preinitialized_segments_aux = (next_segment + - F::from_canonical_usize(Segment::AccountsLinkedList.unscale())) + * (next_segment - F::from_canonical_usize(Segment::StorageLinkedList.unscale())); + + row.preinitialized_segments = (next_segment + - F::from_canonical_usize(Segment::Code.unscale())) + * (next_segment - F::from_canonical_usize(Segment::TrieData.unscale())) + * row.preinitialized_segments_aux; + let address_changed = row.context_first_change + row.segment_first_change + row.virtual_first_change; - row.initialize_aux = next_segment * address_changed * next_is_read; + row.initialize_aux = row.preinitialized_segments * address_changed * next_is_read; } } impl, const D: usize> MemoryStark { - /// Generate most of the trace rows. Excludes a few columns like `COUNTER`, + /// Generate most of the trace rows. Excludes a few columns like `counter`, /// which are generated later, after transposing to column-major form. - fn generate_trace_row_major(&self, mut memory_ops: Vec) -> Vec> { + fn generate_trace_row_major( + &self, + mut memory_ops: Vec, + ) -> (Vec>, usize) { // fill_gaps expects an ordered list of operations. memory_ops.sort_by_key(MemoryOp::sorting_key); Self::fill_gaps(&mut memory_ops); + let unpadded_length = memory_ops.len(); + + memory_ops.sort_by_key(MemoryOp::sorting_key); + Self::pad_memory_ops(&mut memory_ops); // fill_gaps may have added operations at the end which break the order, so sort @@ -158,11 +226,13 @@ impl, const D: usize> MemoryStark { .map(|op| op.into_row()) .collect::>(); generate_first_change_flags_and_rc(&mut trace_rows); - trace_rows + (trace_rows, unpadded_length) } - /// Generates the `COUNTER`, `RANGE_CHECK` and `FREQUENCIES` columns, given + /// Generates the `counter`, `range_check` and `frequencies` columns, given /// a trace in column-major form. + /// Also generates the `state_contexts`, `state_contexts_frequencies`, + /// `maybe_in_mem_after` and `mem_after_filter` columns. fn generate_trace_col_major(trace_col_vecs: &mut [Vec]) { let height = trace_col_vecs[0].len(); trace_col_vecs[MEMORY_COL_MAP.counter] = @@ -174,11 +244,38 @@ impl, const D: usize> MemoryStark { if (trace_col_vecs[MEMORY_COL_MAP.context_first_change][i] == F::ONE) || (trace_col_vecs[MEMORY_COL_MAP.segment_first_change][i] == F::ONE) { - // CONTEXT_FIRST_CHANGE and SEGMENT_FIRST_CHANGE should be 0 at the last row, so - // the index should never be out of bounds. - let x_fo = - trace_col_vecs[MEMORY_COL_MAP.addr_virtual][i + 1].to_canonical_u64() as usize; - trace_col_vecs[MEMORY_COL_MAP.frequencies][x_fo] += F::ONE; + if i < trace_col_vecs[MEMORY_COL_MAP.addr_virtual].len() - 1 { + let x_val = trace_col_vecs[MEMORY_COL_MAP.addr_virtual][i + 1] + .to_canonical_u64() as usize; + trace_col_vecs[MEMORY_COL_MAP.frequencies][x_val] += F::ONE; + } else { + trace_col_vecs[MEMORY_COL_MAP.frequencies][0] += F::ONE; + } + } + + let addr_ctx = trace_col_vecs[MEMORY_COL_MAP.addr_context][i]; + let addr_ctx_usize = addr_ctx.to_canonical_u64() as usize; + if addr_ctx + F::ONE == trace_col_vecs[MEMORY_COL_MAP.stale_contexts][addr_ctx_usize] { + trace_col_vecs[MEMORY_COL_MAP.is_stale][i] = F::ONE; + trace_col_vecs[MEMORY_COL_MAP.stale_context_frequencies][addr_ctx_usize] += F::ONE; + } else if trace_col_vecs[MEMORY_COL_MAP.filter][i].is_one() + && (trace_col_vecs[MEMORY_COL_MAP.context_first_change][i].is_one() + || trace_col_vecs[MEMORY_COL_MAP.segment_first_change][i].is_one() + || trace_col_vecs[MEMORY_COL_MAP.virtual_first_change][i].is_one()) + { + // `maybe_in_mem_after = filter * address_changed * (1 - is_stale)` + trace_col_vecs[MEMORY_COL_MAP.maybe_in_mem_after][i] = F::ONE; + + let addr_segment = trace_col_vecs[MEMORY_COL_MAP.addr_segment][i]; + let is_non_zero_value = (0..VALUE_LIMBS) + .any(|limb| trace_col_vecs[MEMORY_COL_MAP.value_limbs[limb]][i].is_nonzero()); + // We filter out zero values in non-preinitialized segments. + if is_non_zero_value + || PREINITIALIZED_SEGMENTS_INDICES + .contains(&(addr_segment.to_canonical_u64() as usize)) + { + trace_col_vecs[MEMORY_COL_MAP.mem_after_filter][i] = F::ONE; + } } } } @@ -197,6 +294,25 @@ impl, const D: usize> MemoryStark { /// range check, so this method would add two dummy reads to the same /// address, say at timestamps 50 and 80. fn fill_gaps(memory_ops: &mut Vec) { + // First, insert padding row at address (0, 0, 0) if the first row doesn't + // have a first virtual address at 0. + if memory_ops[0].address.virt != 0 { + let dummy_addr = MemoryAddress { + context: 0, + segment: 0, + virt: 0, + }; + memory_ops.insert( + 0, + MemoryOp { + filter: false, + timestamp: 1, + address: dummy_addr, + kind: MemoryOpKind::Read, + value: 0.into(), + }, + ); + } let max_rc = memory_ops.len().next_power_of_two() - 1; for (mut curr, mut next) in memory_ops.clone().into_iter().tuple_windows() { if curr.address.context != next.address.context @@ -214,7 +330,8 @@ impl, const D: usize> MemoryStark { while next.address.virt > max_rc { let mut dummy_address = next.address; dummy_address.virt -= max_rc; - let dummy_read = MemoryOp::new_dummy_read(dummy_address, 0, U256::zero()); + let dummy_read = + MemoryOp::new_dummy_read(dummy_address, curr.timestamp + 1, U256::zero()); memory_ops.push(dummy_read); next = dummy_read; } @@ -222,7 +339,8 @@ impl, const D: usize> MemoryStark { while next.address.virt - curr.address.virt - 1 > max_rc { let mut dummy_address = curr.address; dummy_address.virt += max_rc + 1; - let dummy_read = MemoryOp::new_dummy_read(dummy_address, 0, U256::zero()); + let dummy_read = + MemoryOp::new_dummy_read(dummy_address, curr.timestamp + 1, U256::zero()); memory_ops.push(dummy_read); curr = dummy_read; } @@ -244,30 +362,74 @@ impl, const D: usize> MemoryStark { // desired size, with a few changes: // - We change its filter to 0 to indicate that this is a dummy operation. // - We make sure it's a read, since dummy operations must be reads. + // - We change the address so that the last operation can still be included in + // `MemAfterStark`. + let padding_addr = MemoryAddress { + virt: last_op.address.virt + 1, + ..last_op.address + }; let padding_op = MemoryOp { filter: false, kind: Read, - ..last_op + address: padding_addr, + timestamp: last_op.timestamp + 1, + value: U256::zero(), }; - let num_ops = memory_ops.len(); - let num_ops_padded = num_ops.next_power_of_two(); + // We want at least one padding row, so that the last real operation can have + // its flags set correctly. + let num_ops_padded = (num_ops + 1).next_power_of_two(); for _ in num_ops..num_ops_padded { memory_ops.push(padding_op); } } + fn insert_stale_contexts(trace_rows: &mut [MemoryColumnsView], stale_contexts: Vec) { + debug_assert!( + { + let mut dedup_vec = stale_contexts.clone(); + dedup_vec.sort(); + dedup_vec.dedup(); + dedup_vec.len() == stale_contexts.len() + }, + "Stale contexts are not unique.", + ); + + for ctx in stale_contexts { + let ctx_field = F::from_canonical_usize(ctx); + // We store `ctx_field+1` so that 0 can be the default value for non-stale + // context. + trace_rows[ctx].stale_contexts = ctx_field + F::ONE; + trace_rows[ctx].is_pruned = F::ONE; + } + } + pub(crate) fn generate_trace( &self, - memory_ops: Vec, + mut memory_ops: Vec, + mem_before_values: &[(MemoryAddress, U256)], + stale_contexts: Vec, timing: &mut TimingTree, - ) -> Vec> { + ) -> (Vec>, Vec>, usize) { + // First, push `mem_before` operations. + for &(address, value) in mem_before_values { + memory_ops.push(MemoryOp { + filter: true, + timestamp: 0, + address, + kind: crate::witness::memory::MemoryOpKind::Write, + value, + }); + } // Generate most of the trace in row-major form. - let trace_rows = timed!( + let (mut trace_rows, unpadded_length) = timed!( timing, "generate trace rows", self.generate_trace_row_major(memory_ops) ); + + Self::insert_stale_contexts(&mut trace_rows, stale_contexts.clone()); + let trace_row_vecs: Vec<_> = trace_rows.into_iter().map(|row| row.to_vec()).collect(); // Transpose to column-major form. @@ -276,10 +438,27 @@ impl, const D: usize> MemoryStark { // A few final generation steps, which work better in column-major form. Self::generate_trace_col_major(&mut trace_col_vecs); - trace_col_vecs - .into_iter() - .map(|column| PolynomialValues::new(column)) - .collect() + let final_rows = transpose(&trace_col_vecs); + + // Extract `MemoryAfterStark` values. + let mut mem_after_values = Vec::>::new(); + for row in final_rows { + if row[MEMORY_COL_MAP.mem_after_filter].is_one() { + let mut addr_val = vec![F::ONE]; + addr_val + .extend(&row[MEMORY_COL_MAP.addr_context..MEMORY_COL_MAP.context_first_change]); + mem_after_values.push(addr_val); + } + } + + ( + trace_col_vecs + .into_iter() + .map(|column| PolynomialValues::new(column)) + .collect(), + mem_after_values, + unpadded_length, + ) } } @@ -310,14 +489,21 @@ impl, const D: usize> Stark for MemoryStark = (0..8).map(|i| lv.value_limbs[i]).collect(); + let value_limbs = lv.value_limbs; + let timestamp_inv = lv.timestamp_inv; + let is_stale = lv.is_stale; + let maybe_in_mem_after = lv.maybe_in_mem_after; + let mem_after_filter = lv.mem_after_filter; + let initialize_aux = lv.initialize_aux; + let preinitialized_segments = lv.preinitialized_segments; + let preinitialized_segments_aux = lv.preinitialized_segments_aux; let next_timestamp = nv.timestamp; let next_is_read = nv.is_read; let next_addr_context = nv.addr_context; let next_addr_segment = nv.addr_segment; let next_addr_virtual = nv.addr_virtual; - let next_values_limbs: Vec<_> = (0..8).map(|i| nv.value_limbs[i]).collect(); + let next_values_limbs = nv.value_limbs; // The filter must be 0 or 1. let filter = lv.filter; @@ -373,14 +559,30 @@ impl, const D: usize> Stark for MemoryStark, const D: usize> Stark for MemoryStark, const D: usize> Stark for MemoryStark = (0..8).map(|i| lv.value_limbs[i]).collect(); + let value_limbs = lv.value_limbs; let timestamp = lv.timestamp; + let timestamp_inv = lv.timestamp_inv; + let is_stale = lv.is_stale; + let maybe_in_mem_after = lv.maybe_in_mem_after; + let mem_after_filter = lv.mem_after_filter; + let initialize_aux = lv.initialize_aux; + let preinitialized_segments = lv.preinitialized_segments; + let preinitialized_segments_aux = lv.preinitialized_segments_aux; let next_addr_context = nv.addr_context; let next_addr_segment = nv.addr_segment; let next_addr_virtual = nv.addr_virtual; - let next_values_limbs: Vec<_> = (0..8).map(|i| nv.value_limbs[i]).collect(); + let next_values_limbs = nv.value_limbs; let next_is_read = nv.is_read; let next_timestamp = nv.timestamp; @@ -534,17 +754,48 @@ impl, const D: usize> Stark for MemoryStark, const D: usize> Stark for MemoryStark, const D: usize> Stark for MemoryStark Vec> { - vec![Lookup { - columns: vec![ - Column::single(MEMORY_COL_MAP.range_check), - Column::single_next_row(MEMORY_COL_MAP.addr_virtual), - ], - table_column: Column::single(MEMORY_COL_MAP.counter), - frequencies_column: Column::single(MEMORY_COL_MAP.frequencies), - filter_columns: vec![ - Default::default(), - Filter::new_simple(Column::sum([ - MEMORY_COL_MAP.context_first_change, - MEMORY_COL_MAP.segment_first_change, - ])), - ], - }] + vec![ + Lookup { + columns: vec![ + Column::single(MEMORY_COL_MAP.range_check), + Column::single_next_row(MEMORY_COL_MAP.addr_virtual), + ], + table_column: Column::single(MEMORY_COL_MAP.counter), + frequencies_column: Column::single(MEMORY_COL_MAP.frequencies), + filter_columns: vec![ + Default::default(), + Filter::new_simple(Column::sum([ + MEMORY_COL_MAP.context_first_change, + MEMORY_COL_MAP.segment_first_change, + ])), + ], + }, + Lookup { + columns: vec![Column::linear_combination_with_constant( + vec![(MEMORY_COL_MAP.addr_context, F::ONE)], + F::ONE, + )], + table_column: Column::single(MEMORY_COL_MAP.stale_contexts), + frequencies_column: Column::single(MEMORY_COL_MAP.stale_context_frequencies), + filter_columns: vec![Filter::new_simple(Column::single(MEMORY_COL_MAP.is_stale))], + }, + ] } fn requires_ctls(&self) -> bool { diff --git a/evm_arithmetization/src/memory/segments.rs b/evm_arithmetization/src/memory/segments.rs index 8c687ea93..e1b6678f6 100644 --- a/evm_arithmetization/src/memory/segments.rs +++ b/evm_arithmetization/src/memory/segments.rs @@ -1,10 +1,12 @@ +use serde::{Deserialize, Serialize}; + pub(crate) const SEGMENT_SCALING_FACTOR: usize = 32; /// This contains all the existing memory segments. The values in the enum are /// shifted by 32 bits to allow for convenient address components (context / /// segment / virtual) bundling in the kernel. #[allow(clippy::enum_clike_unportable_variant)] -#[derive(Copy, Clone, Eq, PartialEq, Hash, Ord, PartialOrd, Debug)] +#[derive(Copy, Clone, Eq, PartialEq, Hash, Ord, PartialOrd, Debug, Serialize, Deserialize)] pub(crate) enum Segment { /// Contains EVM bytecode. // The Kernel has optimizations relying on the Code segment being 0. @@ -70,17 +72,32 @@ pub(crate) enum Segment { ContextCheckpoints = 31 << SEGMENT_SCALING_FACTOR, /// List of 256 previous block hashes. BlockHashes = 32 << SEGMENT_SCALING_FACTOR, + /// Segment storing the registers before/after the current execution, + /// as well as `exit_kernel` for the `registers_before`, in that order. + RegistersStates = 33 << SEGMENT_SCALING_FACTOR, + /// List of accounts in the state trie, + AccountsLinkedList = 34 << SEGMENT_SCALING_FACTOR, + /// List of storage slots of all the accounts in state trie, + StorageLinkedList = 35 << SEGMENT_SCALING_FACTOR, // The transient storage of the current transaction. - TransientStorage = 33 << SEGMENT_SCALING_FACTOR, + TransientStorage = 36 << SEGMENT_SCALING_FACTOR, /// List of contracts which have been created during the current /// transaction. - CreatedContracts = 34 << SEGMENT_SCALING_FACTOR, + CreatedContracts = 37 << SEGMENT_SCALING_FACTOR, /// Blob versioned hashes specified in a type-3 transaction. - TxnBlobVersionedHashes = 35 << SEGMENT_SCALING_FACTOR, + TxnBlobVersionedHashes = 38 << SEGMENT_SCALING_FACTOR, } +// These segments are not zero-initialized. +pub(crate) const PREINITIALIZED_SEGMENTS_INDICES: [usize; 4] = [ + Segment::Code.unscale(), + Segment::TrieData.unscale(), + Segment::AccountsLinkedList.unscale(), + Segment::StorageLinkedList.unscale(), +]; + impl Segment { - pub(crate) const COUNT: usize = 36; + pub(crate) const COUNT: usize = 39; /// Unscales this segment by `SEGMENT_SCALING_FACTOR`. pub(crate) const fn unscale(&self) -> usize { @@ -122,6 +139,9 @@ impl Segment { Self::TouchedAddresses, Self::ContextCheckpoints, Self::BlockHashes, + Self::RegistersStates, + Self::AccountsLinkedList, + Self::StorageLinkedList, Self::TransientStorage, Self::CreatedContracts, Self::TxnBlobVersionedHashes, @@ -164,6 +184,9 @@ impl Segment { Segment::TouchedAddresses => "SEGMENT_TOUCHED_ADDRESSES", Segment::ContextCheckpoints => "SEGMENT_CONTEXT_CHECKPOINTS", Segment::BlockHashes => "SEGMENT_BLOCK_HASHES", + Segment::RegistersStates => "SEGMENT_REGISTERS_STATES", + Segment::AccountsLinkedList => "SEGMENT_ACCOUNTS_LINKED_LIST", + Segment::StorageLinkedList => "SEGMENT_STORAGE_LINKED_LIST", Segment::TransientStorage => "SEGMENT_TRANSIENT_STORAGE", Segment::CreatedContracts => "SEGMENT_CREATED_CONTRACTS", Segment::TxnBlobVersionedHashes => "SEGMENT_TXN_BLOB_VERSIONED_HASHES", @@ -205,6 +228,9 @@ impl Segment { Segment::TouchedAddresses => 256, Segment::ContextCheckpoints => 256, Segment::BlockHashes => 256, + Segment::RegistersStates => 256, + Segment::AccountsLinkedList => 256, + Segment::StorageLinkedList => 256, Segment::TransientStorage => 256, Segment::CreatedContracts => 256, Segment::TxnBlobVersionedHashes => 256, diff --git a/evm_arithmetization/src/memory_continuation/columns.rs b/evm_arithmetization/src/memory_continuation/columns.rs new file mode 100644 index 000000000..9cff29fcd --- /dev/null +++ b/evm_arithmetization/src/memory_continuation/columns.rs @@ -0,0 +1,23 @@ +//! Columns for the initial or final memory, ordered by address. +//! It contains (addr, value) pairs. Note that non-padding addresses must be +//! unique. +use crate::memory::VALUE_LIMBS; + +/// 1 if an actual value or 0 if it's a padding row. +pub(crate) const FILTER: usize = 0; +/// The execution context of the address. +pub(crate) const ADDR_CONTEXT: usize = FILTER + 1; +/// The segment section of this address. +pub(crate) const ADDR_SEGMENT: usize = ADDR_CONTEXT + 1; +/// The virtual address within the given context and segment. +pub(crate) const ADDR_VIRTUAL: usize = ADDR_SEGMENT + 1; + +// Eight 32-bit limbs hold a total of 256 bits. +// If a value represents an integer, it is little-endian encoded. +const VALUE_START: usize = ADDR_VIRTUAL + 1; +pub(crate) const fn value_limb(i: usize) -> usize { + debug_assert!(i < VALUE_LIMBS); + VALUE_START + i +} + +pub(crate) const NUM_COLUMNS: usize = VALUE_START + VALUE_LIMBS; diff --git a/evm_arithmetization/src/memory_continuation/memory_continuation_stark.rs b/evm_arithmetization/src/memory_continuation/memory_continuation_stark.rs new file mode 100644 index 000000000..72c869b61 --- /dev/null +++ b/evm_arithmetization/src/memory_continuation/memory_continuation_stark.rs @@ -0,0 +1,179 @@ +//! `ContinuationMemoryStark` is used to store the initial or the final values +//! in memory. It is checked against `MemoryStark` through a CTL. +//! This is used to ensure a continuation of the memory when proving +//! multiple segments of a single full transaction proof. +//! As such, `ContinuationMemoryStark` doesn't have any constraints. +use std::cmp::max; +use std::marker::PhantomData; + +use itertools::Itertools; +use plonky2::field::extension::{Extendable, FieldExtension}; +use plonky2::field::packed::PackedField; +use plonky2::field::polynomial::PolynomialValues; +use plonky2::field::types::Field; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; +use plonky2::util::transpose; +use starky::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; +use starky::evaluation_frame::StarkEvaluationFrame; +use starky::lookup::{Column, Filter, Lookup}; +use starky::stark::Stark; + +use super::columns::{value_limb, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL, FILTER, NUM_COLUMNS}; +use crate::all_stark::EvmStarkFrame; +use crate::generation::MemBeforeValues; +use crate::memory::VALUE_LIMBS; + +/// Creates the vector of `Columns` corresponding to: +/// - the propagated address (context, segment, virt), +/// - the value in u32 limbs. +pub(crate) fn ctl_data() -> Vec> { + let mut res = Column::singles([ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL]).collect_vec(); + res.extend(Column::singles((0..8).map(value_limb))); + res +} + +/// CTL filter for memory operations. +pub(crate) fn ctl_filter() -> Filter { + Filter::new_simple(Column::single(FILTER)) +} + +/// Creates the vector of `Columns` corresponding to: +/// - the initialized address (context, segment, virt), +/// - the value in u32 limbs. +pub(crate) fn ctl_data_memory() -> Vec> { + let mut res = vec![Column::constant(F::ZERO)]; // IS_READ + res.extend(Column::singles([ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL]).collect_vec()); + res.extend(Column::singles((0..8).map(value_limb))); + res.push(Column::constant(F::ZERO)); // TIMESTAMP + res +} + +/// Convert `mem_before_values` to a vector of memory trace rows +pub(crate) fn mem_before_values_to_rows( + mem_before_values: &MemBeforeValues, +) -> Vec> { + mem_before_values + .iter() + .map(|mem_data| { + let mut row = vec![F::ZERO; NUM_COLUMNS]; + row[FILTER] = F::ONE; + row[ADDR_CONTEXT] = F::from_canonical_usize(mem_data.0.context); + row[ADDR_SEGMENT] = F::from_canonical_usize(mem_data.0.segment); + row[ADDR_VIRTUAL] = F::from_canonical_usize(mem_data.0.virt); + for j in 0..VALUE_LIMBS { + row[j + 4] = F::from_canonical_u32((mem_data.1 >> (j * 32)).low_u32()); + } + row + }) + .collect() +} + +/// Structure representing the `ContinuationMemory` STARK. +#[derive(Copy, Clone, Default)] +pub(crate) struct MemoryContinuationStark { + f: PhantomData, +} + +impl, const D: usize> MemoryContinuationStark { + pub(crate) fn generate_trace( + &self, + propagated_values: Vec>, + ) -> Vec> { + // Set the trace to the `propagated_values` provided either by `MemoryStark` + // (for final values) or the previous segment (for initial values). + let mut rows = propagated_values; + + let num_rows = rows.len(); + let num_rows_padded = max(128, num_rows.next_power_of_two()); + for _ in num_rows..num_rows_padded { + rows.push(vec![F::ZERO; NUM_COLUMNS]); + } + + let cols = transpose(&rows); + + cols.into_iter() + .map(|column| PolynomialValues::new(column)) + .collect() + } +} + +impl, const D: usize> Stark for MemoryContinuationStark { + type EvaluationFrame = EvmStarkFrame + where + FE: FieldExtension, + P: PackedField; + + type EvaluationFrameTarget = EvmStarkFrame, ExtensionTarget, NUM_COLUMNS>; + + fn eval_packed_generic( + &self, + vars: &Self::EvaluationFrame, + yield_constr: &mut ConstraintConsumer

, + ) where + FE: FieldExtension, + P: PackedField, + { + let local_values = vars.get_local_values(); + // The filter must be binary. + let filter = local_values[FILTER]; + yield_constr.constraint(filter * (filter - P::ONES)); + } + + fn eval_ext_circuit( + &self, + builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, + vars: &Self::EvaluationFrameTarget, + yield_constr: &mut RecursiveConstraintConsumer, + ) { + let local_values = vars.get_local_values(); + // The filter must be binary. + let filter = local_values[FILTER]; + let constr = builder.add_const_extension(filter, F::NEG_ONE); + let constr = builder.mul_extension(constr, filter); + yield_constr.constraint(builder, constr); + } + + fn constraint_degree(&self) -> usize { + 3 + } + + fn requires_ctls(&self) -> bool { + true + } + + fn lookups(&self) -> Vec> { + vec![] + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + use starky::stark_testing::{test_stark_circuit_constraints, test_stark_low_degree}; + + use crate::memory_continuation::memory_continuation_stark::MemoryContinuationStark; + + #[test] + fn test_stark_degree() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type S = MemoryContinuationStark; + + let stark = S::default(); + test_stark_low_degree(stark) + } + + #[test] + fn test_stark_circuit() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type S = MemoryContinuationStark; + + let stark = S::default(); + test_stark_circuit_constraints::(stark) + } +} diff --git a/evm_arithmetization/src/memory_continuation/mod.rs b/evm_arithmetization/src/memory_continuation/mod.rs new file mode 100644 index 000000000..6c5c0d01f --- /dev/null +++ b/evm_arithmetization/src/memory_continuation/mod.rs @@ -0,0 +1,6 @@ +//! The MemAfter STARK is used to store the memory state at the end of the +//! execution. It connects to the memory STARK to read the final values of all +//! touched addresses. + +pub mod columns; +pub mod memory_continuation_stark; diff --git a/evm_arithmetization/src/proof.rs b/evm_arithmetization/src/proof.rs index 442e1a455..34ab31d4b 100644 --- a/evm_arithmetization/src/proof.rs +++ b/evm_arithmetization/src/proof.rs @@ -1,6 +1,6 @@ use ethereum_types::{Address, H256, U256}; use plonky2::field::extension::Extendable; -use plonky2::hash::hash_types::RichField; +use plonky2::hash::hash_types::{HashOutTarget, MerkleCapTarget, RichField, NUM_HASH_OUT_ELTS}; use plonky2::iop::target::{BoolTarget, Target}; use plonky2::plonk::circuit_builder::CircuitBuilder; use plonky2::plonk::config::GenericConfig; @@ -11,7 +11,13 @@ use starky::lookup::GrandProductChallengeSet; use starky::proof::{MultiProof, StarkProofChallenges}; use crate::all_stark::NUM_TABLES; -use crate::util::{get_h160, get_h256, h2u}; +use crate::util::{get_h160, get_h256, get_u256, h2u}; +use crate::witness::state::RegistersState; + +/// The default cap height used for our zkEVM STARK proofs. +pub(crate) const DEFAULT_CAP_HEIGHT: usize = 4; +/// Number of elements contained in a Merkle cap with default height. +pub(crate) const DEFAULT_CAP_LEN: usize = 1 << DEFAULT_CAP_HEIGHT; /// A STARK proof for each table, plus some metadata used to create recursive /// wrapper proofs. @@ -52,33 +58,104 @@ pub struct PublicValues { pub block_hashes: BlockHashes, /// Extra block data that is specific to the current proof. pub extra_block_data: ExtraBlockData, + /// Registers to initialize the current proof. + pub registers_before: RegistersData, + /// Registers at the end of the current proof. + pub registers_after: RegistersData, + + pub mem_before: MemCap, + pub mem_after: MemCap, } impl PublicValues { /// Extracts public values from the given public inputs of a proof. /// Public values are always the first public inputs added to the circuit, /// so we can start extracting at index 0. + /// `len_mem_cap` is the length of the `MemBefore` and `MemAfter` caps. pub fn from_public_inputs(pis: &[F]) -> Self { - assert!(PublicValuesTarget::SIZE <= pis.len()); + assert!(pis.len() >= PublicValuesTarget::SIZE); - let trie_roots_before = TrieRoots::from_public_inputs(&pis[0..TrieRootsTarget::SIZE]); + let mut offset = 0; + let trie_roots_before = + TrieRoots::from_public_inputs(&pis[offset..offset + TrieRootsTarget::SIZE]); + offset += TrieRootsTarget::SIZE; let trie_roots_after = - TrieRoots::from_public_inputs(&pis[TrieRootsTarget::SIZE..TrieRootsTarget::SIZE * 2]); - let block_metadata = BlockMetadata::from_public_inputs( - &pis[TrieRootsTarget::SIZE * 2..TrieRootsTarget::SIZE * 2 + BlockMetadataTarget::SIZE], - ); - let block_hashes = BlockHashes::from_public_inputs( - &pis[TrieRootsTarget::SIZE * 2 + BlockMetadataTarget::SIZE - ..TrieRootsTarget::SIZE * 2 + BlockMetadataTarget::SIZE + BlockHashesTarget::SIZE], - ); - let extra_block_data = ExtraBlockData::from_public_inputs( - &pis[TrieRootsTarget::SIZE * 2 + BlockMetadataTarget::SIZE + BlockHashesTarget::SIZE - ..TrieRootsTarget::SIZE * 2 - + BlockMetadataTarget::SIZE - + BlockHashesTarget::SIZE - + ExtraBlockDataTarget::SIZE], + TrieRoots::from_public_inputs(&pis[offset..offset + TrieRootsTarget::SIZE]); + offset += TrieRootsTarget::SIZE; + let block_metadata = + BlockMetadata::from_public_inputs(&pis[offset..offset + BlockMetadataTarget::SIZE]); + offset += BlockMetadataTarget::SIZE; + let block_hashes = + BlockHashes::from_public_inputs(&pis[offset..offset + BlockHashesTarget::SIZE]); + offset += BlockHashesTarget::SIZE; + let extra_block_data = + ExtraBlockData::from_public_inputs(&pis[offset..offset + ExtraBlockDataTarget::SIZE]); + offset += ExtraBlockDataTarget::SIZE; + let registers_before = + RegistersData::from_public_inputs(&pis[offset..offset + RegistersDataTarget::SIZE]); + offset += RegistersDataTarget::SIZE; + let registers_after = + RegistersData::from_public_inputs(&pis[offset..offset + RegistersDataTarget::SIZE]); + offset += RegistersDataTarget::SIZE; + let mem_before = MemCap::from_public_inputs(&pis[offset..offset + MemCapTarget::SIZE]); + offset += MemCapTarget::SIZE; + let mem_after = MemCap::from_public_inputs(&pis[offset..offset + MemCapTarget::SIZE]); + + Self { + trie_roots_before, + trie_roots_after, + block_metadata, + block_hashes, + extra_block_data, + registers_before, + registers_after, + mem_before, + mem_after, + } + } +} + +/// Memory values which are public. +#[derive(Debug, Clone, Default, PartialEq, Eq, Deserialize, Serialize)] +pub struct FinalPublicValues { + /// Trie hashes before the execution of the local state transition + pub trie_roots_before: TrieRoots, + /// Trie hashes after the execution of the local state transition. + pub trie_roots_after: TrieRoots, + /// Block metadata: it remains unchanged within a block. + pub block_metadata: BlockMetadata, + /// 256 previous block hashes and current block's hash. + pub block_hashes: BlockHashes, + /// Extra block data that is specific to the current proof. + pub extra_block_data: ExtraBlockData, +} + +impl FinalPublicValues { + /// Extracts final public values from the given public inputs of a proof. + /// Public values are always the first public inputs added to the circuit, + /// so we can start extracting at index 0. + pub fn from_public_inputs(pis: &[F]) -> Self { + assert!( + PublicValuesTarget::SIZE - 2 * RegistersDataTarget::SIZE - 2 * MemCapTarget::SIZE + <= pis.len() ); + let mut offset = 0; + let trie_roots_before = + TrieRoots::from_public_inputs(&pis[offset..offset + TrieRootsTarget::SIZE]); + offset += TrieRootsTarget::SIZE; + let trie_roots_after = + TrieRoots::from_public_inputs(&pis[offset..offset + TrieRootsTarget::SIZE]); + offset += TrieRootsTarget::SIZE; + let block_metadata = + BlockMetadata::from_public_inputs(&pis[offset..offset + BlockMetadataTarget::SIZE]); + offset += BlockMetadataTarget::SIZE; + let block_hashes = + BlockHashes::from_public_inputs(&pis[offset..offset + BlockHashesTarget::SIZE]); + offset += BlockHashesTarget::SIZE; + let extra_block_data = + ExtraBlockData::from_public_inputs(&pis[offset..offset + ExtraBlockDataTarget::SIZE]); + Self { trie_roots_before, trie_roots_after, @@ -89,6 +166,18 @@ impl PublicValues { } } +impl From for FinalPublicValues { + fn from(value: PublicValues) -> Self { + Self { + trie_roots_before: value.trie_roots_before, + trie_roots_after: value.trie_roots_after, + block_metadata: value.block_metadata, + block_hashes: value.block_hashes, + extra_block_data: value.extra_block_data, + } + } +} + /// Trie hashes. #[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)] pub struct TrieRoots { @@ -104,9 +193,9 @@ impl TrieRoots { pub fn from_public_inputs(pis: &[F]) -> Self { assert!(pis.len() == TrieRootsTarget::SIZE); - let state_root = get_h256(&pis[0..8]); - let transactions_root = get_h256(&pis[8..16]); - let receipts_root = get_h256(&pis[16..24]); + let state_root = get_h256(&pis[0..TARGET_HASH_SIZE]); + let transactions_root = get_h256(&pis[TARGET_HASH_SIZE..2 * TARGET_HASH_SIZE]); + let receipts_root = get_h256(&pis[2 * TARGET_HASH_SIZE..3 * TARGET_HASH_SIZE]); Self { state_root, @@ -147,7 +236,9 @@ impl BlockHashes { pub fn from_public_inputs(pis: &[F]) -> Self { assert!(pis.len() == BlockHashesTarget::SIZE); - let prev_hashes: [H256; 256] = core::array::from_fn(|i| get_h256(&pis[8 * i..8 + 8 * i])); + let prev_hashes: [H256; 256] = core::array::from_fn(|i| { + get_h256(&pis[TARGET_HASH_SIZE * i..TARGET_HASH_SIZE * (i + 1)]) + }); let cur_hash = get_h256(&pis[2048..2056]); Self { @@ -269,6 +360,79 @@ impl ExtraBlockData { } } +/// Registers data used to preinitialize the registers and check the final +/// registers of the current proof. +#[derive(Debug, Clone, Default, PartialEq, Eq, Deserialize, Serialize)] +pub struct RegistersData { + /// Program counter. + pub program_counter: U256, + /// Indicates whether we are in kernel mode. + pub is_kernel: U256, + /// Stack length. + pub stack_len: U256, + /// Top of the stack. + pub stack_top: U256, + /// Context. + pub context: U256, + /// Gas used so far. + pub gas_used: U256, +} + +impl RegistersData { + pub fn from_public_inputs(pis: &[F]) -> Self { + assert!(pis.len() == RegistersDataTarget::SIZE); + + let program_counter = pis[0].to_canonical_u64().into(); + let is_kernel = pis[1].to_canonical_u64().into(); + let stack_len = pis[2].to_canonical_u64().into(); + let stack_top = get_u256(&pis[3..11].try_into().unwrap()); + let context = pis[11].to_canonical_u64().into(); + let gas_used = pis[12].to_canonical_u64().into(); + + Self { + program_counter, + is_kernel, + stack_len, + stack_top, + context, + gas_used, + } + } +} + +impl From for RegistersData { + fn from(registers: RegistersState) -> Self { + RegistersData { + program_counter: registers.program_counter.into(), + is_kernel: (registers.is_kernel as u64).into(), + stack_len: registers.stack_len.into(), + stack_top: registers.stack_top, + context: registers.context.into(), + gas_used: registers.gas_used.into(), + } + } +} + +/// Structure for a Merkle cap. It is used for `MemBefore` and `MemAfter`. +#[derive(Debug, Clone, Default, PartialEq, Eq, Deserialize, Serialize)] +pub struct MemCap { + /// STARK cap. + pub mem_cap: Vec<[U256; NUM_HASH_OUT_ELTS]>, +} +impl MemCap { + pub fn from_public_inputs(pis: &[F]) -> Self { + let mem_cap = (0..DEFAULT_CAP_LEN) + .map(|i| { + core::array::from_fn(|j| { + U256::from(pis[pis.len() - 4 * (DEFAULT_CAP_LEN - i) + j].to_canonical_u64()) + }) + }) + .collect(); + + Self { mem_cap } + } +} + /// Memory values which are public. /// Note: All the larger integers are encoded with 32-bit limbs in little-endian /// order. @@ -284,13 +448,22 @@ pub struct PublicValuesTarget { pub block_hashes: BlockHashesTarget, /// Extra block data that is specific to the current proof. pub extra_block_data: ExtraBlockDataTarget, + /// Registers to initialize the current proof. + pub registers_before: RegistersDataTarget, + /// Registers at the end of the current proof. + pub registers_after: RegistersDataTarget, + /// Memory before. + pub mem_before: MemCapTarget, + /// Memory after. + pub mem_after: MemCapTarget, } impl PublicValuesTarget { pub(crate) const SIZE: usize = TrieRootsTarget::SIZE * 2 + BlockMetadataTarget::SIZE + BlockHashesTarget::SIZE - + ExtraBlockDataTarget::SIZE; + + ExtraBlockDataTarget::SIZE + + DEFAULT_CAP_HEIGHT * NUM_HASH_OUT_ELTS * 2; /// Serializes public value targets. pub(crate) fn to_buffer(&self, buffer: &mut Vec) -> IoResult<()> { let TrieRootsTarget { @@ -362,6 +535,37 @@ impl PublicValuesTarget { buffer.write_target(txn_number_after)?; buffer.write_target(gas_used_before)?; buffer.write_target(gas_used_after)?; + let RegistersDataTarget { + program_counter: program_counter_before, + is_kernel: is_kernel_before, + stack_len: stack_len_before, + stack_top: stack_top_before, + context: context_before, + gas_used: gas_used_before, + } = self.registers_before; + buffer.write_target(program_counter_before)?; + buffer.write_target(is_kernel_before)?; + buffer.write_target(stack_len_before)?; + buffer.write_target_array(&stack_top_before)?; + buffer.write_target(context_before)?; + buffer.write_target(gas_used_before)?; + let RegistersDataTarget { + program_counter: program_counter_after, + is_kernel: is_kernel_after, + stack_len: stack_len_after, + stack_top: stack_top_after, + context: context_after, + gas_used: gas_used_after, + } = self.registers_after; + buffer.write_target(program_counter_after)?; + buffer.write_target(is_kernel_after)?; + buffer.write_target(stack_len_after)?; + buffer.write_target_array(&stack_top_after)?; + buffer.write_target(context_after)?; + buffer.write_target(gas_used_after)?; + + buffer.write_target_merkle_cap(&self.mem_before.mem_cap)?; + buffer.write_target_merkle_cap(&self.mem_after.mem_cap)?; Ok(()) } @@ -409,12 +613,40 @@ impl PublicValuesTarget { gas_used_after: buffer.read_target()?, }; + let registers_before = RegistersDataTarget { + program_counter: buffer.read_target()?, + is_kernel: buffer.read_target()?, + stack_len: buffer.read_target()?, + stack_top: buffer.read_target_array()?, + context: buffer.read_target()?, + gas_used: buffer.read_target()?, + }; + let registers_after = RegistersDataTarget { + program_counter: buffer.read_target()?, + is_kernel: buffer.read_target()?, + stack_len: buffer.read_target()?, + stack_top: buffer.read_target_array()?, + context: buffer.read_target()?, + gas_used: buffer.read_target()?, + }; + + let mem_before = MemCapTarget { + mem_cap: buffer.read_target_merkle_cap()?, + }; + let mem_after = MemCapTarget { + mem_cap: buffer.read_target_merkle_cap()?, + }; + Ok(Self { trie_roots_before, trie_roots_after, block_metadata, block_hashes, extra_block_data, + registers_before, + registers_after, + mem_before, + mem_after, }) } @@ -422,37 +654,49 @@ impl PublicValuesTarget { /// Public values are always the first public inputs added to the circuit, /// so we can start extracting at index 0. pub(crate) fn from_public_inputs(pis: &[Target]) -> Self { - assert!( - pis.len() - > TrieRootsTarget::SIZE * 2 - + BlockMetadataTarget::SIZE - + BlockHashesTarget::SIZE - + ExtraBlockDataTarget::SIZE - - 1 + assert!(pis.len() >= Self::SIZE); + + let mut offset = 0; + let trie_roots_before = + TrieRootsTarget::from_public_inputs(&pis[offset..offset + TrieRootsTarget::SIZE]); + offset += TrieRootsTarget::SIZE; + let trie_roots_after = + TrieRootsTarget::from_public_inputs(&pis[offset..offset + TrieRootsTarget::SIZE]); + offset += TrieRootsTarget::SIZE; + let block_metadata = BlockMetadataTarget::from_public_inputs( + &pis[offset..offset + BlockMetadataTarget::SIZE], + ); + offset += BlockMetadataTarget::SIZE; + let block_hashes = + BlockHashesTarget::from_public_inputs(&pis[offset..offset + BlockHashesTarget::SIZE]); + offset += BlockHashesTarget::SIZE; + let extra_block_data = ExtraBlockDataTarget::from_public_inputs( + &pis[offset..offset + ExtraBlockDataTarget::SIZE], + ); + offset += ExtraBlockDataTarget::SIZE; + let registers_before = RegistersDataTarget::from_public_inputs( + &pis[offset..offset + RegistersDataTarget::SIZE], + ); + offset += RegistersDataTarget::SIZE; + let registers_after = RegistersDataTarget::from_public_inputs( + &pis[offset..offset + RegistersDataTarget::SIZE], ); + offset += RegistersDataTarget::SIZE; + let mem_before = + MemCapTarget::from_public_inputs(&pis[offset..offset + MemCapTarget::SIZE]); + offset += MemCapTarget::SIZE; + let mem_after = MemCapTarget::from_public_inputs(&pis[offset..offset + MemCapTarget::SIZE]); Self { - trie_roots_before: TrieRootsTarget::from_public_inputs(&pis[0..TrieRootsTarget::SIZE]), - trie_roots_after: TrieRootsTarget::from_public_inputs( - &pis[TrieRootsTarget::SIZE..TrieRootsTarget::SIZE * 2], - ), - block_metadata: BlockMetadataTarget::from_public_inputs( - &pis[TrieRootsTarget::SIZE * 2 - ..TrieRootsTarget::SIZE * 2 + BlockMetadataTarget::SIZE], - ), - block_hashes: BlockHashesTarget::from_public_inputs( - &pis[TrieRootsTarget::SIZE * 2 + BlockMetadataTarget::SIZE - ..TrieRootsTarget::SIZE * 2 - + BlockMetadataTarget::SIZE - + BlockHashesTarget::SIZE], - ), - extra_block_data: ExtraBlockDataTarget::from_public_inputs( - &pis[TrieRootsTarget::SIZE * 2 + BlockMetadataTarget::SIZE + BlockHashesTarget::SIZE - ..TrieRootsTarget::SIZE * 2 - + BlockMetadataTarget::SIZE - + BlockHashesTarget::SIZE - + ExtraBlockDataTarget::SIZE], - ), + trie_roots_before, + trie_roots_after, + block_metadata, + block_hashes, + extra_block_data, + registers_before, + registers_after, + mem_before, + mem_after, } } @@ -494,6 +738,21 @@ impl PublicValuesTarget { pv0.extra_block_data, pv1.extra_block_data, ), + registers_before: RegistersDataTarget::select( + builder, + condition, + pv0.registers_before, + pv1.registers_before, + ), + registers_after: RegistersDataTarget::select( + builder, + condition, + pv0.registers_after, + pv1.registers_after, + ), + mem_before: MemCapTarget::select(builder, condition, pv0.mem_before, pv1.mem_before), + + mem_after: MemCapTarget::select(builder, condition, pv0.mem_after, pv1.mem_after), } } } @@ -504,25 +763,30 @@ impl PublicValuesTarget { #[derive(Eq, PartialEq, Debug, Copy, Clone)] pub struct TrieRootsTarget { /// Targets for the state trie hash. - pub(crate) state_root: [Target; 8], + pub(crate) state_root: [Target; TARGET_HASH_SIZE], /// Targets for the transactions trie hash. - pub(crate) transactions_root: [Target; 8], + pub(crate) transactions_root: [Target; TARGET_HASH_SIZE], /// Targets for the receipts trie hash. - pub(crate) receipts_root: [Target; 8], + pub(crate) receipts_root: [Target; TARGET_HASH_SIZE], } +/// Number of `Target`s required for hashes. +pub(crate) const TARGET_HASH_SIZE: usize = 8; + impl TrieRootsTarget { - /// Number of `Target`s required for all trie hashes. - pub(crate) const HASH_SIZE: usize = 8; - pub(crate) const SIZE: usize = Self::HASH_SIZE * 3; + pub(crate) const SIZE: usize = TARGET_HASH_SIZE * 3; /// Extracts trie hash `Target`s for all tries from the provided public /// input `Target`s. The provided `pis` should start with the trie /// hashes. pub(crate) fn from_public_inputs(pis: &[Target]) -> Self { - let state_root = pis[0..8].try_into().unwrap(); - let transactions_root = pis[8..16].try_into().unwrap(); - let receipts_root = pis[16..24].try_into().unwrap(); + let state_root = pis[0..TARGET_HASH_SIZE].try_into().unwrap(); + let transactions_root = pis[TARGET_HASH_SIZE..2 * TARGET_HASH_SIZE] + .try_into() + .unwrap(); + let receipts_root = pis[2 * TARGET_HASH_SIZE..3 * TARGET_HASH_SIZE] + .try_into() + .unwrap(); Self { state_root, @@ -568,6 +832,28 @@ impl TrieRootsTarget { builder.connect(tr0.receipts_root[i], tr1.receipts_root[i]); } } + + /// If `condition`, asserts that `tr0 == tr1`. + pub(crate) fn conditional_assert_eq, const D: usize>( + builder: &mut CircuitBuilder, + condition: BoolTarget, + tr0: Self, + tr1: Self, + ) { + for i in 0..8 { + builder.conditional_assert_eq(condition.target, tr0.state_root[i], tr1.state_root[i]); + builder.conditional_assert_eq( + condition.target, + tr0.transactions_root[i], + tr1.transactions_root[i], + ); + builder.conditional_assert_eq( + condition.target, + tr0.receipts_root[i], + tr1.receipts_root[i], + ); + } + } } /// Circuit version of `BlockMetadata`. @@ -733,6 +1019,45 @@ impl BlockMetadataTarget { builder.connect(bm0.block_bloom[i], bm1.block_bloom[i]) } } + + /// If `condition`, asserts that `bm0 == bm1`. + pub(crate) fn conditional_assert_eq, const D: usize>( + builder: &mut CircuitBuilder, + condition: BoolTarget, + bm0: Self, + bm1: Self, + ) { + for i in 0..5 { + builder.conditional_assert_eq( + condition.target, + bm0.block_beneficiary[i], + bm1.block_beneficiary[i], + ); + } + builder.conditional_assert_eq(condition.target, bm0.block_timestamp, bm1.block_timestamp); + builder.conditional_assert_eq(condition.target, bm0.block_number, bm1.block_number); + builder.conditional_assert_eq(condition.target, bm0.block_difficulty, bm1.block_difficulty); + for i in 0..8 { + builder.conditional_assert_eq( + condition.target, + bm0.block_random[i], + bm1.block_random[i], + ); + } + builder.conditional_assert_eq(condition.target, bm0.block_gaslimit, bm1.block_gaslimit); + builder.conditional_assert_eq(condition.target, bm0.block_chain_id, bm1.block_chain_id); + for i in 0..2 { + builder.conditional_assert_eq( + condition.target, + bm0.block_base_fee[i], + bm1.block_base_fee[i], + ) + } + builder.conditional_assert_eq(condition.target, bm0.block_gas_used, bm1.block_gas_used); + for i in 0..64 { + builder.conditional_assert_eq(condition.target, bm0.block_bloom[i], bm1.block_bloom[i]) + } + } } /// Circuit version of `BlockHashes`. @@ -798,6 +1123,21 @@ impl BlockHashesTarget { builder.connect(bm0.cur_hash[i], bm1.cur_hash[i]); } } + + /// If `condition`, asserts that `bm0 == bm1`. + pub(crate) fn conditional_assert_eq, const D: usize>( + builder: &mut CircuitBuilder, + condition: BoolTarget, + bm0: Self, + bm1: Self, + ) { + for i in 0..2048 { + builder.conditional_assert_eq(condition.target, bm0.prev_hashes[i], bm1.prev_hashes[i]); + } + for i in 0..8 { + builder.conditional_assert_eq(condition.target, bm0.cur_hash[i], bm1.cur_hash[i]); + } + } } /// Circuit version of `ExtraBlockData`. @@ -889,4 +1229,213 @@ impl ExtraBlockDataTarget { builder.connect(ed0.gas_used_before, ed1.gas_used_before); builder.connect(ed0.gas_used_after, ed1.gas_used_after); } + + /// If `condition`, asserts that `ed0 == ed1`. + pub(crate) fn conditional_assert_eq, const D: usize>( + builder: &mut CircuitBuilder, + condition: BoolTarget, + ed0: Self, + ed1: Self, + ) { + for i in 0..8 { + builder.conditional_assert_eq( + condition.target, + ed0.checkpoint_state_trie_root[i], + ed1.checkpoint_state_trie_root[i], + ); + } + builder.conditional_assert_eq( + condition.target, + ed0.txn_number_before, + ed1.txn_number_before, + ); + builder.conditional_assert_eq(condition.target, ed0.txn_number_after, ed1.txn_number_after); + builder.conditional_assert_eq(condition.target, ed0.gas_used_before, ed1.gas_used_before); + builder.conditional_assert_eq(condition.target, ed0.gas_used_after, ed1.gas_used_after); + } +} + +/// Circuit version of `RegistersData`. +/// Registers data used to preinitialize the registers and check the final +/// registers of the current proof. +#[derive(Debug, Clone, Default, PartialEq, Eq, Deserialize, Serialize)] +pub struct RegistersDataTarget { + /// Program counter. + pub program_counter: Target, + /// Indicates whether we are in kernel mode. + pub is_kernel: Target, + /// Stack length. + pub stack_len: Target, + /// Top of the stack. + pub stack_top: [Target; 8], + /// Context. + pub context: Target, + /// Gas used so far. + pub gas_used: Target, +} + +impl RegistersDataTarget { + /// Number of `Target`s required for the extra block data. + pub const SIZE: usize = 13; + + /// Extracts the extra block data `Target`s from the public input `Target`s. + /// The provided `pis` should start with the extra vblock data. + pub(crate) fn from_public_inputs(pis: &[Target]) -> Self { + let program_counter = pis[0]; + let is_kernel = pis[1]; + let stack_len = pis[2]; + let stack_top = pis[3..11].try_into().unwrap(); + let context = pis[11]; + let gas_used = pis[12]; + + Self { + program_counter, + is_kernel, + stack_len, + stack_top, + context, + gas_used, + } + } + + /// If `condition`, returns the extra block data in `ed0`, + /// otherwise returns the extra block data in `ed1`. + pub(crate) fn select, const D: usize>( + builder: &mut CircuitBuilder, + condition: BoolTarget, + rd0: Self, + rd1: Self, + ) -> Self { + Self { + program_counter: builder.select(condition, rd0.program_counter, rd1.program_counter), + is_kernel: builder.select(condition, rd0.is_kernel, rd1.is_kernel), + stack_len: builder.select(condition, rd0.stack_len, rd1.stack_len), + stack_top: core::array::from_fn(|i| { + builder.select(condition, rd0.stack_top[i], rd1.stack_top[i]) + }), + context: builder.select(condition, rd0.context, rd1.context), + gas_used: builder.select(condition, rd0.gas_used, rd1.gas_used), + } + } + + /// Connects the extra block data in `ed0` with the extra block data in + /// `ed1`. + pub(crate) fn connect, const D: usize>( + builder: &mut CircuitBuilder, + rd0: Self, + rd1: Self, + ) { + builder.connect(rd0.program_counter, rd1.program_counter); + builder.connect(rd0.is_kernel, rd1.is_kernel); + builder.connect(rd0.stack_len, rd1.stack_len); + for i in 0..8 { + builder.connect(rd0.stack_top[i], rd1.stack_top[i]); + } + builder.connect(rd0.context, rd1.context); + builder.connect(rd0.gas_used, rd1.gas_used); + } + + /// If `condition`, asserts that `rd0 == rd1`. + pub(crate) fn conditional_assert_eq, const D: usize>( + builder: &mut CircuitBuilder, + condition: BoolTarget, + rd0: Self, + rd1: Self, + ) { + builder.conditional_assert_eq(condition.target, rd0.program_counter, rd1.program_counter); + builder.conditional_assert_eq(condition.target, rd0.is_kernel, rd1.is_kernel); + builder.conditional_assert_eq(condition.target, rd0.stack_len, rd1.stack_len); + for i in 0..8 { + builder.conditional_assert_eq(condition.target, rd0.stack_top[i], rd1.stack_top[i]); + } + builder.conditional_assert_eq(condition.target, rd0.context, rd1.context); + builder.conditional_assert_eq(condition.target, rd0.gas_used, rd1.gas_used); + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct MemCapTarget { + /// Merkle cap. + pub mem_cap: MerkleCapTarget, +} + +impl MemCapTarget { + pub(crate) const SIZE: usize = DEFAULT_CAP_LEN * NUM_HASH_OUT_ELTS; + + /// Extracts the exit kernel `Target`s from the public input `Target`s. + /// The provided `pis` should start with the extra vblock data. + pub(crate) fn from_public_inputs(pis: &[Target]) -> Self { + let mem_values = &pis[0..Self::SIZE]; + let mem_cap = MerkleCapTarget( + (0..DEFAULT_CAP_LEN) + .map(|i| HashOutTarget { + elements: mem_values[i * NUM_HASH_OUT_ELTS..(i + 1) * NUM_HASH_OUT_ELTS] + .try_into() + .unwrap(), + }) + .collect::>(), + ); + + Self { mem_cap } + } + + /// If `condition`, returns the exit kernel in `ek0`, + /// otherwise returns the exit kernel in `ek1`. + pub(crate) fn select, const D: usize>( + builder: &mut CircuitBuilder, + condition: BoolTarget, + mc0: Self, + mc1: Self, + ) -> Self { + Self { + mem_cap: MerkleCapTarget( + (0..mc0.mem_cap.0.len()) + .map(|i| HashOutTarget { + elements: (0..NUM_HASH_OUT_ELTS) + .map(|j| { + builder.select( + condition, + mc0.mem_cap.0[i].elements[j], + mc1.mem_cap.0[i].elements[j], + ) + }) + .collect::>() + .try_into() + .unwrap(), + }) + .collect::>(), + ), + } + } + + /// Connects the exit kernel in `ek0` with the exit kernel in `ek1`. + pub(crate) fn connect, const D: usize>( + builder: &mut CircuitBuilder, + mc0: Self, + mc1: Self, + ) { + for i in 0..mc0.mem_cap.0.len() { + for j in 0..NUM_HASH_OUT_ELTS { + builder.connect(mc0.mem_cap.0[i].elements[j], mc1.mem_cap.0[i].elements[j]); + } + } + } + + /// If `condition`, asserts that `mc0 == mc1`. + pub(crate) fn conditional_assert_eq, const D: usize>( + builder: &mut CircuitBuilder, + condition: BoolTarget, + mc0: Self, + mc1: Self, + ) { + for i in 0..mc0.mem_cap.0.len() { + for j in 0..NUM_HASH_OUT_ELTS { + builder.conditional_assert_eq( + condition.target, + mc0.mem_cap.0[i].elements[j], + mc1.mem_cap.0[i].elements[j], + ); + } + } + } } diff --git a/evm_arithmetization/src/prover.rs b/evm_arithmetization/src/prover.rs index f6c40cf89..746e1926e 100644 --- a/evm_arithmetization/src/prover.rs +++ b/evm_arithmetization/src/prover.rs @@ -8,10 +8,12 @@ use plonky2::field::extension::Extendable; use plonky2::field::polynomial::PolynomialValues; use plonky2::fri::oracle::PolynomialBatch; use plonky2::hash::hash_types::RichField; +use plonky2::hash::merkle_tree::MerkleCap; use plonky2::iop::challenger::Challenger; -use plonky2::plonk::config::GenericConfig; +use plonky2::plonk::config::{GenericConfig, GenericHashOut}; use plonky2::timed; use plonky2::util::timing::TimingTree; +use serde::{Deserialize, Serialize}; use starky::config::StarkConfig; use starky::cross_table_lookup::{get_ctl_data, CtlData}; use starky::lookup::GrandProductChallengeSet; @@ -21,15 +23,46 @@ use starky::stark::Stark; use crate::all_stark::{AllStark, Table, NUM_TABLES}; use crate::cpu::kernel::aggregator::KERNEL; -use crate::generation::{generate_traces, GenerationInputs}; +use crate::cpu::kernel::interpreter::{set_registers_and_run, ExtraSegmentData, Interpreter}; +use crate::generation::state::State; +use crate::generation::{debug_inputs, generate_traces, GenerationInputs, TrimmedGenerationInputs}; use crate::get_challenges::observe_public_values; -use crate::proof::{AllProof, PublicValues}; +use crate::proof::{AllProof, MemCap, PublicValues, DEFAULT_CAP_LEN}; +use crate::witness::memory::MemoryState; +use crate::witness::state::RegistersState; +use crate::AllData; + +/// Structure holding the data needed to initialize a segment. +#[derive(Clone, Default, Debug, Serialize, Deserialize)] +pub struct GenerationSegmentData { + /// Indicates the position of this segment in a sequence of + /// executions for a larger payload. + pub(crate) segment_index: usize, + /// Registers at the start of the segment execution. + pub(crate) registers_before: RegistersState, + /// Registers at the end of the segment execution. + pub(crate) registers_after: RegistersState, + /// Memory at the start of the segment execution. + pub(crate) memory: MemoryState, + /// Extra data required to initialize a segment. + pub(crate) extra_data: ExtraSegmentData, + /// Log of the maximal cpu length. + pub(crate) max_cpu_len_log: Option, +} + +impl GenerationSegmentData { + /// Retrieves the index of this segment. + pub fn segment_index(&self) -> usize { + self.segment_index + } +} /// Generate traces, then create all STARK proofs. pub fn prove( all_stark: &AllStark, config: &StarkConfig, - inputs: GenerationInputs, + inputs: TrimmedGenerationInputs, + segment_data: &mut GenerationSegmentData, timing: &mut TimingTree, abort_signal: Option>, ) -> Result> @@ -37,22 +70,28 @@ where F: RichField + Extendable, C: GenericConfig, { + // Sanity check on the provided config + assert_eq!(DEFAULT_CAP_LEN, 1 << config.fri_config.cap_height); + timed!(timing, "build kernel", Lazy::force(&KERNEL)); - let (traces, public_values) = timed!( + + let (traces, mut public_values) = timed!( timing, "generate all traces", - generate_traces(all_stark, inputs, config, timing)? + generate_traces(all_stark, &inputs, config, segment_data, timing)? ); + check_abort_signal(abort_signal.clone())?; let proof = prove_with_traces( all_stark, config, traces, - public_values, + &mut public_values, timing, abort_signal, )?; + Ok(proof) } @@ -61,7 +100,7 @@ pub(crate) fn prove_with_traces( all_stark: &AllStark, config: &StarkConfig, trace_poly_values: [Vec>; NUM_TABLES], - public_values: PublicValues, + public_values: &mut PublicValues, timing: &mut TimingTree, abort_signal: Option>, ) -> Result> @@ -107,7 +146,7 @@ where challenger.observe_cap(cap); } - observe_public_values::(&mut challenger, &public_values) + observe_public_values::(&mut challenger, public_values) .map_err(|_| anyhow::Error::msg("Invalid conversion of public values."))?; // For each STARK, compute its cross-table lookup Z polynomials and get the @@ -124,7 +163,7 @@ where ) ); - let stark_proofs = timed!( + let (stark_proofs, mem_before_cap, mem_after_cap) = timed!( timing, "compute all proofs given commitments", prove_with_commitments( @@ -139,6 +178,34 @@ where abort_signal, )? ); + public_values.mem_before = MemCap { + mem_cap: mem_before_cap + .0 + .iter() + .map(|h| { + h.to_vec() + .iter() + .map(|hi| hi.to_canonical_u64().into()) + .collect::>() + .try_into() + .unwrap() + }) + .collect::>(), + }; + public_values.mem_after = MemCap { + mem_cap: mem_after_cap + .0 + .iter() + .map(|h| { + h.to_vec() + .iter() + .map(|hi| hi.to_canonical_u64().into()) + .collect::>() + .try_into() + .unwrap() + }) + .collect::>(), + }; // This is an expensive check, hence is only run when `debug_assertions` are // enabled. @@ -152,7 +219,7 @@ where let mut extra_values = HashMap::new(); extra_values.insert( *Table::Memory, - get_memory_extra_looking_values(&public_values), + get_memory_extra_looking_values(public_values), ); check_ctls( &trace_poly_values, @@ -166,10 +233,16 @@ where stark_proofs, ctl_challenges, }, - public_values, + public_values: public_values.clone(), }) } +type ProofWithMemCaps = ( + [StarkProofWithMetadata; NUM_TABLES], + MerkleCap, + MerkleCap, +); + /// Generates a proof for each STARK. /// At this stage, we have computed the trace polynomials commitments for the /// various STARKs, and we have the cross-table lookup data for each table, @@ -189,12 +262,12 @@ fn prove_with_commitments( ctl_challenges: &GrandProductChallengeSet, timing: &mut TimingTree, abort_signal: Option>, -) -> Result<[StarkProofWithMetadata; NUM_TABLES]> +) -> Result> where F: RichField + Extendable, C: GenericConfig, { - let arithmetic_proof = timed!( + let (arithmetic_proof, _) = timed!( timing, "prove Arithmetic STARK", prove_single_table( @@ -209,7 +282,7 @@ where abort_signal.clone(), )? ); - let byte_packing_proof = timed!( + let (byte_packing_proof, _) = timed!( timing, "prove byte packing STARK", prove_single_table( @@ -224,7 +297,7 @@ where abort_signal.clone(), )? ); - let cpu_proof = timed!( + let (cpu_proof, _) = timed!( timing, "prove CPU STARK", prove_single_table( @@ -239,7 +312,7 @@ where abort_signal.clone(), )? ); - let keccak_proof = timed!( + let (keccak_proof, _) = timed!( timing, "prove Keccak STARK", prove_single_table( @@ -254,7 +327,7 @@ where abort_signal.clone(), )? ); - let keccak_sponge_proof = timed!( + let (keccak_sponge_proof, _) = timed!( timing, "prove Keccak sponge STARK", prove_single_table( @@ -269,7 +342,7 @@ where abort_signal.clone(), )? ); - let logic_proof = timed!( + let (logic_proof, _) = timed!( timing, "prove logic STARK", prove_single_table( @@ -284,7 +357,7 @@ where abort_signal.clone(), )? ); - let memory_proof = timed!( + let (memory_proof, _) = timed!( timing, "prove memory STARK", prove_single_table( @@ -296,25 +369,66 @@ where ctl_challenges, challenger, timing, + abort_signal.clone(), + )? + ); + let (mem_before_proof, mem_before_cap) = timed!( + timing, + "prove mem_before STARK", + prove_single_table( + &all_stark.mem_before_stark, + config, + &trace_poly_values[Table::MemBefore as usize], + &trace_commitments[Table::MemBefore as usize], + &ctl_data_per_table[Table::MemBefore as usize], + ctl_challenges, + challenger, + timing, + abort_signal.clone(), + )? + ); + let (mem_after_proof, mem_after_cap) = timed!( + timing, + "prove mem_after STARK", + prove_single_table( + &all_stark.mem_after_stark, + config, + &trace_poly_values[Table::MemAfter as usize], + &trace_commitments[Table::MemAfter as usize], + &ctl_data_per_table[Table::MemAfter as usize], + ctl_challenges, + challenger, + timing, abort_signal, )? ); - Ok([ - arithmetic_proof, - byte_packing_proof, - cpu_proof, - keccak_proof, - keccak_sponge_proof, - logic_proof, - memory_proof, - ]) + Ok(( + [ + arithmetic_proof, + byte_packing_proof, + cpu_proof, + keccak_proof, + keccak_sponge_proof, + logic_proof, + memory_proof, + mem_before_proof, + mem_after_proof, + ], + mem_before_cap, + mem_after_cap, + )) } +type ProofSingleWithCap = + (StarkProofWithMetadata, MerkleCap); + /// Computes a proof for a single STARK table, including: /// - the initial state of the challenger, /// - all the requires Merkle caps, /// - all the required polynomial and FRI argument openings. +/// +/// Returns the proof, along with the associated `MerkleCap`. pub(crate) fn prove_single_table( stark: &S, config: &StarkConfig, @@ -325,7 +439,7 @@ pub(crate) fn prove_single_table( challenger: &mut Challenger, timing: &mut TimingTree, abort_signal: Option>, -) -> Result> +) -> Result> where F: RichField + Extendable, C: GenericConfig, @@ -336,7 +450,7 @@ where // Clear buffered outputs. let init_challenger_state = challenger.compact(); - prove_with_commitment( + let proof = prove_with_commitment( stark, config, trace_poly_values, @@ -350,7 +464,9 @@ where .map(|proof_with_pis| StarkProofWithMetadata { proof: proof_with_pis.proof, init_challenger_state, - }) + })?; + + Ok((proof, trace_commitment.merkle_tree.cap.clone())) } /// Utility method that checks whether a kill signal has been emitted by one of @@ -366,6 +482,159 @@ pub fn check_abort_signal(abort_signal: Option>) -> Result<()> { Ok(()) } +/// Builds a new `GenerationSegmentData`. +#[allow(clippy::unwrap_or_default)] +fn build_segment_data( + segment_index: usize, + registers_before: Option, + registers_after: Option, + memory: Option, + interpreter: &Interpreter, +) -> GenerationSegmentData { + GenerationSegmentData { + segment_index, + registers_before: registers_before.unwrap_or(RegistersState::new()), + registers_after: registers_after.unwrap_or(RegistersState::new()), + memory: memory.unwrap_or(MemoryState { + preinitialized_segments: interpreter + .generation_state + .memory + .preinitialized_segments + .clone(), + ..Default::default() + }), + max_cpu_len_log: interpreter.get_max_cpu_len_log(), + extra_data: ExtraSegmentData { + bignum_modmul_result_limbs: interpreter + .generation_state + .bignum_modmul_result_limbs + .clone(), + rlp_prover_inputs: interpreter.generation_state.rlp_prover_inputs.clone(), + withdrawal_prover_inputs: interpreter + .generation_state + .withdrawal_prover_inputs + .clone(), + ger_prover_inputs: interpreter.generation_state.ger_prover_inputs.clone(), + trie_root_ptrs: interpreter.generation_state.trie_root_ptrs.clone(), + jumpdest_table: interpreter.generation_state.jumpdest_table.clone(), + next_txn_index: interpreter.generation_state.next_txn_index, + }, + } +} + +pub struct SegmentDataIterator { + interpreter: Interpreter, + partial_next_data: Option, +} + +pub type SegmentRunResult = Option)>>; + +#[derive(thiserror::Error, Debug, Serialize, Deserialize)] +#[error("{}", .0)] +pub struct SegmentError(pub String); + +impl SegmentDataIterator { + pub fn new(inputs: &GenerationInputs, max_cpu_len_log: Option) -> Self { + debug_inputs(inputs); + + let interpreter = Interpreter::::new_with_generation_inputs( + KERNEL.global_labels["init"], + vec![], + inputs, + max_cpu_len_log, + ); + + Self { + interpreter, + partial_next_data: None, + } + } + + /// Returns the data for the current segment, as well as the data -- except + /// registers_after -- for the next segment. + fn generate_next_segment( + &mut self, + partial_segment_data: Option, + ) -> Result { + // Get the (partial) current segment data, if it is provided. Otherwise, + // initialize it. + let mut segment_data = if let Some(partial) = partial_segment_data { + if partial.registers_after.program_counter == KERNEL.global_labels["halt"] { + return Ok(None); + } + self.interpreter + .get_mut_generation_state() + .set_segment_data(&partial); + self.interpreter.generation_state.memory = partial.memory.clone(); + partial + } else { + build_segment_data(0, None, None, None, &self.interpreter) + }; + + let segment_index = segment_data.segment_index; + + // Run the interpreter to get `registers_after` and the partial data for the + // next segment. + let run = set_registers_and_run(segment_data.registers_after, &mut self.interpreter); + if let Ok((updated_registers, mem_after)) = run { + let partial_segment_data = Some(build_segment_data( + segment_index + 1, + Some(updated_registers), + Some(updated_registers), + mem_after, + &self.interpreter, + )); + + segment_data.registers_after = updated_registers; + Ok(Some(Box::new((segment_data, partial_segment_data)))) + } else { + let inputs = &self.interpreter.get_generation_state().inputs; + let block = inputs.block_metadata.block_number; + let txn_range = match inputs.txn_hashes.len() { + 0 => "Dummy".to_string(), + 1 => format!("{:?}", inputs.txn_number_before), + _ => format!( + "{:?}_{:?}", + inputs.txn_number_before, + inputs.txn_number_before + inputs.txn_hashes.len() + ), + }; + let s = format!( + "Segment generation {:?} for block {:?} ({}) failed with error {:?}", + segment_index, + block, + txn_range, + run.unwrap_err() + ); + Err(SegmentError(s)) + } + } +} + +impl Iterator for SegmentDataIterator { + type Item = AllData; + + fn next(&mut self) -> Option { + let run = self.generate_next_segment(self.partial_next_data.clone()); + + if let Ok(segment_run) = run { + match segment_run { + // The run was valid, but didn't not consume the payload fully. + Some(boxed) => { + let (data, next_data) = *boxed; + self.partial_next_data = next_data; + Some(Ok((self.interpreter.generation_state.inputs.clone(), data))) + } + // The payload was fully consumed. + None => None, + } + } else { + // The run encountered some error. + Some(Err(run.unwrap_err())) + } + } +} + /// A utility module designed to test witness generation externally. pub mod testing { use super::*; @@ -378,14 +647,64 @@ pub mod testing { /// It does not generate any trace or proof of correct state transition. pub fn simulate_execution(inputs: GenerationInputs) -> Result<()> { let initial_stack = vec![]; - let initial_offset = KERNEL.global_labels["main"]; + let initial_offset = KERNEL.global_labels["init"]; let mut interpreter: Interpreter = - Interpreter::new_with_generation_inputs(initial_offset, initial_stack, inputs); + Interpreter::new_with_generation_inputs(initial_offset, initial_stack, &inputs, None); let result = interpreter.run(); + if result.is_err() { output_debug_tries(interpreter.get_generation_state())?; } - result + result?; + Ok(()) + } + + pub fn prove_all_segments( + all_stark: &AllStark, + config: &StarkConfig, + inputs: GenerationInputs, + max_cpu_len_log: usize, + timing: &mut TimingTree, + abort_signal: Option>, + ) -> Result>> + where + F: RichField + Extendable, + C: GenericConfig, + { + let segment_data_iterator = SegmentDataIterator::::new(&inputs, Some(max_cpu_len_log)); + let inputs = inputs.trim(); + let mut proofs = vec![]; + + for segment_run in segment_data_iterator { + let (_, mut next_data) = segment_run.map_err(|e| anyhow::format_err!(e))?; + let proof = prove( + all_stark, + config, + inputs.clone(), + &mut next_data, + timing, + abort_signal.clone(), + )?; + proofs.push(proof); + } + + Ok(proofs) + } + + pub fn simulate_execution_all_segments( + inputs: GenerationInputs, + max_cpu_len_log: usize, + ) -> Result<()> + where + F: RichField, + { + for segment in SegmentDataIterator::::new(&inputs, Some(max_cpu_len_log)) { + if let Err(e) = segment { + return Err(anyhow::format_err!(e)); + } + } + + Ok(()) } } diff --git a/evm_arithmetization/src/recursive_verifier.rs b/evm_arithmetization/src/recursive_verifier.rs index 26ee4c116..391032307 100644 --- a/evm_arithmetization/src/recursive_verifier.rs +++ b/evm_arithmetization/src/recursive_verifier.rs @@ -7,7 +7,7 @@ use plonky2::field::extension::Extendable; use plonky2::gates::exponentiation::ExponentiationGate; use plonky2::gates::gate::GateRef; use plonky2::gates::noop::NoopGate; -use plonky2::hash::hash_types::RichField; +use plonky2::hash::hash_types::{HashOut, MerkleCapTarget, RichField}; use plonky2::hash::hashing::PlonkyPermutation; use plonky2::iop::challenger::RecursiveChallenger; use plonky2::iop::target::Target; @@ -36,7 +36,8 @@ use crate::memory::segments::Segment; use crate::memory::VALUE_LIMBS; use crate::proof::{ BlockHashes, BlockHashesTarget, BlockMetadata, BlockMetadataTarget, ExtraBlockData, - ExtraBlockDataTarget, PublicValues, PublicValuesTarget, TrieRoots, TrieRootsTarget, + ExtraBlockDataTarget, MemCap, MemCapTarget, PublicValues, PublicValuesTarget, RegistersData, + RegistersDataTarget, TrieRoots, TrieRootsTarget, DEFAULT_CAP_LEN, }; use crate::util::{h256_limbs, u256_limbs, u256_to_u32, u256_to_u64}; use crate::witness::errors::ProgramError; @@ -115,6 +116,7 @@ where buffer.write_target(self.zero_target)?; self.stark_proof_target.to_buffer(buffer)?; self.ctl_challenges_target.to_buffer(buffer)?; + Ok(()) } @@ -130,6 +132,7 @@ where let zero_target = buffer.read_target()?; let stark_proof_target = StarkProofTarget::from_buffer(buffer)?; let ctl_challenges_target = GrandProductChallengeSet::from_buffer(buffer)?; + Ok(Self { circuit, stark_proof_target, @@ -523,6 +526,47 @@ pub(crate) fn get_memory_extra_looking_sum_circuit, &[kernel_len_target], ); + // Write registers. + let registers_segment = + builder.constant(F::from_canonical_usize(Segment::RegistersStates.unscale())); + let registers_before: [&[Target]; 6] = [ + &[public_values.registers_before.program_counter], + &[public_values.registers_before.is_kernel], + &[public_values.registers_before.stack_len], + &public_values.registers_before.stack_top, + &[public_values.registers_before.context], + &[public_values.registers_before.gas_used], + ]; + for i in 0..registers_before.len() { + sum = add_data_write( + builder, + challenge, + sum, + registers_segment, + i, + registers_before[i], + ); + } + + let registers_after: [&[Target]; 6] = [ + &[public_values.registers_after.program_counter], + &[public_values.registers_after.is_kernel], + &[public_values.registers_after.stack_len], + &public_values.registers_after.stack_top, + &[public_values.registers_after.context], + &[public_values.registers_after.gas_used], + ]; + for i in 0..registers_before.len() { + sum = add_data_write( + builder, + challenge, + sum, + registers_segment, + registers_before.len() + i, + registers_after[i], + ); + } + sum } @@ -558,8 +602,9 @@ fn add_data_write, const D: usize>( builder.assert_zero(row[4 + j]); } - // timestamp = 1 - builder.assert_one(row[12]); + // timestamp = 2 + let two = builder.constant(F::TWO); + builder.connect(row[12], two); let combined = challenge.combine_base_circuit(builder, &row); let inverse = builder.inverse(combined); @@ -574,12 +619,26 @@ pub(crate) fn add_virtual_public_values, const D: u let block_metadata = add_virtual_block_metadata(builder); let block_hashes = add_virtual_block_hashes(builder); let extra_block_data = add_virtual_extra_block_data(builder); + let registers_before = add_virtual_registers_data(builder); + let registers_after = add_virtual_registers_data(builder); + + let mem_before = MemCapTarget { + mem_cap: MerkleCapTarget(builder.add_virtual_hashes_public_input(DEFAULT_CAP_LEN)), + }; + let mem_after = MemCapTarget { + mem_cap: MerkleCapTarget(builder.add_virtual_hashes_public_input(DEFAULT_CAP_LEN)), + }; + PublicValuesTarget { trie_roots_before, trie_roots_after, block_metadata, block_hashes, extra_block_data, + registers_before, + registers_after, + mem_before, + mem_after, } } @@ -656,6 +715,25 @@ pub(crate) fn add_virtual_extra_block_data, const D } } +pub(crate) fn add_virtual_registers_data, const D: usize>( + builder: &mut CircuitBuilder, +) -> RegistersDataTarget { + let program_counter = builder.add_virtual_public_input(); + let is_kernel = builder.add_virtual_public_input(); + let stack_len = builder.add_virtual_public_input(); + let stack_top = builder.add_virtual_public_input_arr(); + let context = builder.add_virtual_public_input(); + let gas_used = builder.add_virtual_public_input(); + RegistersDataTarget { + program_counter, + is_kernel, + stack_len, + stack_top, + context, + gas_used, + } +} + pub(crate) fn debug_public_values(public_values: &PublicValues) { log::debug!("Public Values:"); log::debug!( @@ -704,6 +782,27 @@ where &public_values_target.extra_block_data, &public_values.extra_block_data, )?; + set_registers_target( + witness, + &public_values_target.registers_before, + &public_values.registers_before, + )?; + set_registers_target( + witness, + &public_values_target.registers_after, + &public_values.registers_after, + )?; + + set_mem_cap_target( + witness, + &public_values_target.mem_before, + &public_values.mem_before, + )?; + set_mem_cap_target( + witness, + &public_values_target.mem_after, + &public_values.mem_after, + )?; Ok(()) } @@ -888,3 +987,42 @@ where Ok(()) } + +pub(crate) fn set_registers_target( + witness: &mut W, + rd_target: &RegistersDataTarget, + rd: &RegistersData, +) -> Result<(), ProgramError> +where + F: RichField + Extendable, + W: Witness, +{ + witness.set_target(rd_target.program_counter, u256_to_u32(rd.program_counter)?); + witness.set_target(rd_target.is_kernel, u256_to_u32(rd.is_kernel)?); + witness.set_target(rd_target.stack_len, u256_to_u32(rd.stack_len)?); + witness.set_target_arr(&rd_target.stack_top, &u256_limbs(rd.stack_top)); + witness.set_target(rd_target.context, u256_to_u32(rd.context)?); + witness.set_target(rd_target.gas_used, u256_to_u32(rd.gas_used)?); + + Ok(()) +} + +pub(crate) fn set_mem_cap_target( + witness: &mut W, + mc_target: &MemCapTarget, + mc: &MemCap, +) -> Result<(), ProgramError> +where + F: RichField + Extendable, + W: Witness, +{ + for i in 0..mc.mem_cap.len() { + witness.set_hash_target( + mc_target.mem_cap.0[i], + HashOut { + elements: mc.mem_cap[i].map(|elt| F::from_canonical_u64(elt.as_u64())), + }, + ); + } + Ok(()) +} diff --git a/evm_arithmetization/src/util.rs b/evm_arithmetization/src/util.rs index f8463258b..a9aae2d2b 100644 --- a/evm_arithmetization/src/util.rs +++ b/evm_arithmetization/src/util.rs @@ -236,6 +236,18 @@ pub(crate) fn get_h256(slice: &[F]) -> H256 { ) } +pub(crate) fn get_u256(slice: &[F; 8]) -> U256 { + U256( + (0..4) + .map(|i| { + slice[2 * i].to_canonical_u64() + (slice[2 * i + 1].to_noncanonical_u64() << 32) + }) + .collect::>() + .try_into() + .unwrap(), + ) +} + /// Standard Sha2 implementation. pub(crate) fn sha2(input: Vec) -> U256 { let mut hasher = Sha256::new(); diff --git a/evm_arithmetization/src/verifier.rs b/evm_arithmetization/src/verifier.rs index 90a287f1e..329289c39 100644 --- a/evm_arithmetization/src/verifier.rs +++ b/evm_arithmetization/src/verifier.rs @@ -1,9 +1,14 @@ -use anyhow::Result; +use anyhow::{ensure, Result}; use ethereum_types::{BigEndianHash, U256}; use itertools::Itertools; use plonky2::field::extension::Extendable; +use plonky2::field::polynomial::PolynomialValues; +use plonky2::fri::oracle::PolynomialBatch; use plonky2::hash::hash_types::RichField; -use plonky2::plonk::config::GenericConfig; +use plonky2::hash::merkle_tree::MerkleCap; +use plonky2::plonk::config::{GenericConfig, GenericHashOut}; +use plonky2::util::timing::TimingTree; +use plonky2::util::transpose; use starky::config::StarkConfig; use starky::cross_table_lookup::{get_ctl_vars_from_proofs, verify_cross_table_lookups}; use starky::lookup::GrandProductChallenge; @@ -18,13 +23,105 @@ use crate::memory::VALUE_LIMBS; use crate::proof::{AllProof, AllProofChallenges, PublicValues}; use crate::util::h2u; -pub fn verify_proof, C: GenericConfig, const D: usize>( +pub(crate) fn initial_memory_merkle_cap< + F: RichField + Extendable, + C: GenericConfig, + const D: usize, +>( + rate_bits: usize, + cap_height: usize, +) -> MerkleCap { + // At the start of a transaction proof, `MemBefore` only contains the kernel + // `Code` segment and the `ShiftTable`. + let mut trace = Vec::with_capacity((KERNEL.code.len() + 256).next_power_of_two()); + + // Push kernel code. + for (i, &byte) in KERNEL.code.iter().enumerate() { + let mut row = vec![F::ZERO; crate::memory_continuation::columns::NUM_COLUMNS]; + row[crate::memory_continuation::columns::FILTER] = F::ONE; + row[crate::memory_continuation::columns::ADDR_CONTEXT] = F::ZERO; + row[crate::memory_continuation::columns::ADDR_SEGMENT] = + F::from_canonical_usize(Segment::Code.unscale()); + row[crate::memory_continuation::columns::ADDR_VIRTUAL] = F::from_canonical_usize(i); + row[crate::memory_continuation::columns::value_limb(0)] = F::from_canonical_u8(byte); + trace.push(row); + } + let mut val = U256::one(); + // Push shift table. + for i in 0..256 { + let mut row = vec![F::ZERO; crate::memory_continuation::columns::NUM_COLUMNS]; + + row[crate::memory_continuation::columns::FILTER] = F::ONE; + row[crate::memory_continuation::columns::ADDR_CONTEXT] = F::ZERO; + row[crate::memory_continuation::columns::ADDR_SEGMENT] = + F::from_canonical_usize(Segment::ShiftTable.unscale()); + row[crate::memory_continuation::columns::ADDR_VIRTUAL] = F::from_canonical_usize(i); + for j in 0..crate::memory::VALUE_LIMBS { + row[j + 4] = F::from_canonical_u32((val >> (j * 32)).low_u32()); + } + trace.push(row); + val <<= 1; + } + + // Padding. + let num_rows = trace.len(); + let num_rows_padded = num_rows.next_power_of_two(); + trace.resize( + num_rows_padded, + vec![F::ZERO; crate::memory_continuation::columns::NUM_COLUMNS], + ); + + let cols = transpose(&trace); + let polys = cols + .into_iter() + .map(|column| PolynomialValues::new(column)) + .collect::>(); + + PolynomialBatch::::from_values( + polys, + rate_bits, + false, + cap_height, + &mut TimingTree::default(), + None, + ) + .merkle_tree + .cap +} + +fn verify_initial_memory< + F: RichField + Extendable, + C: GenericConfig, + const D: usize, +>( + public_values: &PublicValues, + config: &StarkConfig, +) -> Result<()> { + for (hash1, hash2) in initial_memory_merkle_cap::( + config.fri_config.rate_bits, + config.fri_config.cap_height, + ) + .0 + .iter() + .zip(public_values.mem_before.mem_cap.iter()) + { + for (&limb1, limb2) in hash1.to_vec().iter().zip(hash2) { + ensure!( + limb1 == F::from_canonical_u64(limb2.as_u64()), + anyhow::Error::msg("Invalid initial MemBefore Merkle cap.") + ); + } + } + + Ok(()) +} + +fn verify_proof, C: GenericConfig, const D: usize>( all_stark: &AllStark, all_proof: AllProof, config: &StarkConfig, -) -> Result<()> -where -{ + is_initial: bool, +) -> Result<()> { let AllProofChallenges { stark_challenges, ctl_challenges, @@ -42,6 +139,8 @@ where keccak_sponge_stark, logic_stark, memory_stark, + mem_before_stark, + mem_after_stark, cross_table_lookups, } = all_stark; @@ -112,9 +211,30 @@ where &[], config, )?; + verify_stark_proof_with_challenges( + mem_before_stark, + &stark_proofs[Table::MemBefore as usize].proof, + &stark_challenges[Table::MemBefore as usize], + Some(&ctl_vars_per_table[Table::MemBefore as usize]), + &[], + config, + )?; + verify_stark_proof_with_challenges( + mem_after_stark, + &stark_proofs[Table::MemAfter as usize].proof, + &stark_challenges[Table::MemAfter as usize], + Some(&ctl_vars_per_table[Table::MemAfter as usize]), + &[], + config, + )?; let public_values = all_proof.public_values; + // Verify shift table and kernel code. + if is_initial { + verify_initial_memory::(&public_values, config)?; + } + // Extra sums to add to the looked last value. // Only necessary for the Memory values. let mut extra_looking_sums = vec![vec![F::ZERO; config.num_challenges]; NUM_TABLES]; @@ -268,6 +388,36 @@ where sum = add_data_write(challenge, block_hashes_segment, sum, index, val); } + let registers_segment = F::from_canonical_usize(Segment::RegistersStates.unscale()); + let registers_before = [ + public_values.registers_before.program_counter, + public_values.registers_before.is_kernel, + public_values.registers_before.stack_len, + public_values.registers_before.stack_top, + public_values.registers_before.context, + public_values.registers_before.gas_used, + ]; + for i in 0..registers_before.len() { + sum = add_data_write(challenge, registers_segment, sum, i, registers_before[i]); + } + let registers_after = [ + public_values.registers_after.program_counter, + public_values.registers_after.is_kernel, + public_values.registers_after.stack_len, + public_values.registers_after.stack_top, + public_values.registers_after.context, + public_values.registers_after.gas_used, + ]; + for i in 0..registers_before.len() { + sum = add_data_write( + challenge, + registers_segment, + sum, + registers_before.len() + i, + registers_after[i], + ); + } + sum } @@ -290,10 +440,35 @@ where for j in 0..VALUE_LIMBS { row[j + 4] = F::from_canonical_u32((val >> (j * 32)).low_u32()); } - row[12] = F::ONE; // timestamp + row[12] = F::TWO; // timestamp running_sum + challenge.combine(row.iter()).inverse() } +/// A utility module designed to verify proofs. +pub mod testing { + use super::*; + + pub fn verify_all_proofs< + F: RichField + Extendable, + C: GenericConfig, + const D: usize, + >( + all_stark: &AllStark, + all_proofs: &[AllProof], + config: &StarkConfig, + ) -> Result<()> { + assert!(!all_proofs.is_empty()); + + verify_proof(all_stark, all_proofs[0].clone(), config, true)?; + + for all_proof in &all_proofs[1..] { + verify_proof(all_stark, all_proof.clone(), config, false)?; + } + + Ok(()) + } +} + #[cfg(debug_assertions)] pub(crate) mod debug_utils { use super::*; @@ -425,6 +600,39 @@ pub(crate) mod debug_utils { extra_looking_rows.push(add_extra_looking_row(block_hashes_segment, index, val)); } + // Add registers writes. + let registers_segment = F::from_canonical_usize(Segment::RegistersStates.unscale()); + let registers_before = [ + public_values.registers_before.program_counter, + public_values.registers_before.is_kernel, + public_values.registers_before.stack_len, + public_values.registers_before.stack_top, + public_values.registers_before.context, + public_values.registers_before.gas_used, + ]; + for i in 0..registers_before.len() { + extra_looking_rows.push(add_extra_looking_row( + registers_segment, + i, + registers_before[i], + )); + } + let registers_after = [ + public_values.registers_after.program_counter, + public_values.registers_after.is_kernel, + public_values.registers_after.stack_len, + public_values.registers_after.stack_top, + public_values.registers_after.context, + public_values.registers_after.gas_used, + ]; + for i in 0..registers_before.len() { + extra_looking_rows.push(add_extra_looking_row( + registers_segment, + registers_before.len() + i, + registers_after[i], + )); + } + extra_looking_rows } @@ -441,7 +649,7 @@ pub(crate) mod debug_utils { for j in 0..VALUE_LIMBS { row[j + 4] = F::from_canonical_u32((val >> (j * 32)).low_u32()); } - row[12] = F::ONE; // timestamp + row[12] = F::TWO; // timestamp row } } diff --git a/evm_arithmetization/src/witness/memory.rs b/evm_arithmetization/src/witness/memory.rs index de1515f40..ff616a348 100644 --- a/evm_arithmetization/src/witness/memory.rs +++ b/evm_arithmetization/src/witness/memory.rs @@ -1,6 +1,8 @@ use std::collections::HashMap; use ethereum_types::U256; +use serde::{Deserialize, Serialize}; +use serde_big_array::BigArray; use crate::cpu::membus::{NUM_CHANNELS, NUM_GP_CHANNELS}; @@ -15,8 +17,8 @@ use MemoryChannel::{Code, GeneralPurpose, PartialChannel}; use super::operation::CONTEXT_SCALING_FACTOR; use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; -use crate::memory::segments::{Segment, SEGMENT_SCALING_FACTOR}; -use crate::witness::errors::MemoryError::{ContextTooLarge, SegmentTooLarge, VirtTooLarge}; +use crate::memory::segments::{Segment, PREINITIALIZED_SEGMENTS_INDICES, SEGMENT_SCALING_FACTOR}; +use crate::witness::errors::MemoryError::SegmentTooLarge; use crate::witness::errors::ProgramError; use crate::witness::errors::ProgramError::MemoryError; @@ -50,38 +52,21 @@ impl MemoryAddress { } } - pub(crate) fn new_u256s( - context: U256, - segment: U256, - virt: U256, - ) -> Result { - if context.bits() > 32 { - return Err(MemoryError(ContextTooLarge { context })); - } - if segment >= Segment::COUNT.into() { - return Err(MemoryError(SegmentTooLarge { segment })); - } - if virt.bits() > 32 { - return Err(MemoryError(VirtTooLarge { virt })); - } - - // Calling `as_usize` here is safe as those have been checked above. - Ok(Self { - context: context.as_usize(), - segment: segment.as_usize(), - virt: virt.as_usize(), - }) - } - /// Creates a new `MemoryAddress` from a bundled address fitting a `U256`. /// It will recover the virtual offset as the lowest 32-bit limb, the /// segment as the next limb, and the context as the next one. pub(crate) fn new_bundle(addr: U256) -> Result { - let virt = addr.low_u32().into(); - let segment = (addr >> SEGMENT_SCALING_FACTOR).low_u32().into(); - let context = (addr >> CONTEXT_SCALING_FACTOR).low_u32().into(); + let virt = addr.low_u32() as usize; + let segment = (addr >> SEGMENT_SCALING_FACTOR).low_u32() as usize; + let context = (addr >> CONTEXT_SCALING_FACTOR).low_u32() as usize; + + if segment >= Segment::COUNT { + return Err(MemoryError(SegmentTooLarge { + segment: segment.into(), + })); + } - Self::new_u256s(context, segment, virt) + Ok(Self::new(context, Segment::all()[segment], virt)) } pub(crate) fn increment(&mut self) { @@ -126,7 +111,9 @@ impl MemoryOp { kind: MemoryOpKind, value: U256, ) -> Self { - let timestamp = clock * NUM_CHANNELS + channel.index(); + // Since the CPU clock starts at 1, and the `clock` value is the CPU length, the + // timestamps is: `timestamp = clock * NUM_CHANNELS + 1 + channel` + let timestamp = clock * NUM_CHANNELS + 1 + channel.index(); MemoryOp { filter: true, timestamp, @@ -160,7 +147,7 @@ impl MemoryOp { } } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub(crate) struct MemoryState { pub(crate) contexts: Vec, pub(crate) preinitialized_segments: HashMap, @@ -211,6 +198,24 @@ impl MemoryState { Some(val) } + /// Returns the memory values associated with a preinitialized segment. We + /// need a specific behaviour here, since the values can be stored either in + /// `preinitialized_segments` or in the memory itself. + pub(crate) fn get_preinit_memory(&self, segment: Segment) -> Vec> { + assert!(PREINITIALIZED_SEGMENTS_INDICES.contains(&segment.unscale())); + let len = self + .preinitialized_segments + .get(&segment) + .unwrap_or(&MemorySegmentState { content: vec![] }) + .content + .len() + .max(self.contexts[0].segments[segment.unscale()].content.len()); + + (0..len) + .map(|i| Some(self.get_with_init(MemoryAddress::new(0, segment, i)))) + .collect::>() + } + /// Returns a memory value, or 0 if the memory is unset. If we have some /// preinitialized segments (in interpreter mode), then the values might not /// be stored in memory yet. If the value in memory is not set and the @@ -232,7 +237,7 @@ impl MemoryState { .len() { self.preinitialized_segments.get(&segment).unwrap().content[offset] - .expect("We checked that the offset is not out of bounds.") + .unwrap_or_default() } else { 0.into() } @@ -269,6 +274,7 @@ impl MemoryState { segment: Segment, values: MemorySegmentState, ) { + assert!(PREINITIALIZED_SEGMENTS_INDICES.contains(&segment.unscale())); self.preinitialized_segments.insert(segment, values); } @@ -293,8 +299,9 @@ impl Default for MemoryState { } } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub(crate) struct MemoryContextState { + #[serde(with = "BigArray")] /// The content of each memory segment. pub(crate) segments: [MemorySegmentState; Segment::COUNT], } @@ -307,7 +314,7 @@ impl Default for MemoryContextState { } } -#[derive(Clone, Default, Debug)] +#[derive(Clone, Default, Debug, Serialize, Deserialize)] pub(crate) struct MemorySegmentState { pub(crate) content: Vec>, } diff --git a/evm_arithmetization/src/witness/operation.rs b/evm_arithmetization/src/witness/operation.rs index ac884c74a..0076388ce 100644 --- a/evm_arithmetization/src/witness/operation.rs +++ b/evm_arithmetization/src/witness/operation.rs @@ -159,7 +159,6 @@ pub(crate) fn generate_keccak_general>( push_no_write(generation_state, hash.into_uint()); state.log_debug(format!("Hashing {:?}", input)); - keccak_sponge_log(state, base_address, input); state.push_memory(log_in1); @@ -288,6 +287,9 @@ pub(crate) fn generate_set_context>( // The popped value needs to be scaled down. let new_ctx = u256_to_usize(ctx >> CONTEXT_SCALING_FACTOR)?; + // Flag indicating whether the old context should be pruned. + let flag = ctx & 1.into(); + let sp_field = ContextMetadata::StackSize.unscale(); let old_sp_addr = MemoryAddress::new(old_ctx, Segment::ContextMetadata, sp_field); let new_sp_addr = MemoryAddress::new(new_ctx, Segment::ContextMetadata, sp_field); @@ -308,9 +310,6 @@ pub(crate) fn generate_set_context>( ); (sp_to_save, op) } else { - // Even though we might be in the interpreter, `Stack` is not part of the - // preinitialized segments, so we don't need to carry out the additional checks - // when get the value from memory. mem_read_with_log(GeneralPurpose(2), new_sp_addr, generation_state) }; @@ -343,6 +342,11 @@ pub(crate) fn generate_set_context>( None }; + if flag == 1.into() { + row.general.context_pruning_mut().pruning_flag = F::ONE; + generation_state.stale_contexts.push(old_ctx); + } + generation_state.registers.context = new_ctx; generation_state.registers.stack_len = new_sp; if let Some(mem_op) = log_read_new_top { @@ -462,9 +466,6 @@ pub(crate) fn generate_dup>( (stack_top, op) } else { - // Even though we might be in the interpreter, `Stack` is not part of the - // preinitialized segments, so we don't need to carry out the additional checks - // when get the value from memory. mem_read_gp_with_log_and_fill(2, other_addr, generation_state, &mut row) }; push_no_write(generation_state, val); @@ -492,9 +493,7 @@ pub(crate) fn generate_swap>( ); let [(in0, _)] = stack_pop_with_log_and_fill::<1, _>(generation_state, &mut row)?; - // Even though we might be in the interpreter, `Stack` is not part of the - // preinitialized segments, so we don't need to carry out the additional checks - // when get the value from memory. + let (in1, log_in1) = mem_read_gp_with_log_and_fill(1, other_addr, generation_state, &mut row); let log_out0 = mem_write_gp_log_and_fill(2, other_addr, generation_state, &mut row, in0); push_no_write(generation_state, in1); @@ -561,13 +560,9 @@ fn append_shift>( const LOOKUP_CHANNEL: usize = 2; let lookup_addr = MemoryAddress::new(0, Segment::ShiftTable, input0.low_u32() as usize); let read_op = if input0.bits() <= 32 { - // Even though we might be in the interpreter, `ShiftTable` is not part of the - // preinitialized segments, so we don't need to carry out the additional checks - // when get the value from memory. let (_, read) = mem_read_gp_with_log_and_fill(LOOKUP_CHANNEL, lookup_addr, generation_state, &mut row); Some(read) - // state.push_memory(read); } else { // The shift constraints still expect the address to be set, even though no read // will occur. @@ -663,9 +658,7 @@ pub(crate) fn generate_syscall>( virt: base_address.virt + i, ..base_address }; - // Even though we might be in the interpreter, `Code` is not part of the - // preinitialized segments, so we don't need to carry out the additional checks - // when get the value from memory. + let val = generation_state.memory.get_with_init(address); val.low_u32() as u8 }) @@ -946,7 +939,10 @@ pub(crate) fn generate_exception>( let gas = U256::from(generation_state.registers.gas_used); - let exc_info = U256::from(generation_state.registers.program_counter) + (gas << 192); + // `is_kernel_mode` is only necessary for the halting `exc_stop` exception. + let exc_info = U256::from(generation_state.registers.program_counter) + + (U256::from(generation_state.registers.is_kernel as u64) << 32) + + (gas << 192); // Get the opcode so we can provide it to the range_check operation. let code_context = generation_state.registers.code_context(); diff --git a/evm_arithmetization/src/witness/state.rs b/evm_arithmetization/src/witness/state.rs index 5edcf69bb..92af12b32 100644 --- a/evm_arithmetization/src/witness/state.rs +++ b/evm_arithmetization/src/witness/state.rs @@ -1,10 +1,11 @@ use ethereum_types::U256; +use serde::{Deserialize, Serialize}; use crate::cpu::kernel::aggregator::KERNEL; pub(crate) const KERNEL_CONTEXT: usize = 0; -#[derive(Clone, Copy, Debug, Eq, PartialEq)] +#[derive(Clone, Copy, Debug, Eq, PartialEq, Deserialize, Serialize)] pub struct RegistersState { pub program_counter: usize, pub is_kernel: bool, @@ -20,6 +21,8 @@ pub struct RegistersState { } impl RegistersState { + /// Returns the KERNEL context in kernel mode, and the + /// current context otherwise. pub(crate) const fn code_context(&self) -> usize { if self.is_kernel { KERNEL_CONTEXT @@ -27,12 +30,21 @@ impl RegistersState { self.context } } + + /// Returns a `RegisterState` corresponding to the start + /// of a full transaction proof. + pub(crate) fn new() -> Self { + Self { + program_counter: KERNEL.global_labels["main"], + ..Self::default() + } + } } impl Default for RegistersState { fn default() -> Self { Self { - program_counter: KERNEL.global_labels["main"], + program_counter: KERNEL.global_labels["init"], is_kernel: true, stack_len: 0, stack_top: U256::zero(), diff --git a/evm_arithmetization/src/witness/traces.rs b/evm_arithmetization/src/witness/traces.rs index 7dd973b2a..fe59bba59 100644 --- a/evm_arithmetization/src/witness/traces.rs +++ b/evm_arithmetization/src/witness/traces.rs @@ -10,7 +10,9 @@ use crate::all_stark::{AllStark, NUM_TABLES}; use crate::arithmetic::{BinaryOperator, Operation}; use crate::byte_packing::byte_packing_stark::BytePackingOp; use crate::cpu::columns::CpuColumnsView; +use crate::generation::MemBeforeValues; use crate::keccak_sponge::keccak_sponge_stark::KeccakSpongeOp; +use crate::memory_continuation::memory_continuation_stark::mem_before_values_to_rows; use crate::witness::memory::MemoryOp; use crate::{arithmetic, keccak, keccak_sponge, logic}; @@ -120,6 +122,9 @@ impl Traces { pub(crate) fn into_tables( self, all_stark: &AllStark, + mem_before_values: &MemBeforeValues, + stale_contexts: Vec, + mut trace_lengths: TraceCheckpoint, config: &StarkConfig, timing: &mut TimingTree, ) -> [Vec>; NUM_TABLES] @@ -172,10 +177,38 @@ impl Traces { .logic_stark .generate_trace(logic_ops, cap_elements, timing) ); - let memory_trace = timed!( + let (memory_trace, final_values, unpadded_memory_length) = timed!( timing, "generate memory trace", - all_stark.memory_stark.generate_trace(memory_ops, timing) + all_stark.memory_stark.generate_trace( + memory_ops, + mem_before_values, + stale_contexts, + timing + ) + ); + trace_lengths.memory_len = unpadded_memory_length; + + let mem_before_trace = timed!( + timing, + "generate mem_before trace", + all_stark + .mem_before_stark + .generate_trace(mem_before_values_to_rows(mem_before_values)) + ); + let mem_after_trace = timed!( + timing, + "generate mem_after trace", + all_stark + .mem_after_stark + .generate_trace(final_values.clone()) + ); + + log::info!( + "Trace lengths (before padding): {:?}, mem_before_len: {}, mem_after_len: {}", + trace_lengths, + mem_before_values.len(), + final_values.len() ); [ @@ -186,6 +219,8 @@ impl Traces { keccak_sponge_trace, logic_trace, memory_trace, + mem_before_trace, + mem_after_trace, ] } } diff --git a/evm_arithmetization/src/witness/transition.rs b/evm_arithmetization/src/witness/transition.rs index 71d4632bb..ae5ae6d97 100644 --- a/evm_arithmetization/src/witness/transition.rs +++ b/evm_arithmetization/src/witness/transition.rs @@ -20,6 +20,8 @@ use crate::witness::state::RegistersState; use crate::witness::util::mem_read_code_with_log_and_fill; use crate::{arithmetic, logic}; +pub(crate) const EXC_STOP_CODE: u8 = 6; + pub(crate) fn read_code_memory>( state: &mut T, row: &mut CpuColumnsView, @@ -296,12 +298,27 @@ pub(crate) fn log_kernel_instruction>(state: &mut S, op: O assert!(pc < KERNEL.code.len(), "Kernel PC is out of range: {}", pc); } -pub(crate) trait Transition: State { +pub(crate) trait Transition: State +where + Self: Sized, +{ /// When in jumpdest analysis, adds the offset `dst` to the jumpdest table. /// Returns a boolean indicating whether we are running the jumpdest /// analysis. fn generate_jumpdest_analysis(&mut self, dst: usize) -> bool; + fn final_exception(&mut self) -> anyhow::Result<()> { + let checkpoint = self.checkpoint(); + + let (row, _) = self.base_row(); + + generate_exception(EXC_STOP_CODE, self, row) + .map_err(|e| anyhow::anyhow!("Exception handling failed with error {:?}", e))?; + + self.apply_ops(checkpoint); + Ok(()) + } + /// Performs the next operation in the execution, and updates the gas used /// and program counter. fn perform_state_op( diff --git a/evm_arithmetization/src/witness/util.rs b/evm_arithmetization/src/witness/util.rs index 090a13981..bca6f580c 100644 --- a/evm_arithmetization/src/witness/util.rs +++ b/evm_arithmetization/src/witness/util.rs @@ -305,7 +305,7 @@ pub(crate) fn keccak_sponge_log>( address.increment(); } xor_into_sponge::(state, &mut sponge_state, block.try_into().unwrap()); - state.push_keccak_bytes(sponge_state, clock * NUM_CHANNELS); + state.push_keccak_bytes(sponge_state, clock * NUM_CHANNELS + 1); keccakf_u8s(&mut sponge_state); } @@ -330,11 +330,11 @@ pub(crate) fn keccak_sponge_log>( final_block[KECCAK_RATE_BYTES - 1] = 0b10000000; } xor_into_sponge::(state, &mut sponge_state, &final_block); - state.push_keccak_bytes(sponge_state, clock * NUM_CHANNELS); + state.push_keccak_bytes(sponge_state, clock * NUM_CHANNELS + 1); state.push_keccak_sponge(KeccakSpongeOp { base_address, - timestamp: clock * NUM_CHANNELS, + timestamp: clock * NUM_CHANNELS + 1, input, }); } @@ -366,7 +366,7 @@ pub(crate) fn byte_packing_log>( state.push_byte_packing(BytePackingOp { is_read: true, base_address, - timestamp: clock * NUM_CHANNELS, + timestamp: clock * NUM_CHANNELS + 1, bytes, }); } @@ -399,7 +399,7 @@ pub(crate) fn byte_unpacking_log>( state.push_byte_packing(BytePackingOp { is_read: false, base_address, - timestamp: clock * NUM_CHANNELS, + timestamp: clock * NUM_CHANNELS + 1, bytes, }); } diff --git a/evm_arithmetization/tests/add11_yml.rs b/evm_arithmetization/tests/add11_yml.rs index dca625d36..8de4e36ae 100644 --- a/evm_arithmetization/tests/add11_yml.rs +++ b/evm_arithmetization/tests/add11_yml.rs @@ -4,16 +4,17 @@ use std::time::Duration; use ethereum_types::{Address, BigEndianHash, H256}; use evm_arithmetization::generation::mpt::{AccountRlp, LegacyReceiptRlp}; -use evm_arithmetization::generation::{GenerationInputs, TrieInputs}; +use evm_arithmetization::generation::TrieInputs; use evm_arithmetization::proof::{BlockHashes, BlockMetadata, TrieRoots}; -use evm_arithmetization::prover::prove; +use evm_arithmetization::prover::testing::prove_all_segments; use evm_arithmetization::testing_utils::{ beacon_roots_account_nibbles, beacon_roots_contract_from_storage, ger_account_nibbles, init_logger, preinitialized_state_and_storage_tries, update_beacon_roots_account_storage, GLOBAL_EXIT_ROOT_ACCOUNT, }; -use evm_arithmetization::verifier::verify_proof; -use evm_arithmetization::{AllStark, Node, StarkConfig}; +use evm_arithmetization::verifier::testing::verify_all_proofs; +use evm_arithmetization::StarkConfig; +use evm_arithmetization::{AllStark, GenerationInputs, Node}; use hex_literal::hex; use keccak_hash::keccak; use mpt_trie::nibbles::Nibbles; @@ -26,14 +27,7 @@ type F = GoldilocksField; const D: usize = 2; type C = KeccakGoldilocksConfig; -/// The `add11_yml` test case from https://github.com/ethereum/tests -#[test] -fn add11_yml() -> anyhow::Result<()> { - init_logger(); - - let all_stark = AllStark::::default(); - let config = StarkConfig::standard_fast_config(); - +fn get_generation_inputs() -> GenerationInputs { let beneficiary = hex!("2adc25665018aa1fe0e6bc666dac8fc2697ff9ba"); let sender = hex!("a94f5374fce5edbc8e2a8697c15331677e6ebf0b"); let to = hex!("095e7baea6a6c7c4c2dfeb977efac326af552d87"); @@ -63,19 +57,26 @@ fn add11_yml() -> anyhow::Result<()> { ..AccountRlp::default() }; - let (mut state_trie_before, mut storage_tries) = preinitialized_state_and_storage_tries()?; + let (mut state_trie_before, mut storage_tries) = + preinitialized_state_and_storage_tries().unwrap(); let mut beacon_roots_account_storage = storage_tries[0].1.clone(); - state_trie_before.insert( - beneficiary_nibbles, - rlp::encode(&beneficiary_account_before).to_vec(), - )?; - state_trie_before.insert(sender_nibbles, rlp::encode(&sender_account_before).to_vec())?; - state_trie_before.insert(to_nibbles, rlp::encode(&to_account_before).to_vec())?; + state_trie_before + .insert( + beneficiary_nibbles, + rlp::encode(&beneficiary_account_before).to_vec(), + ) + .unwrap(); + state_trie_before + .insert(sender_nibbles, rlp::encode(&sender_account_before).to_vec()) + .unwrap(); + state_trie_before + .insert(to_nibbles, rlp::encode(&to_account_before).to_vec()) + .unwrap(); storage_tries.push((to_hashed, Node::Empty.into())); let tries_before = TrieInputs { - state_trie: state_trie_before, + state_trie: state_trie_before.clone(), transactions_trie: Node::Empty.into(), receipts_trie: Node::Empty.into(), storage_tries, @@ -105,7 +106,8 @@ fn add11_yml() -> anyhow::Result<()> { &mut beacon_roots_account_storage, block_metadata.block_timestamp, block_metadata.parent_beacon_block_root, - )?; + ) + .unwrap(); let beacon_roots_account = beacon_roots_contract_from_storage(&beacon_roots_account_storage); @@ -131,22 +133,30 @@ fn add11_yml() -> anyhow::Result<()> { }; let mut expected_state_trie_after = HashedPartialTrie::from(Node::Empty); - expected_state_trie_after.insert( - beneficiary_nibbles, - rlp::encode(&beneficiary_account_after).to_vec(), - )?; expected_state_trie_after - .insert(sender_nibbles, rlp::encode(&sender_account_after).to_vec())?; - expected_state_trie_after.insert(to_nibbles, rlp::encode(&to_account_after).to_vec())?; - expected_state_trie_after.insert( - beacon_roots_account_nibbles(), - rlp::encode(&beacon_roots_account).to_vec(), - )?; - expected_state_trie_after.insert( - ger_account_nibbles(), - rlp::encode(&GLOBAL_EXIT_ROOT_ACCOUNT).to_vec(), - )?; - + .insert( + beneficiary_nibbles, + rlp::encode(&beneficiary_account_after).to_vec(), + ) + .unwrap(); + expected_state_trie_after + .insert(sender_nibbles, rlp::encode(&sender_account_after).to_vec()) + .unwrap(); + expected_state_trie_after + .insert(to_nibbles, rlp::encode(&to_account_after).to_vec()) + .unwrap(); + expected_state_trie_after + .insert( + beacon_roots_account_nibbles(), + rlp::encode(&beacon_roots_account).to_vec(), + ) + .unwrap(); + expected_state_trie_after + .insert( + ger_account_nibbles(), + rlp::encode(&GLOBAL_EXIT_ROOT_ACCOUNT).to_vec(), + ) + .unwrap(); expected_state_trie_after }; @@ -157,10 +167,12 @@ fn add11_yml() -> anyhow::Result<()> { logs: vec![], }; let mut receipts_trie = HashedPartialTrie::from(Node::Empty); - receipts_trie.insert( - Nibbles::from_str("0x80").unwrap(), - rlp::encode(&receipt_0).to_vec(), - )?; + receipts_trie + .insert( + Nibbles::from_str("0x80").unwrap(), + rlp::encode(&receipt_0).to_vec(), + ) + .unwrap(); let transactions_trie: HashedPartialTrie = Node::Leaf { nibbles: Nibbles::from_str("0x80").unwrap(), value: txn.to_vec(), @@ -172,15 +184,16 @@ fn add11_yml() -> anyhow::Result<()> { transactions_root: transactions_trie.hash(), receipts_root: receipts_trie.hash(), }; - let inputs = GenerationInputs { - signed_txn: Some(txn.to_vec()), + + GenerationInputs { + signed_txns: vec![txn.to_vec()], withdrawals: vec![], global_exit_roots: vec![], tries: tries_before, trie_roots_after, contract_code, block_metadata, - checkpoint_state_trie_root: HashedPartialTrie::from(Node::Empty).hash(), + checkpoint_state_trie_root: state_trie_before.hash(), txn_number_before: 0.into(), gas_used_before: 0.into(), gas_used_after: 0xa868u64.into(), @@ -188,11 +201,31 @@ fn add11_yml() -> anyhow::Result<()> { prev_hashes: vec![H256::default(); 256], cur_hash: H256::default(), }, - }; + } +} +/// The `add11_yml` test case from https://github.com/ethereum/tests +#[test] +fn add11_yml() -> anyhow::Result<()> { + init_logger(); + + let all_stark = AllStark::::default(); + let config = StarkConfig::standard_fast_config(); + let inputs = get_generation_inputs(); + + let max_cpu_len_log = 20; let mut timing = TimingTree::new("prove", log::Level::Debug); - let proof = prove::(&all_stark, &config, inputs, &mut timing, None)?; + + let proofs = prove_all_segments::( + &all_stark, + &config, + inputs, + max_cpu_len_log, + &mut timing, + None, + )?; + timing.filter(Duration::from_millis(100)).print(); - verify_proof(&all_stark, proof, &config) + verify_all_proofs(&all_stark, &proofs, &config) } diff --git a/evm_arithmetization/tests/erc20.rs b/evm_arithmetization/tests/erc20.rs index 1c829efc1..13ef8ee21 100644 --- a/evm_arithmetization/tests/erc20.rs +++ b/evm_arithmetization/tests/erc20.rs @@ -5,13 +5,13 @@ use ethereum_types::{Address, BigEndianHash, H160, H256, U256}; use evm_arithmetization::generation::mpt::{AccountRlp, LegacyReceiptRlp, LogRlp}; use evm_arithmetization::generation::{GenerationInputs, TrieInputs}; use evm_arithmetization::proof::{BlockHashes, BlockMetadata, TrieRoots}; -use evm_arithmetization::prover::prove; +use evm_arithmetization::prover::testing::prove_all_segments; use evm_arithmetization::testing_utils::{ beacon_roots_account_nibbles, beacon_roots_contract_from_storage, create_account_storage, ger_account_nibbles, init_logger, preinitialized_state_and_storage_tries, sd2u, update_beacon_roots_account_storage, GLOBAL_EXIT_ROOT_ACCOUNT, }; -use evm_arithmetization::verifier::verify_proof; +use evm_arithmetization::verifier::testing::verify_all_proofs; use evm_arithmetization::{AllStark, Node, StarkConfig}; use hex_literal::hex; use keccak_hash::keccak; @@ -178,8 +178,9 @@ fn test_erc20() -> anyhow::Result<()> { transactions_root: transactions_trie.hash(), receipts_root: receipts_trie.hash(), }; + let inputs = GenerationInputs { - signed_txn: Some(txn.to_vec()), + signed_txns: vec![txn.to_vec()], withdrawals: vec![], global_exit_roots: vec![], tries: tries_before, @@ -196,11 +197,21 @@ fn test_erc20() -> anyhow::Result<()> { }, }; + let max_cpu_len_log = 20; let mut timing = TimingTree::new("prove", log::Level::Debug); - let proof = prove::(&all_stark, &config, inputs, &mut timing, None)?; + + let proofs = prove_all_segments::( + &all_stark, + &config, + inputs, + max_cpu_len_log, + &mut timing, + None, + )?; + timing.filter(Duration::from_millis(100)).print(); - verify_proof(&all_stark, proof, &config) + verify_all_proofs(&all_stark, &proofs, &config) } fn giver_bytecode() -> Vec { diff --git a/evm_arithmetization/tests/erc721.rs b/evm_arithmetization/tests/erc721.rs index 3a02d8968..4cf347afc 100644 --- a/evm_arithmetization/tests/erc721.rs +++ b/evm_arithmetization/tests/erc721.rs @@ -5,13 +5,13 @@ use ethereum_types::{Address, BigEndianHash, H160, H256, U256}; use evm_arithmetization::generation::mpt::{AccountRlp, LegacyReceiptRlp, LogRlp}; use evm_arithmetization::generation::{GenerationInputs, TrieInputs}; use evm_arithmetization::proof::{BlockHashes, BlockMetadata, TrieRoots}; -use evm_arithmetization::prover::prove; +use evm_arithmetization::prover::testing::prove_all_segments; use evm_arithmetization::testing_utils::{ beacon_roots_account_nibbles, beacon_roots_contract_from_storage, create_account_storage, ger_account_nibbles, init_logger, preinitialized_state_and_storage_tries, sd2u, sh2u, update_beacon_roots_account_storage, GLOBAL_EXIT_ROOT_ACCOUNT, }; -use evm_arithmetization::verifier::verify_proof; +use evm_arithmetization::verifier::testing::verify_all_proofs; use evm_arithmetization::{AllStark, Node, StarkConfig}; use hex_literal::hex; use keccak_hash::keccak; @@ -183,7 +183,7 @@ fn test_erc721() -> anyhow::Result<()> { }; let inputs = GenerationInputs { - signed_txn: Some(txn.to_vec()), + signed_txns: vec![txn.to_vec()], withdrawals: vec![], global_exit_roots: vec![], tries: tries_before, @@ -200,11 +200,21 @@ fn test_erc721() -> anyhow::Result<()> { }, }; + let max_cpu_len_log = 20; let mut timing = TimingTree::new("prove", log::Level::Debug); - let proof = prove::(&all_stark, &config, inputs, &mut timing, None)?; + + let proofs = prove_all_segments::( + &all_stark, + &config, + inputs, + max_cpu_len_log, + &mut timing, + None, + )?; + timing.filter(Duration::from_millis(100)).print(); - verify_proof(&all_stark, proof, &config) + verify_all_proofs(&all_stark, &proofs, &config) } fn contract_bytecode() -> Vec { diff --git a/evm_arithmetization/tests/global_exit_root.rs b/evm_arithmetization/tests/global_exit_root.rs index 507ffe0f7..302b1a143 100644 --- a/evm_arithmetization/tests/global_exit_root.rs +++ b/evm_arithmetization/tests/global_exit_root.rs @@ -4,13 +4,13 @@ use std::time::Duration; use ethereum_types::{H256, U256}; use evm_arithmetization::generation::{GenerationInputs, TrieInputs}; use evm_arithmetization::proof::{BlockHashes, BlockMetadata, TrieRoots}; -use evm_arithmetization::prover::prove; +use evm_arithmetization::prover::testing::prove_all_segments; use evm_arithmetization::testing_utils::{ beacon_roots_account_nibbles, beacon_roots_contract_from_storage, ger_account_nibbles, ger_contract_from_storage, init_logger, preinitialized_state_and_storage_tries, update_beacon_roots_account_storage, update_ger_account_storage, }; -use evm_arithmetization::verifier::verify_proof; +use evm_arithmetization::verifier::testing::verify_all_proofs; use evm_arithmetization::{AllStark, Node, StarkConfig}; use keccak_hash::keccak; use mpt_trie::partial_trie::{HashedPartialTrie, PartialTrie}; @@ -77,7 +77,7 @@ fn test_global_exit_root() -> anyhow::Result<()> { }; let inputs = GenerationInputs { - signed_txn: None, + signed_txns: vec![], withdrawals: vec![], global_exit_roots, tries: TrieInputs { @@ -99,9 +99,18 @@ fn test_global_exit_root() -> anyhow::Result<()> { }, }; + let max_cpu_len_log = 20; + let mut timing = TimingTree::new("prove", log::Level::Debug); - let proof = prove::(&all_stark, &config, inputs, &mut timing, None)?; + let proofs = prove_all_segments::( + &all_stark, + &config, + inputs, + max_cpu_len_log, + &mut timing, + None, + )?; timing.filter(Duration::from_millis(100)).print(); - verify_proof(&all_stark, proof, &config) + verify_all_proofs(&all_stark, &proofs, &config) } diff --git a/evm_arithmetization/tests/log_opcode.rs b/evm_arithmetization/tests/log_opcode.rs index 8cd5c57c0..5ac537c4e 100644 --- a/evm_arithmetization/tests/log_opcode.rs +++ b/evm_arithmetization/tests/log_opcode.rs @@ -10,13 +10,13 @@ use evm_arithmetization::generation::mpt::transaction_testing::{ use evm_arithmetization::generation::mpt::{AccountRlp, LegacyReceiptRlp, LogRlp}; use evm_arithmetization::generation::{GenerationInputs, TrieInputs}; use evm_arithmetization::proof::{BlockHashes, BlockMetadata, TrieRoots}; -use evm_arithmetization::prover::prove; +use evm_arithmetization::prover::testing::prove_all_segments; use evm_arithmetization::testing_utils::{ beacon_roots_account_nibbles, beacon_roots_contract_from_storage, ger_account_nibbles, init_logger, preinitialized_state_and_storage_tries, update_beacon_roots_account_storage, GLOBAL_EXIT_ROOT_ACCOUNT, }; -use evm_arithmetization::verifier::verify_proof; +use evm_arithmetization::verifier::testing::verify_all_proofs; use evm_arithmetization::{AllStark, Node, StarkConfig}; use hex_literal::hex; use keccak_hash::keccak; @@ -238,7 +238,7 @@ fn test_log_opcodes() -> anyhow::Result<()> { }; let inputs = GenerationInputs { - signed_txn: Some(txn.to_vec()), + signed_txns: vec![txn.to_vec()], withdrawals: vec![], global_exit_roots: vec![], tries: tries_before, @@ -256,22 +256,21 @@ fn test_log_opcodes() -> anyhow::Result<()> { }, }; + let max_cpu_len_log = 20; let mut timing = TimingTree::new("prove", log::Level::Debug); - let proof = prove::(&all_stark, &config, inputs, &mut timing, None)?; - timing.filter(Duration::from_millis(100)).print(); - // Assert that the proof leads to the correct state and receipt roots. - assert_eq!( - proof.public_values.trie_roots_after.state_root, - expected_state_trie_after.hash() - ); + let proofs = prove_all_segments::( + &all_stark, + &config, + inputs, + max_cpu_len_log, + &mut timing, + None, + )?; - assert_eq!( - proof.public_values.trie_roots_after.receipts_root, - receipts_trie.hash() - ); + timing.filter(Duration::from_millis(100)).print(); - verify_proof(&all_stark, proof, &config) + verify_all_proofs(&all_stark, &proofs, &config) } /// Values taken from the block 1000000 of Goerli: https://goerli.etherscan.io/txs?block=1000000 diff --git a/evm_arithmetization/tests/selfdestruct.rs b/evm_arithmetization/tests/selfdestruct.rs index 708646e16..97f41b78d 100644 --- a/evm_arithmetization/tests/selfdestruct.rs +++ b/evm_arithmetization/tests/selfdestruct.rs @@ -5,13 +5,13 @@ use ethereum_types::{Address, BigEndianHash, H256}; use evm_arithmetization::generation::mpt::{AccountRlp, LegacyReceiptRlp}; use evm_arithmetization::generation::{GenerationInputs, TrieInputs}; use evm_arithmetization::proof::{BlockHashes, BlockMetadata, TrieRoots}; -use evm_arithmetization::prover::prove; +use evm_arithmetization::prover::testing::prove_all_segments; use evm_arithmetization::testing_utils::{ beacon_roots_account_nibbles, beacon_roots_contract_from_storage, eth_to_wei, ger_account_nibbles, init_logger, preinitialized_state_and_storage_tries, update_beacon_roots_account_storage, GLOBAL_EXIT_ROOT_ACCOUNT, }; -use evm_arithmetization::verifier::verify_proof; +use evm_arithmetization::verifier::testing::verify_all_proofs; use evm_arithmetization::{AllStark, Node, StarkConfig}; use hex_literal::hex; use keccak_hash::keccak; @@ -152,8 +152,9 @@ fn test_selfdestruct() -> anyhow::Result<()> { transactions_root: transactions_trie.hash(), receipts_root: receipts_trie.hash(), }; + let inputs = GenerationInputs { - signed_txn: Some(txn.to_vec()), + signed_txns: vec![txn.to_vec()], withdrawals: vec![], global_exit_roots: vec![], tries: tries_before, @@ -170,9 +171,19 @@ fn test_selfdestruct() -> anyhow::Result<()> { }, }; + let max_cpu_len_log = 20; let mut timing = TimingTree::new("prove", log::Level::Debug); - let proof = prove::(&all_stark, &config, inputs, &mut timing, None)?; + + let proofs = prove_all_segments::( + &all_stark, + &config, + inputs, + max_cpu_len_log, + &mut timing, + None, + )?; + timing.filter(Duration::from_millis(100)).print(); - verify_proof(&all_stark, proof, &config) + verify_all_proofs(&all_stark, &proofs, &config) } diff --git a/evm_arithmetization/tests/simple_transfer.rs b/evm_arithmetization/tests/simple_transfer.rs index 030b2c3e1..81cd62113 100644 --- a/evm_arithmetization/tests/simple_transfer.rs +++ b/evm_arithmetization/tests/simple_transfer.rs @@ -6,13 +6,13 @@ use ethereum_types::{Address, BigEndianHash, H256, U256}; use evm_arithmetization::generation::mpt::{AccountRlp, LegacyReceiptRlp}; use evm_arithmetization::generation::{GenerationInputs, TrieInputs}; use evm_arithmetization::proof::{BlockHashes, BlockMetadata, TrieRoots}; -use evm_arithmetization::prover::prove; +use evm_arithmetization::prover::testing::prove_all_segments; use evm_arithmetization::testing_utils::{ beacon_roots_account_nibbles, beacon_roots_contract_from_storage, eth_to_wei, ger_account_nibbles, init_logger, preinitialized_state_and_storage_tries, update_beacon_roots_account_storage, GLOBAL_EXIT_ROOT_ACCOUNT, }; -use evm_arithmetization::verifier::verify_proof; +use evm_arithmetization::verifier::testing::verify_all_proofs; use evm_arithmetization::{AllStark, Node, StarkConfig}; use hex_literal::hex; use keccak_hash::keccak; @@ -144,8 +144,9 @@ fn test_simple_transfer() -> anyhow::Result<()> { transactions_root: transactions_trie.hash(), receipts_root: receipts_trie.hash(), }; + let inputs = GenerationInputs { - signed_txn: Some(txn.to_vec()), + signed_txns: vec![txn.to_vec()], withdrawals: vec![], global_exit_roots: vec![], tries: tries_before, @@ -162,9 +163,19 @@ fn test_simple_transfer() -> anyhow::Result<()> { }, }; + let max_cpu_len_log = 20; let mut timing = TimingTree::new("prove", log::Level::Debug); - let proof = prove::(&all_stark, &config, inputs, &mut timing, None)?; + + let proofs = prove_all_segments::( + &all_stark, + &config, + inputs, + max_cpu_len_log, + &mut timing, + None, + )?; + timing.filter(Duration::from_millis(100)).print(); - verify_proof(&all_stark, proof, &config) + verify_all_proofs(&all_stark, &proofs, &config) } diff --git a/evm_arithmetization/tests/two_to_one_block.rs b/evm_arithmetization/tests/two_to_one_block.rs index 5d4922fb4..59f0c377d 100644 --- a/evm_arithmetization/tests/two_to_one_block.rs +++ b/evm_arithmetization/tests/two_to_one_block.rs @@ -113,23 +113,30 @@ fn get_test_block_proof( let dummy1 = dummy_payload(timestamp, false)?; let timing = &mut TimingTree::new(&format!("Blockproof {timestamp}"), log::Level::Info); - let (dummy_proof0, dummy_pv0) = - all_circuits.prove_root(all_stark, config, dummy0, timing, None)?; - all_circuits.verify_root(dummy_proof0.clone())?; - let (dummy_proof1, dummy_pv1) = - all_circuits.prove_root(all_stark, config, dummy1, timing, None)?; - all_circuits.verify_root(dummy_proof1.clone())?; - - let (agg_proof0, pv0) = all_circuits.prove_aggregation( + let dummy0_proof0 = + all_circuits.prove_all_segments(all_stark, config, dummy0, 20, timing, None)?; + let dummy1_proof = + all_circuits.prove_all_segments(all_stark, config, dummy1, 20, timing, None)?; + + let inputs0_proof = all_circuits.prove_segment_aggregation( + false, + &dummy0_proof0[0], + false, + &dummy0_proof0[1], + )?; + let dummy0_proof = + all_circuits.prove_segment_aggregation(false, &dummy1_proof[0], false, &dummy1_proof[1])?; + + let (agg_proof0, pv0) = all_circuits.prove_transaction_aggregation( false, - &dummy_proof0, - dummy_pv0, + &inputs0_proof.proof_with_pis, + inputs0_proof.public_values, false, - &dummy_proof1, - dummy_pv1, + &dummy0_proof.proof_with_pis, + dummy0_proof.public_values, )?; - all_circuits.verify_aggregation(&agg_proof0)?; + all_circuits.verify_txn_aggregation(&agg_proof0)?; // Test retrieved public values from the proof public inputs. let retrieved_public_values0 = PublicValues::from_public_inputs(&agg_proof0.public_inputs); @@ -146,7 +153,7 @@ fn get_test_block_proof( )?; let pv_block = PublicValues::from_public_inputs(&block_proof0.public_inputs); - assert_eq!(block_public_values, pv_block); + assert_eq!(block_public_values, pv_block.into()); Ok(block_proof0) } @@ -161,7 +168,17 @@ fn test_two_to_one_block_aggregation() -> anyhow::Result<()> { let config = StarkConfig::standard_fast_config(); let all_circuits = AllRecursiveCircuits::::new( &all_stark, - &[16..17, 8..10, 14..15, 14..15, 9..10, 12..13, 17..18], + &[ + 16..17, + 9..15, + 12..18, + 14..15, + 9..10, + 12..13, + 17..20, + 16..17, + 7..8, + ], &config, ); diff --git a/evm_arithmetization/tests/withdrawals.rs b/evm_arithmetization/tests/withdrawals.rs index e17b775b1..b1ad4c715 100644 --- a/evm_arithmetization/tests/withdrawals.rs +++ b/evm_arithmetization/tests/withdrawals.rs @@ -5,13 +5,13 @@ use ethereum_types::{H160, H256, U256}; use evm_arithmetization::generation::mpt::AccountRlp; use evm_arithmetization::generation::{GenerationInputs, TrieInputs}; use evm_arithmetization::proof::{BlockHashes, BlockMetadata, TrieRoots}; -use evm_arithmetization::prover::prove; +use evm_arithmetization::prover::testing::prove_all_segments; use evm_arithmetization::testing_utils::{ beacon_roots_account_nibbles, beacon_roots_contract_from_storage, ger_account_nibbles, init_logger, preinitialized_state_and_storage_tries, update_beacon_roots_account_storage, GLOBAL_EXIT_ROOT_ACCOUNT, }; -use evm_arithmetization::verifier::verify_proof; +use evm_arithmetization::verifier::testing::verify_all_proofs; use evm_arithmetization::{AllStark, Node, StarkConfig}; use keccak_hash::keccak; use mpt_trie::nibbles::Nibbles; @@ -85,7 +85,7 @@ fn test_withdrawals() -> anyhow::Result<()> { }; let inputs = GenerationInputs { - signed_txn: None, + signed_txns: vec![], withdrawals, global_exit_roots: vec![], tries: TrieInputs { @@ -107,9 +107,19 @@ fn test_withdrawals() -> anyhow::Result<()> { }, }; + let max_cpu_len_log = 20; let mut timing = TimingTree::new("prove", log::Level::Debug); - let proof = prove::(&all_stark, &config, inputs, &mut timing, None)?; + + let proofs = prove_all_segments::( + &all_stark, + &config, + inputs, + max_cpu_len_log, + &mut timing, + None, + )?; + timing.filter(Duration::from_millis(100)).print(); - verify_proof(&all_stark, proof, &config) + verify_all_proofs(&all_stark, &proofs, &config) } diff --git a/proof_gen/Cargo.toml b/proof_gen/Cargo.toml index 785fbe8b0..07ac0fb9a 100644 --- a/proof_gen/Cargo.toml +++ b/proof_gen/Cargo.toml @@ -14,6 +14,7 @@ log = { workspace = true } paste = { workspace = true } plonky2 = { workspace = true } serde = { workspace = true } +hashbrown = { workspace = true } # Local dependencies evm_arithmetization = { workspace = true } diff --git a/proof_gen/src/constants.rs b/proof_gen/src/constants.rs index 808f9f2b7..e0b84387d 100644 --- a/proof_gen/src/constants.rs +++ b/proof_gen/src/constants.rs @@ -16,3 +16,9 @@ pub(crate) const DEFAULT_KECCAK_SPONGE_RANGE: Range = 9..25; pub(crate) const DEFAULT_LOGIC_RANGE: Range = 12..28; /// Default range to be used for the `MemoryStark` table. pub(crate) const DEFAULT_MEMORY_RANGE: Range = 17..30; +// TODO: adapt the ranges once we have a better idea of the more common ranges +// for the next two STARKs. +/// Default range to be used for the `MemoryBeforeStark` table. +pub(crate) const DEFAULT_MEMORY_BEFORE_RANGE: Range = 8..20; +/// Default range to be used for the `MemoryAfterStark` table. +pub(crate) const DEFAULT_MEMORY_AFTER_RANGE: Range = 16..30; diff --git a/proof_gen/src/lib.rs b/proof_gen/src/lib.rs index 2599f6360..c96491af4 100644 --- a/proof_gen/src/lib.rs +++ b/proof_gen/src/lib.rs @@ -44,10 +44,10 @@ //! This library handles the 3 kinds of proof generations necessary for the //! zkEVM: //! -//! ### Transaction proofs +//! ### Segment proofs //! //! From a `ProverState` and a transaction processed with some metadata in -//! Intermediate Representation, one can obtain a transaction proof by calling +//! Intermediate Representation, one can obtain a segment proof by calling //! the method below: //! //! ```compile_fail @@ -61,10 +61,10 @@ //! The obtained `GeneratedTxnProof` contains the actual proof and some //! additional data to be used when aggregating this transaction with others. //! -//! ### Aggregation proofs +//! ### Segment Aggregation proofs //! //! Two proofs can be aggregated together with a `ProverState`. These `child` -//! proofs can either be transaction proofs, or aggregated proofs themselves. +//! proofs can either be segment proofs, or aggregated proofs themselves. //! This library abstracts their type behind an `AggregatableProof` enum. //! //! ```compile_fail @@ -75,9 +75,25 @@ //! ) -> ProofGenResult { ... } //! ``` //! +//! ### Transaction Aggregation proofs +//! +//! Given a `GeneratedAggProof` corresponding to the entire set of segment +//! proofs within one transaction proof, the prover can wrap it into a +//! `GeneratedBlockProof`. The prover can pass an optional previous transaction +//! proof as argument to the `generate_transaction_agg_proof` method, to combine +//! both statements into one. +//! +//! ```compile_fail +//! pub fn generate_transaction_agg_proof( +//! p_state: &ProverState, +//! prev_opt_parent_b_proof: Option<&GeneratedBlockProof>, +//! curr_block_agg_proof: &GeneratedAggProof, +//! ) -> ProofGenResult { ... } +//! ``` +//! //! ### Block proofs //! -//! Once the prover has obtained a `GeneratedAggProof` corresponding to the +//! Once the prover has obtained a `GeneratedBlockProof` corresponding to the //! entire set of transactions within a block, they can then wrap it into a //! final `GeneratedBlockProof`. The prover can pass an optional previous //! block proof as argument to the `generate_block_proof` method, to combine diff --git a/proof_gen/src/proof_gen.rs b/proof_gen/src/proof_gen.rs index e2a1f0daa..916bc4e52 100644 --- a/proof_gen/src/proof_gen.rs +++ b/proof_gen/src/proof_gen.rs @@ -3,21 +3,25 @@ use std::sync::{atomic::AtomicBool, Arc}; -use evm_arithmetization::{AllStark, GenerationInputs, StarkConfig}; +use evm_arithmetization::{ + fixed_recursive_verifier::ProverOutputData, generation::TrimmedGenerationInputs, + prover::GenerationSegmentData, AllStark, StarkConfig, +}; +use hashbrown::HashMap; use plonky2::{ gates::noop::NoopGate, - iop::witness::PartialWitness, plonk::{circuit_builder::CircuitBuilder, circuit_data::CircuitConfig}, util::timing::TimingTree, }; use crate::{ proof_types::{ - AggregatableBlockProof, AggregatableProof, GeneratedAggBlockProof, GeneratedAggProof, - GeneratedBlockProof, GeneratedTxnProof, + AggregatableBlockProof, BatchAggregatableProof, GeneratedAggBlockProof, + GeneratedBlockProof, GeneratedSegmentAggProof, GeneratedSegmentProof, GeneratedTxnAggProof, + SegmentAggregatableProof, }, prover_state::ProverState, - types::{Config, Field, PlonkyProofIntern, EXTENSION_DEGREE}, + types::{Field, PlonkyProofIntern, EXTENSION_DEGREE}, }; /// A type alias for `Result`. @@ -44,36 +48,86 @@ impl From for ProofGenError { } /// Generates a transaction proof from some IR data. -pub fn generate_txn_proof( +pub fn generate_segment_proof( p_state: &ProverState, - gen_inputs: GenerationInputs, + gen_inputs: TrimmedGenerationInputs, + segment_data: &mut GenerationSegmentData, abort_signal: Option>, -) -> ProofGenResult { - let (intern, p_vals) = p_state +) -> ProofGenResult { + let output_data = p_state .state - .prove_root( + .prove_segment( &AllStark::default(), &StarkConfig::standard_fast_config(), gen_inputs, + segment_data, &mut TimingTree::default(), abort_signal, ) .map_err(|err| err.to_string())?; - Ok(GeneratedTxnProof { p_vals, intern }) + let p_vals = output_data.public_values; + let intern = output_data.proof_with_pis; + Ok(GeneratedSegmentProof { p_vals, intern }) } /// Generates an aggregation proof from two child proofs. /// /// Note that the child proofs may be either transaction or aggregation proofs. -pub fn generate_agg_proof( +/// +/// If a transaction only contains a single segment, this function must still be +/// called to generate a `GeneratedSegmentAggProof`. In that case, you can set +/// `has_dummy` to `true`, and provide an arbitrary proof for the right child. +pub fn generate_segment_agg_proof( p_state: &ProverState, - lhs_child: &AggregatableProof, - rhs_child: &AggregatableProof, -) -> ProofGenResult { - let (intern, p_vals) = p_state + lhs_child: &SegmentAggregatableProof, + rhs_child: &SegmentAggregatableProof, + has_dummy: bool, +) -> ProofGenResult { + if has_dummy { + assert!( + !lhs_child.is_agg(), + "Cannot have a dummy segment with an aggregation." + ); + } + + let lhs_prover_output_data = ProverOutputData { + is_dummy: false, + proof_with_pis: lhs_child.intern().clone(), + public_values: lhs_child.public_values(), + }; + let rhs_prover_output_data = ProverOutputData { + is_dummy: has_dummy, + proof_with_pis: rhs_child.intern().clone(), + public_values: rhs_child.public_values(), + }; + let agg_output_data = p_state .state - .prove_aggregation( + .prove_segment_aggregation( + lhs_child.is_agg(), + &lhs_prover_output_data, + rhs_child.is_agg(), + &rhs_prover_output_data, + ) + .map_err(|err| err.to_string())?; + + let p_vals = agg_output_data.public_values; + let intern = agg_output_data.proof_with_pis; + + Ok(GeneratedSegmentAggProof { p_vals, intern }) +} + +/// Generates a transaction aggregation proof from two child proofs. +/// +/// Note that the child proofs may be either transaction or aggregation proofs. +pub fn generate_transaction_agg_proof( + p_state: &ProverState, + lhs_child: &BatchAggregatableProof, + rhs_child: &BatchAggregatableProof, +) -> ProofGenResult { + let (b_proof_intern, p_vals) = p_state + .state + .prove_transaction_aggregation( lhs_child.is_agg(), lhs_child.intern(), lhs_child.public_values(), @@ -83,7 +137,10 @@ pub fn generate_agg_proof( ) .map_err(|err| err.to_string())?; - Ok(GeneratedAggProof { p_vals, intern }) + Ok(GeneratedTxnAggProof { + p_vals, + intern: b_proof_intern, + }) } /// Generates a block proof. @@ -93,7 +150,7 @@ pub fn generate_agg_proof( pub fn generate_block_proof( p_state: &ProverState, prev_opt_parent_b_proof: Option<&GeneratedBlockProof>, - curr_block_agg_proof: &GeneratedAggProof, + curr_block_agg_proof: &GeneratedTxnAggProof, ) -> ProofGenResult { let b_height = curr_block_agg_proof .p_vals @@ -145,13 +202,6 @@ pub fn dummy_proof() -> ProofGenResult { builder.add_gate(NoopGate, vec![]); let circuit_data = builder.build::<_>(); - let inputs = PartialWitness::new(); - - plonky2::plonk::prover::prove::( - &circuit_data.prover_only, - &circuit_data.common, - inputs, - &mut TimingTree::default(), - ) - .map_err(|e| ProofGenError(e.to_string())) + plonky2::recursion::dummy_circuit::dummy_proof(&circuit_data, HashMap::default()) + .map_err(|e| ProofGenError(e.to_string())) } diff --git a/proof_gen/src/proof_types.rs b/proof_gen/src/proof_types.rs index 037daa1c1..1c1c51edd 100644 --- a/proof_gen/src/proof_types.rs +++ b/proof_gen/src/proof_types.rs @@ -14,26 +14,39 @@ use crate::types::{Hash, Hasher, PlonkyProofIntern}; /// A transaction proof along with its public values, for proper connection with /// contiguous proofs. #[derive(Clone, Debug, Deserialize, Serialize)] -pub struct GeneratedTxnProof { +pub struct GeneratedSegmentProof { /// Public values of this transaction proof. pub p_vals: PublicValues, /// Underlying plonky2 proof. pub intern: PlonkyProofIntern, } -/// An aggregation proof along with its public values, for proper connection -/// with contiguous proofs. +/// A segment aggregation proof along with its public values, for proper +/// connection with contiguous proofs. /// /// Aggregation proofs can represent any contiguous range of two or more -/// transactions, up to an entire block. +/// segments, up to an entire transaction. #[derive(Clone, Debug, Deserialize, Serialize)] -pub struct GeneratedAggProof { +pub struct GeneratedSegmentAggProof { /// Public values of this aggregation proof. pub p_vals: PublicValues, /// Underlying plonky2 proof. pub intern: PlonkyProofIntern, } +/// A transaction aggregation proof along with its public values, for proper +/// connection with contiguous proofs. +/// +/// Transaction agregation proofs can represent any contiguous range of two or +/// more transactions, up to an entire block. +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct GeneratedTxnAggProof { + /// Public values of this transaction aggregation proof. + pub p_vals: PublicValues, + /// Underlying plonky2 proof. + pub intern: PlonkyProofIntern, +} + /// A block proof along with the block height against which this proof ensures /// the validity since the last proof checkpoint. #[derive(Clone, Debug, Deserialize, Serialize)] @@ -59,48 +72,109 @@ pub struct GeneratedAggBlockProof { /// we can combine it into an agg proof. For these cases, we want to abstract /// away whether or not the proof was a txn or agg proof. #[derive(Clone, Debug, Deserialize, Serialize)] -pub enum AggregatableProof { +pub enum SegmentAggregatableProof { + /// The underlying proof is a segment proof. + Seg(GeneratedSegmentProof), + /// The underlying proof is an aggregation proof. + Agg(GeneratedSegmentAggProof), +} + +/// Sometimes we don't care about the underlying proof type and instead only if +/// we can combine it into an agg proof. For these cases, we want to abstract +/// away whether or not the proof was a txn or agg proof. +#[derive(Clone, Debug, Deserialize, Serialize)] +pub enum BatchAggregatableProof { + /// The underlying proof is a segment proof. It first needs to be aggregated + /// with another segment proof, or a dummy one. + Segment(GeneratedSegmentProof), /// The underlying proof is a transaction proof. - Txn(GeneratedTxnProof), + Txn(GeneratedSegmentAggProof), /// The underlying proof is an aggregation proof. - Agg(GeneratedAggProof), + Agg(GeneratedTxnAggProof), } -impl AggregatableProof { +impl SegmentAggregatableProof { pub(crate) fn public_values(&self) -> PublicValues { match self { - AggregatableProof::Txn(info) => info.p_vals.clone(), - AggregatableProof::Agg(info) => info.p_vals.clone(), + SegmentAggregatableProof::Seg(info) => info.p_vals.clone(), + SegmentAggregatableProof::Agg(info) => info.p_vals.clone(), } } pub(crate) const fn is_agg(&self) -> bool { match self { - AggregatableProof::Txn(_) => false, - AggregatableProof::Agg(_) => true, + SegmentAggregatableProof::Seg(_) => false, + SegmentAggregatableProof::Agg(_) => true, } } pub(crate) const fn intern(&self) -> &PlonkyProofIntern { match self { - AggregatableProof::Txn(info) => &info.intern, - AggregatableProof::Agg(info) => &info.intern, + SegmentAggregatableProof::Seg(info) => &info.intern, + SegmentAggregatableProof::Agg(info) => &info.intern, } } } -impl From for AggregatableProof { - fn from(v: GeneratedTxnProof) -> Self { +impl BatchAggregatableProof { + pub(crate) fn public_values(&self) -> PublicValues { + match self { + BatchAggregatableProof::Segment(info) => info.p_vals.clone(), + BatchAggregatableProof::Txn(info) => info.p_vals.clone(), + BatchAggregatableProof::Agg(info) => info.p_vals.clone(), + } + } + + pub(crate) fn is_agg(&self) -> bool { + match self { + BatchAggregatableProof::Segment(_) => false, + BatchAggregatableProof::Txn(_) => false, + BatchAggregatableProof::Agg(_) => true, + } + } + + pub(crate) fn intern(&self) -> &PlonkyProofIntern { + match self { + BatchAggregatableProof::Segment(info) => &info.intern, + BatchAggregatableProof::Txn(info) => &info.intern, + BatchAggregatableProof::Agg(info) => &info.intern, + } + } +} + +impl From for SegmentAggregatableProof { + fn from(v: GeneratedSegmentProof) -> Self { + Self::Seg(v) + } +} + +impl From for SegmentAggregatableProof { + fn from(v: GeneratedSegmentAggProof) -> Self { + Self::Agg(v) + } +} + +impl From for BatchAggregatableProof { + fn from(v: GeneratedSegmentAggProof) -> Self { Self::Txn(v) } } -impl From for AggregatableProof { - fn from(v: GeneratedAggProof) -> Self { +impl From for BatchAggregatableProof { + fn from(v: GeneratedTxnAggProof) -> Self { Self::Agg(v) } } +impl From for BatchAggregatableProof { + fn from(v: SegmentAggregatableProof) -> Self { + match v { + SegmentAggregatableProof::Agg(agg) => BatchAggregatableProof::Txn(agg), + SegmentAggregatableProof::Seg(seg) => BatchAggregatableProof::Segment(seg), + } + } +} + #[derive(Clone, Debug, Deserialize, Serialize)] pub enum AggregatableBlockProof { /// The underlying proof is a single block proof. diff --git a/proof_gen/src/prover_state.rs b/proof_gen/src/prover_state.rs index 338cb52c7..bb3e5656f 100644 --- a/proof_gen/src/prover_state.rs +++ b/proof_gen/src/prover_state.rs @@ -29,6 +29,8 @@ pub struct ProverStateBuilder { pub(crate) keccak_sponge_circuit_size: Range, pub(crate) logic_circuit_size: Range, pub(crate) memory_circuit_size: Range, + pub(crate) memory_before_circuit_size: Range, + pub(crate) memory_after_circuit_size: Range, } impl Default for ProverStateBuilder { @@ -48,6 +50,8 @@ impl Default for ProverStateBuilder { keccak_sponge_circuit_size: DEFAULT_KECCAK_SPONGE_RANGE, logic_circuit_size: DEFAULT_LOGIC_RANGE, memory_circuit_size: DEFAULT_MEMORY_RANGE, + memory_before_circuit_size: DEFAULT_MEMORY_BEFORE_RANGE, + memory_after_circuit_size: DEFAULT_MEMORY_AFTER_RANGE, } } } @@ -73,6 +77,8 @@ impl ProverStateBuilder { define_set_circuit_size_method!(keccak_sponge); define_set_circuit_size_method!(logic); define_set_circuit_size_method!(memory); + define_set_circuit_size_method!(memory_before); + define_set_circuit_size_method!(memory_after); // TODO: Consider adding async version? /// Instantiate the prover state from the builder. Note that this is a very @@ -90,6 +96,8 @@ impl ProverStateBuilder { self.keccak_sponge_circuit_size, self.logic_circuit_size, self.memory_circuit_size, + self.memory_before_circuit_size, + self.memory_after_circuit_size, ], &StarkConfig::standard_fast_config(), ); diff --git a/trace_decoder/benches/block_processing.rs b/trace_decoder/benches/block_processing.rs index d2862d11f..6f3319d94 100644 --- a/trace_decoder/benches/block_processing.rs +++ b/trace_decoder/benches/block_processing.rs @@ -19,16 +19,28 @@ fn criterion_benchmark(c: &mut Criterion) { serde_json::from_slice::(include_bytes!("block_input.json").as_slice()) .unwrap(); - c.bench_function("Block 19778575 processing", |b| { - b.iter_batched( - || prover_input.clone(), - |ProverInput { - block_trace, - other_data, - }| { trace_decoder::entrypoint(block_trace, other_data).unwrap() }, - BatchSize::LargeInput, - ) - }); + let batch_sizes = vec![1, 2, 4, 8]; + + let mut group = c.benchmark_group("Benchmark group"); + + for batch_size in batch_sizes { + let batch_size_string = + format!("Block 19240650 processing, with batch_size = {batch_size}"); + group.bench_function(batch_size_string, |b| { + b.iter_batched( + || prover_input.clone(), + |ProverInput { + block_trace, + other_data, + }| { + trace_decoder::entrypoint(block_trace, other_data, batch_size).unwrap() + }, + BatchSize::LargeInput, + ) + }); + } + + group.finish() } criterion_group!( diff --git a/trace_decoder/src/decoding.rs b/trace_decoder/src/decoding.rs index aa755fadf..ce669a581 100644 --- a/trace_decoder/src/decoding.rs +++ b/trace_decoder/src/decoding.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::{cmp::min, collections::HashMap, ops::Range}; use anyhow::{anyhow, Context as _}; use ethereum_types::{Address, BigEndianHash, H256, U256, U512}; @@ -52,6 +52,7 @@ pub fn into_txn_proof_gen_ir( withdrawals, }: ProcessedBlockTrace, other_data: OtherBlockData, + batch_size: usize, ) -> anyhow::Result> { let mut curr_block_tries = PartialTrieState { state: state.clone(), @@ -67,29 +68,31 @@ pub fn into_txn_proof_gen_ir( gas_used_after: U256::zero(), }; - // Dummy payloads do not increment this accumulator. - // For actual transactions, it will match their position in the block. - let mut txn_idx = 0; + let num_txs = txn_info + .iter() + .map(|tx_info| tx_info.meta.len()) + .sum::(); let mut txn_gen_inputs = txn_info .into_iter() - .map(|txn_info| { - let is_initial_payload = txn_idx == 0; - - let current_idx = txn_idx; - if !txn_info.meta.is_dummy() { - txn_idx += 1; - } + .enumerate() + .map(|(txn_idx, txn_info)| { + let txn_range = + min(txn_idx * batch_size, num_txs)..min(txn_idx * batch_size + batch_size, num_txs); + let is_initial_payload = txn_range.start == 0; process_txn_info( - current_idx, + txn_range.clone(), is_initial_payload, txn_info, &mut curr_block_tries, &mut extra_data, &other_data, ) - .context(format!("at transaction index {}", current_idx)) + .context(format!( + "at transaction range {}..{}", + txn_range.start, txn_range.end + )) }) .collect::>>() .context(format!( @@ -130,7 +133,7 @@ fn update_beacon_block_root_contract_storage( .get_mut(&ADDRESS) .context(format!("missing account storage trie {:x}", ADDRESS))?; - let mut slots_nibbles = vec![]; + let slots_nibbles = nodes_used.storage_accesses.entry(ADDRESS).or_default(); for (ix, val) in [(timestamp_idx, timestamp), (root_idx, calldata)] { // TODO(0xaatif): https://github.com/0xPolygonZero/zk_evm/issues/275 @@ -173,8 +176,6 @@ fn update_beacon_block_root_contract_storage( } } - nodes_used.storage_accesses.push((ADDRESS, slots_nibbles)); - let addr_nibbles = TrieKey::from_hash(ADDRESS); delta_out .additional_state_trie_paths_to_not_hash @@ -238,7 +239,7 @@ fn init_any_needed_empty_storage_tries<'a>( fn create_minimal_partial_tries_needed_by_txn( curr_block_tries: &PartialTrieState, nodes_used_by_txn: &NodesUsedByTxn, - txn_idx: usize, + txn_range: Range, delta_application_out: TrieDeltaApplicationOutput, ) -> anyhow::Result { let state_trie = create_minimal_state_partial_trie( @@ -251,17 +252,17 @@ fn create_minimal_partial_tries_needed_by_txn( .as_hashed_partial_trie() .clone(); - let txn_k = TrieKey::from_txn_ix(txn_idx); + let txn_keys = txn_range.map(TrieKey::from_txn_ix); let transactions_trie = create_trie_subset_wrapped( curr_block_tries.txn.as_hashed_partial_trie(), - [txn_k], + txn_keys.clone(), TrieType::Txn, )?; let receipts_trie = create_trie_subset_wrapped( curr_block_tries.receipt.as_hashed_partial_trie(), - [txn_k], + txn_keys, TrieType::Receipt, )?; @@ -282,11 +283,11 @@ fn create_minimal_partial_tries_needed_by_txn( fn apply_deltas_to_trie_state( trie_state: &mut PartialTrieState, deltas: &NodesUsedByTxn, - meta: &TxnMetaState, + meta: &[TxnMetaState], ) -> anyhow::Result { let mut out = TrieDeltaApplicationOutput::default(); - for (hashed_acc_addr, storage_writes) in &deltas.storage_writes { + for (hashed_acc_addr, storage_writes) in deltas.storage_writes.iter() { let storage_trie = trie_state .storage .get_mut(hashed_acc_addr) @@ -336,22 +337,39 @@ fn apply_deltas_to_trie_state( &trie_state.storage, )?; + trie_state.state.insert_by_key(val_k, account)?; + if is_created { // If the account did not exist prior this transaction, we // need to make sure the transaction didn't revert. - // Check status in the receipt. - let (_, _, receipt) = decode_receipt(&meta.receipt_node_bytes) + // We will check the status of the last receipt that attempted to create the + // account in this batch. + let last_creation_receipt = &meta + .iter() + .rev() + .find(|tx| tx.created_accounts.contains(hashed_acc_addr)) + .expect("We should have found a matching transaction") + .receipt_node_bytes; + + let (_, _, receipt) = decode_receipt(last_creation_receipt) .map_err(|_| anyhow!("couldn't RLP-decode receipt node bytes"))?; if !receipt.status { // The transaction failed, hence any created account should be removed. - trie_state.state.remove(val_k)?; - trie_state.storage.remove(hashed_acc_addr); - continue; + if let Some(remaining_account_key) = + delete_node_and_report_remaining_key_if_branch_collapsed( + trie_state.state.as_mut_hashed_partial_trie_unchecked(), + &val_k, + )? + { + out.additional_state_trie_paths_to_not_hash + .push(remaining_account_key); + trie_state.storage.remove(hashed_acc_addr); + continue; + } } } - trie_state.state.insert_by_key(val_k, account)?; } // Remove any accounts that self-destructed. @@ -440,7 +458,7 @@ fn add_withdrawals_to_txns( .last_mut() .expect("We cannot have an empty list of payloads."); - if last_inputs.signed_txn.is_none() { + if last_inputs.signed_txns.is_empty() { // This is a dummy payload, hence it does not contain yet // state accesses to the withdrawal addresses. let withdrawal_addrs = withdrawals_with_hashed_addrs_iter().map(|(_, h_addr, _)| h_addr); @@ -500,35 +518,42 @@ fn update_trie_state_from_withdrawals<'a>( /// Processes a single transaction in the trace. fn process_txn_info( - txn_idx: usize, + txn_range: Range, is_initial_payload: bool, txn_info: ProcessedTxnInfo, curr_block_tries: &mut PartialTrieState, extra_data: &mut ExtraBlockData, other_data: &OtherBlockData, ) -> anyhow::Result { - log::trace!("Generating proof IR for txn {}...", txn_idx); + log::trace!( + "Generating proof IR for txn {} through {}...", + txn_range.start, + txn_range.end - 1 + ); init_any_needed_empty_storage_tries( &mut curr_block_tries.storage, - txn_info - .nodes_used_by_txn - .storage_accesses - .iter() - .map(|(k, _)| k), + txn_info.nodes_used_by_txn.storage_accesses.keys(), &txn_info.nodes_used_by_txn.accts_with_unaccessed_storage, ); - // For each non-dummy txn, we increment `txn_number_after` by 1, and + + // For each non-dummy txn, we increment `txn_number_after` and // update `gas_used_after` accordingly. - extra_data.txn_number_after += U256::from(!txn_info.meta.is_dummy() as u8); - extra_data.gas_used_after += txn_info.meta.gas_used.into(); + extra_data.txn_number_after += txn_info.meta.len().into(); + extra_data.gas_used_after += txn_info.meta.iter().map(|i| i.gas_used).sum::().into(); // Because we need to run delta application before creating the minimal // sub-tries (we need to detect if deletes collapsed any branches), we need to // do this clone every iteration. let tries_at_start_of_txn = curr_block_tries.clone(); - update_txn_and_receipt_tries(curr_block_tries, &txn_info.meta, txn_idx)?; + for (i, meta) in txn_info.meta.iter().enumerate() { + update_txn_and_receipt_tries( + curr_block_tries, + meta, + extra_data.txn_number_before.as_usize() + i, + )?; + } let mut delta_out = apply_deltas_to_trie_state( curr_block_tries, @@ -553,7 +578,7 @@ fn process_txn_info( let tries = create_minimal_partial_tries_needed_by_txn( &tries_at_start_of_txn, &nodes_used_by_txn, - txn_idx, + txn_range, delta_out, )?; @@ -561,7 +586,11 @@ fn process_txn_info( txn_number_before: extra_data.txn_number_before, gas_used_before: extra_data.gas_used_before, gas_used_after: extra_data.gas_used_after, - signed_txn: txn_info.meta.txn_bytes, + signed_txns: txn_info + .meta + .iter() + .filter_map(|t| t.txn_bytes.clone()) + .collect::>(), withdrawals: Vec::default(), /* Only ever set in a dummy txn at the end of * the block (see `[add_withdrawals_to_txns]` * for more info). */ @@ -637,7 +666,7 @@ fn create_minimal_state_partial_trie( // trie somewhere else! This is a big hack! fn create_minimal_storage_partial_tries<'a>( storage_tries: &HashMap, - accesses_per_account: impl IntoIterator)>, + accesses_per_account: impl IntoIterator)>, additional_storage_trie_paths_to_not_hash: &HashMap>, ) -> anyhow::Result> { accesses_per_account diff --git a/trace_decoder/src/lib.rs b/trace_decoder/src/lib.rs index a71cd38ee..652d25ba4 100644 --- a/trace_decoder/src/lib.rs +++ b/trace_decoder/src/lib.rs @@ -284,6 +284,7 @@ pub struct BlockLevelData { pub fn entrypoint( trace: BlockTrace, other: OtherBlockData, + batch_size: usize, ) -> anyhow::Result> { use anyhow::Context as _; use mpt_trie::partial_trie::PartialTrie as _; @@ -400,10 +401,10 @@ pub fn entrypoint( ) .collect::(); - let last_tx_idx = txn_info.len().saturating_sub(1); + let last_tx_idx = txn_info.len().saturating_sub(1) / batch_size; let mut txn_info = txn_info - .into_iter() + .chunks(batch_size) .enumerate() .map(|(i, t)| { let extra_state_accesses = if last_tx_idx == i { @@ -419,7 +420,8 @@ pub fn entrypoint( Vec::new() }; - t.into_processed_txn_info( + TxnInfo::into_processed_txn_info( + t, &pre_images.tries, &all_accounts_in_pre_images, &extra_state_accesses, @@ -429,7 +431,7 @@ pub fn entrypoint( .collect::, _>>()?; while txn_info.len() < 2 { - txn_info.insert(0, ProcessedTxnInfo::default()); + txn_info.push(ProcessedTxnInfo::default()); } decoding::into_txn_proof_gen_ir( @@ -439,6 +441,7 @@ pub fn entrypoint( withdrawals: other.b_data.withdrawals.clone(), }, other, + batch_size, ) } diff --git a/trace_decoder/src/processed_block_trace.rs b/trace_decoder/src/processed_block_trace.rs index 5dcd9f109..6472b18b0 100644 --- a/trace_decoder/src/processed_block_trace.rs +++ b/trace_decoder/src/processed_block_trace.rs @@ -1,8 +1,9 @@ -use std::collections::{HashMap, HashSet}; +use std::collections::{BTreeSet, HashMap, HashSet}; use anyhow::{bail, Context as _}; use ethereum_types::{Address, H256, U256}; use evm_arithmetization::generation::mpt::{AccountRlp, LegacyReceiptRlp}; +use itertools::Itertools; use zk_evm_common::EMPTY_TRIE_HASH; use crate::typed_mpt::TrieKey; @@ -30,7 +31,7 @@ pub(crate) struct ProcessedBlockTracePreImages { pub(crate) struct ProcessedTxnInfo { pub nodes_used_by_txn: NodesUsedByTxn, pub contract_code_accessed: HashSet>, - pub meta: TxnMetaState, + pub meta: Vec, } /// Code hash mappings that we have constructed from parsing the block @@ -71,7 +72,7 @@ impl FromIterator> for Hash2Code { impl TxnInfo { pub(crate) fn into_processed_txn_info( - self, + tx_infos: &[Self], tries: &PartialTriePreImages, all_accounts_in_pre_image: &[(H256, AccountRlp)], extra_state_accesses: &[H256], @@ -79,122 +80,190 @@ impl TxnInfo { ) -> anyhow::Result { let mut nodes_used_by_txn = NodesUsedByTxn::default(); let mut contract_code_accessed = HashSet::from([vec![]]); // we always "access" empty code + let mut meta = Vec::with_capacity(tx_infos.len()); - for ( - addr, - TxnTrace { - balance, - nonce, - storage_read, - storage_written, - code_usage, - self_destructed, - }, - ) in self.traces - { - let hashed_addr = hash(addr.as_bytes()); - - // record storage changes - let storage_written = storage_written.unwrap_or_default(); - nodes_used_by_txn.storage_accesses.push(( - hashed_addr, - storage_read - .into_iter() - .flatten() - .chain(storage_written.keys().copied()) - .map(|H256(bytes)| TrieKey::from_hash(hash(bytes))) - .collect(), - )); - nodes_used_by_txn.storage_writes.push(( - hashed_addr, - storage_written - .iter() - .map(|(k, v)| (TrieKey::from_hash(*k), rlp::encode(v).to_vec())) - .collect(), - )); - - // record state changes - let state_write = StateWrite { - balance, - nonce, - storage_trie_change: !storage_written.is_empty(), - code_hash: code_usage.as_ref().map(|it| match it { - ContractCodeUsage::Read(hash) => *hash, - ContractCodeUsage::Write(bytes) => hash(bytes), - }), - }; - - if state_write != StateWrite::default() { - // a write occurred - nodes_used_by_txn - .state_writes - .push((hashed_addr, state_write)) - } + let all_accounts: BTreeSet = + all_accounts_in_pre_image.iter().map(|(h, _)| *h).collect(); + + for txn in tx_infos.iter() { + let mut created_accounts = BTreeSet::new(); - let is_precompile = (FIRST_PRECOMPILE_ADDRESS..LAST_PRECOMPILE_ADDRESS) - .contains(&U256::from_big_endian(&addr.0)); - - // Trie witnesses will only include accessed precompile accounts as hash - // nodes if the transaction calling them reverted. If this is the case, we - // shouldn't include them in this transaction's `state_accesses` to allow the - // decoder to build a minimal state trie without hitting any hash node. - if !is_precompile - || tries - .state - .get_by_key(TrieKey::from_hash(hashed_addr)) - .is_some() + for ( + addr, + TxnTrace { + balance, + nonce, + storage_read, + storage_written, + code_usage, + self_destructed, + }, + ) in txn.traces.iter() { - nodes_used_by_txn.state_accesses.push(hashed_addr); - } + let hashed_addr = hash(addr.as_bytes()); + + // record storage changes + let storage_written = storage_written.clone().unwrap_or_default(); + + let storage_read_keys = storage_read + .clone() + .into_iter() + .flat_map(|reads| reads.into_iter()); + + let storage_written_keys = storage_written.keys(); + let storage_access_keys = storage_read_keys.chain(storage_written_keys.copied()); - match code_usage { - Some(ContractCodeUsage::Read(hash)) => { - contract_code_accessed.insert(hash2code.get(hash)?); + if let Some(storage) = nodes_used_by_txn.storage_accesses.get_mut(&hashed_addr) { + storage.extend( + storage_access_keys + .map(|H256(bytes)| TrieKey::from_hash(hash(bytes))) + .collect_vec(), + ) + } else { + nodes_used_by_txn.storage_accesses.insert( + hashed_addr, + storage_access_keys + .map(|H256(bytes)| TrieKey::from_hash(hash(bytes))) + .collect(), + ); + }; + + // record state changes + let state_write = StateWrite { + balance: *balance, + nonce: *nonce, + storage_trie_change: !storage_written.is_empty(), + code_hash: code_usage.as_ref().map(|it| match it { + ContractCodeUsage::Read(hash) => *hash, + ContractCodeUsage::Write(bytes) => hash(bytes), + }), + }; + + if state_write != StateWrite::default() { + // a write occurred + + // Account creations are flagged to handle reverts. + if !all_accounts.contains(&hashed_addr) { + created_accounts.insert(hashed_addr); + } + + // Some edge case may see a contract creation followed by a `SELFDESTRUCT`, with + // then a follow-up transaction within the same batch updating the state of the + // account. If that happens, we should not delete the account after processing + // this batch. + nodes_used_by_txn + .self_destructed_accounts + .remove(&hashed_addr); + + if let Some(existing_state_write) = + nodes_used_by_txn.state_writes.get_mut(&hashed_addr) + { + // The entry already exists, so we update only the relevant fields. + if state_write.balance.is_some() { + existing_state_write.balance = state_write.balance; + } + if state_write.nonce.is_some() { + existing_state_write.nonce = state_write.nonce; + } + if state_write.storage_trie_change { + existing_state_write.storage_trie_change = + state_write.storage_trie_change; + } + if state_write.code_hash.is_some() { + existing_state_write.code_hash = state_write.code_hash; + } + } else { + nodes_used_by_txn + .state_writes + .insert(hashed_addr, state_write); + } } - Some(ContractCodeUsage::Write(code)) => { - contract_code_accessed.insert(code.clone()); - hash2code.insert(code); + + for (k, v) in storage_written.into_iter() { + if let Some(storage) = nodes_used_by_txn.storage_writes.get_mut(&hashed_addr) { + storage.insert(TrieKey::from_hash(k), rlp::encode(&v).to_vec()); + } else { + nodes_used_by_txn.storage_writes.insert( + hashed_addr, + HashMap::from_iter([(TrieKey::from_hash(k), rlp::encode(&v).to_vec())]), + ); + } + } + + let is_precompile = (FIRST_PRECOMPILE_ADDRESS..LAST_PRECOMPILE_ADDRESS) + .contains(&U256::from_big_endian(&addr.0)); + + // Trie witnesses will only include accessed precompile accounts as hash + // nodes if the transaction calling them reverted. If this is the case, we + // shouldn't include them in this transaction's `state_accesses` to allow the + // decoder to build a minimal state trie without hitting any hash node. + if !is_precompile + || tries + .state + .get_by_key(TrieKey::from_hash(hashed_addr)) + .is_some() + { + nodes_used_by_txn.state_accesses.insert(hashed_addr); + } + + match code_usage { + Some(ContractCodeUsage::Read(hash)) => { + contract_code_accessed.insert(hash2code.get(*hash)?); + } + Some(ContractCodeUsage::Write(code)) => { + contract_code_accessed.insert(code.clone()); + hash2code.insert(code.to_vec()); + } + None => {} + } + + if self_destructed.unwrap_or_default() { + nodes_used_by_txn + .self_destructed_accounts + .insert(hashed_addr); } - None => {} } - if self_destructed.unwrap_or_default() { - nodes_used_by_txn.self_destructed_accounts.push(hashed_addr); + for &hashed_addr in extra_state_accesses { + nodes_used_by_txn.state_accesses.insert(hashed_addr); } - } - for &hashed_addr in extra_state_accesses { - nodes_used_by_txn.state_accesses.push(hashed_addr); - } + let accounts_with_storage_accesses = nodes_used_by_txn + .storage_accesses + .iter() + .filter(|(_, slots)| !slots.is_empty()) + .map(|(addr, _)| *addr) + .collect::>(); - let accounts_with_storage_accesses = nodes_used_by_txn - .storage_accesses - .iter() - .filter(|(_, slots)| !slots.is_empty()) - .map(|(addr, _)| *addr) - .collect::>(); + let all_accounts_with_non_empty_storage = all_accounts_in_pre_image + .iter() + .filter(|(_, data)| data.storage_root != EMPTY_TRIE_HASH); - for (addr, state) in all_accounts_in_pre_image { - if state.storage_root != EMPTY_TRIE_HASH - && !accounts_with_storage_accesses.contains(addr) - { - nodes_used_by_txn - .accts_with_unaccessed_storage - .insert(*addr, state.storage_root); - } + let accounts_with_storage_but_no_storage_accesses = all_accounts_with_non_empty_storage + .filter(|&(addr, _data)| !accounts_with_storage_accesses.contains(addr)) + .map(|(addr, data)| (*addr, data.storage_root)); + + nodes_used_by_txn + .accts_with_unaccessed_storage + .extend(accounts_with_storage_but_no_storage_accesses); + + meta.push(TxnMetaState { + txn_bytes: match txn.meta.byte_code.is_empty() { + false => Some(txn.meta.byte_code.clone()), + true => None, + }, + receipt_node_bytes: check_receipt_bytes( + txn.meta.new_receipt_trie_node_byte.clone(), + )?, + gas_used: txn.meta.gas_used, + created_accounts, + }); } Ok(ProcessedTxnInfo { nodes_used_by_txn, contract_code_accessed, - meta: TxnMetaState { - txn_bytes: match self.meta.byte_code.is_empty() { - false => Some(self.meta.byte_code), - true => None, - }, - receipt_node_bytes: check_receipt_bytes(self.meta.new_receipt_trie_node_byte)?, - gas_used: self.meta.gas_used, - }, + meta, }) } } @@ -211,16 +280,15 @@ fn check_receipt_bytes(bytes: Vec) -> anyhow::Result> { /// Note that "*_accesses" includes writes. #[derive(Debug, Default)] pub(crate) struct NodesUsedByTxn { - pub state_accesses: Vec, - pub state_writes: Vec<(H256, StateWrite)>, + pub state_accesses: HashSet, + pub state_writes: HashMap, // Note: All entries in `storage_writes` also appear in `storage_accesses`. - pub storage_accesses: Vec<(H256, Vec)>, - #[allow(clippy::type_complexity)] - pub storage_writes: Vec<(H256, Vec<(TrieKey, Vec)>)>, + pub storage_accesses: HashMap>, + pub storage_writes: HashMap>>, /// Hashed address -> storage root. pub accts_with_unaccessed_storage: HashMap, - pub self_destructed_accounts: Vec, + pub self_destructed_accounts: HashSet, } #[derive(Debug, Default, PartialEq)] @@ -237,10 +305,5 @@ pub(crate) struct TxnMetaState { pub txn_bytes: Option>, pub receipt_node_bytes: Vec, pub gas_used: u64, -} - -impl TxnMetaState { - pub fn is_dummy(&self) -> bool { - self.txn_bytes.is_none() - } + pub created_accounts: BTreeSet, } diff --git a/trace_decoder/tests/trace_decoder_tests.rs b/trace_decoder/tests/trace_decoder_tests.rs index 5875bdf27..3d2ca96b9 100644 --- a/trace_decoder/tests/trace_decoder_tests.rs +++ b/trace_decoder/tests/trace_decoder_tests.rs @@ -9,7 +9,7 @@ use std::{ use alloy::rpc::types::eth::Header; use anyhow::Context as _; -use evm_arithmetization::prover::testing::simulate_execution; +use evm_arithmetization::prover::testing::simulate_execution_all_segments; use evm_arithmetization::GenerationInputs; use itertools::Itertools; use log::info; @@ -82,6 +82,7 @@ fn decode_generation_inputs( let trace_decoder_output = trace_decoder::entrypoint( block_prover_input.block_trace, block_prover_input.other_data.clone(), + 3, ) .context(format!( "Failed to execute trace decoder on block {}", @@ -189,13 +190,15 @@ fn test_parsing_decoding_proving(#[case] test_witness_directory: &str) { // with setting env variable RAYON_NUM_THREADS=. let timing = TimingTree::new( &format!( - "Simulating zkEVM CPU for block {} txn {:?}", + "simulate zkEVM CPU for block {}, txns {:?}..{:?}.", generation_inputs.block_metadata.block_number, + generation_inputs.txn_number_before, generation_inputs.txn_number_before + + generation_inputs.signed_txns.len() ), log::Level::Info, ); - simulate_execution::(generation_inputs)?; + simulate_execution_all_segments::(generation_inputs, 19)?; timing.filter(Duration::from_millis(100)).print(); Ok::<(), anyhow::Error>(()) }) diff --git a/zero_bin/common/Cargo.toml b/zero_bin/common/Cargo.toml index d5a317bf1..2ae22f2e3 100644 --- a/zero_bin/common/Cargo.toml +++ b/zero_bin/common/Cargo.toml @@ -12,6 +12,7 @@ categories.workspace = true directories = "5.0.1" thiserror = { workspace = true } +trace_decoder = { workspace = true } tracing = { workspace = true } proof_gen = { workspace = true } plonky2 = { workspace = true } diff --git a/zero_bin/common/src/prover_state/circuit.rs b/zero_bin/common/src/prover_state/circuit.rs index f23d6ebe9..94596c8c9 100644 --- a/zero_bin/common/src/prover_state/circuit.rs +++ b/zero_bin/common/src/prover_state/circuit.rs @@ -13,7 +13,7 @@ use crate::parsing::{parse_range_exclusive, RangeParseError}; /// Number of tables defined in plonky2. /// /// TODO: This should be made public in the evm_arithmetization crate. -pub(crate) const NUM_TABLES: usize = 7; +pub(crate) const NUM_TABLES: usize = 9; /// New type wrapper for [`Range`] that implements [`FromStr`] and [`Display`]. /// @@ -66,6 +66,8 @@ pub enum Circuit { KeccakSponge, Logic, Memory, + MemoryBefore, + MemoryAfter, } impl Display for Circuit { @@ -85,6 +87,8 @@ impl Circuit { Circuit::KeccakSponge => 9..15, Circuit::Logic => 12..18, Circuit::Memory => 17..28, + Circuit::MemoryBefore => 7..23, + Circuit::MemoryAfter => 7..27, } } @@ -98,6 +102,8 @@ impl Circuit { Circuit::KeccakSponge => "KECCAK_SPONGE_CIRCUIT_SIZE", Circuit::Logic => "LOGIC_CIRCUIT_SIZE", Circuit::Memory => "MEMORY_CIRCUIT_SIZE", + Circuit::MemoryBefore => "MEMORY_BEFORE_CIRCUIT_SIZE", + Circuit::MemoryAfter => "MEMORY_AFTER_CIRCUIT_SIZE", } } @@ -111,6 +117,8 @@ impl Circuit { Circuit::KeccakSponge => "keccak sponge", Circuit::Logic => "logic", Circuit::Memory => "memory", + Circuit::MemoryBefore => "memory before", + Circuit::MemoryAfter => "memory after", } } @@ -124,6 +132,8 @@ impl Circuit { Circuit::KeccakSponge => "ks", Circuit::Logic => "l", Circuit::Memory => "m", + Circuit::MemoryBefore => "m_b", + Circuit::MemoryAfter => "m_a", } } } @@ -138,6 +148,8 @@ impl From for Circuit { 4 => Circuit::KeccakSponge, 5 => Circuit::Logic, 6 => Circuit::Memory, + 7 => Circuit::MemoryBefore, + 8 => Circuit::MemoryAfter, _ => unreachable!(), } } @@ -175,6 +187,8 @@ impl Default for CircuitConfig { Circuit::KeccakSponge.default_size(), Circuit::Logic.default_size(), Circuit::Memory.default_size(), + Circuit::MemoryBefore.default_size(), + Circuit::MemoryAfter.default_size(), ], } } diff --git a/zero_bin/common/src/prover_state/cli.rs b/zero_bin/common/src/prover_state/cli.rs index 5355d7f4e..b0bebb331 100644 --- a/zero_bin/common/src/prover_state/cli.rs +++ b/zero_bin/common/src/prover_state/cli.rs @@ -84,7 +84,9 @@ gen_prover_state_config!( keccak: Circuit::Keccak, keccak_sponge: Circuit::KeccakSponge, logic: Circuit::Logic, - memory: Circuit::Memory + memory: Circuit::Memory, + mem_before: Circuit::MemoryBefore, + mem_after: Circuit::MemoryAfter ); impl CliProverStateConfig { @@ -99,6 +101,8 @@ impl CliProverStateConfig { (Circuit::KeccakSponge, self.keccak_sponge), (Circuit::Logic, self.logic), (Circuit::Memory, self.memory), + (Circuit::MemoryBefore, self.mem_before), + (Circuit::MemoryAfter, self.mem_after), ] .into_iter() .filter_map(|(circuit, range)| range.map(|range| (circuit, range))) diff --git a/zero_bin/common/src/prover_state/mod.rs b/zero_bin/common/src/prover_state/mod.rs index aacd7c12e..638ca20bb 100644 --- a/zero_bin/common/src/prover_state/mod.rs +++ b/zero_bin/common/src/prover_state/mod.rs @@ -15,13 +15,17 @@ use std::{fmt::Display, sync::OnceLock}; use clap::ValueEnum; use evm_arithmetization::{ - proof::AllProof, prover::prove, AllStark, GenerationInputs, StarkConfig, + fixed_recursive_verifier::ProverOutputData, + generation::TrimmedGenerationInputs, + proof::AllProof, + prover::{prove, GenerationSegmentData}, + AllStark, StarkConfig, }; use plonky2::{ field::goldilocks_field::GoldilocksField, plonk::config::PoseidonGoldilocksConfig, util::timing::TimingTree, }; -use proof_gen::{proof_types::GeneratedTxnProof, prover_state::ProverState, VerifierState}; +use proof_gen::{proof_types::GeneratedSegmentProof, prover_state::ProverState, VerifierState}; use tracing::info; use self::circuit::{CircuitConfig, NUM_TABLES}; @@ -182,42 +186,67 @@ impl ProverStateManager { circuit!(4), circuit!(5), circuit!(6), + circuit!(7), + circuit!(8), ]) } - /// Generate a transaction proof using the specified input, loading the - /// circuit tables as needed to shrink the individual STARK proofs, and - /// finally aggregating them to a final transaction proof. - fn txn_proof_on_demand(&self, input: GenerationInputs) -> anyhow::Result { + /// Generate a segment proof using the specified input, loading + /// the circuit tables as needed to shrink the individual STARK proofs, + /// and finally aggregating them to a final transaction proof. + fn segment_proof_on_demand( + &self, + input: TrimmedGenerationInputs, + segment_data: &mut GenerationSegmentData, + ) -> anyhow::Result { let config = StarkConfig::standard_fast_config(); let all_stark = AllStark::default(); - let all_proof = prove(&all_stark, &config, input, &mut TimingTree::default(), None)?; + + let all_proof = prove( + &all_stark, + &config, + input, + segment_data, + &mut TimingTree::default(), + None, + )?; let table_circuits = self.load_table_circuits(&config, &all_proof)?; let (intern, p_vals) = p_state() .state - .prove_root_after_initial_stark(all_proof, &table_circuits, None)?; + .prove_segment_after_initial_stark(all_proof, &table_circuits, None)?; - Ok(GeneratedTxnProof { intern, p_vals }) + Ok(GeneratedSegmentProof { p_vals, intern }) } - /// Generate a transaction proof using the specified input on the monolithic + /// Generate a segment proof using the specified input on the monolithic /// circuit. - fn txn_proof_monolithic(&self, input: GenerationInputs) -> anyhow::Result { - let (intern, p_vals) = p_state().state.prove_root( + fn segment_proof_monolithic( + &self, + input: TrimmedGenerationInputs, + segment_data: &mut GenerationSegmentData, + ) -> anyhow::Result { + let p_out = p_state().state.prove_segment( &AllStark::default(), &StarkConfig::standard_fast_config(), input, + segment_data, &mut TimingTree::default(), None, )?; - Ok(GeneratedTxnProof { p_vals, intern }) + let ProverOutputData { + is_dummy: _, + proof_with_pis: intern, + public_values: p_vals, + } = p_out; + + Ok(GeneratedSegmentProof { p_vals, intern }) } - /// Generate a transaction proof using the specified input. + /// Generate a segment proof using the specified input. /// /// The specific implementation depends on the persistence strategy. /// - If the persistence strategy is [`CircuitPersistence::None`] or @@ -226,15 +255,20 @@ impl ProverStateManager { /// - If the persistence strategy is [`CircuitPersistence::Disk`] with /// [`TableLoadStrategy::OnDemand`], the table circuits are loaded as /// needed. - pub fn generate_txn_proof(&self, input: GenerationInputs) -> anyhow::Result { + pub fn generate_segment_proof( + &self, + input: (TrimmedGenerationInputs, GenerationSegmentData), + ) -> anyhow::Result { + let (generation_inputs, mut segment_data) = input; + match self.persistence { CircuitPersistence::None | CircuitPersistence::Disk(TableLoadStrategy::Monolithic) => { info!("using monolithic circuit {:?}", self); - self.txn_proof_monolithic(input) + self.segment_proof_monolithic(generation_inputs, &mut segment_data) } CircuitPersistence::Disk(TableLoadStrategy::OnDemand) => { info!("using on demand circuit {:?}", self); - self.txn_proof_on_demand(input) + self.segment_proof_on_demand(generation_inputs, &mut segment_data) } } } diff --git a/zero_bin/leader/Cargo.toml b/zero_bin/leader/Cargo.toml index 5d0881da1..7f3655961 100644 --- a/zero_bin/leader/Cargo.toml +++ b/zero_bin/leader/Cargo.toml @@ -34,7 +34,6 @@ zero_bin_common = { workspace = true } [features] default = [] -test_only = ["ops/test_only", "prover/test_only"] [build-dependencies] cargo_metadata = { workspace = true } diff --git a/zero_bin/leader/src/cli.rs b/zero_bin/leader/src/cli.rs index eb3ea08f8..ccb09fd1f 100644 --- a/zero_bin/leader/src/cli.rs +++ b/zero_bin/leader/src/cli.rs @@ -2,6 +2,7 @@ use std::path::PathBuf; use alloy::transports::http::reqwest::Url; use clap::{Parser, Subcommand, ValueHint}; +use prover::cli::CliProverConfig; use rpc::RpcType; use zero_bin_common::prover_state::cli::CliProverStateConfig; @@ -14,6 +15,9 @@ pub(crate) struct Cli { #[clap(flatten)] pub(crate) paladin: paladin::config::Config, + #[clap(flatten)] + pub(crate) prover_config: CliProverConfig, + // Note this is only relevant for the leader when running in in-memory // mode. #[clap(flatten)] @@ -27,9 +31,6 @@ pub(crate) enum Command { /// The previous proof output. #[arg(long, short = 'f', value_hint = ValueHint::FilePath)] previous_proof: Option, - /// If true, save the public inputs to disk on error. - #[arg(short, long, default_value_t = false)] - save_inputs_on_error: bool, }, /// Reads input from a node rpc and writes output to stdout. Rpc { @@ -52,9 +53,6 @@ pub(crate) enum Command { /// stdout. #[arg(long, short = 'o', value_hint = ValueHint::FilePath)] proof_output_dir: Option, - /// If true, save the public inputs to disk on error. - #[arg(short, long, default_value_t = false)] - save_inputs_on_error: bool, /// Network block time in milliseconds. This value is used /// to determine the blockchain node polling interval. #[arg(short, long, env = "ZERO_BIN_BLOCK_TIME", default_value_t = 2000)] @@ -83,8 +81,5 @@ pub(crate) enum Command { /// The directory to which output should be written. #[arg(short, long, value_hint = ValueHint::DirPath)] output_dir: PathBuf, - /// If true, save the public inputs to disk on error. - #[arg(short, long, default_value_t = false)] - save_inputs_on_error: bool, }, } diff --git a/zero_bin/leader/src/client.rs b/zero_bin/leader/src/client.rs index 8fbcf1bd8..ecf8a969c 100644 --- a/zero_bin/leader/src/client.rs +++ b/zero_bin/leader/src/client.rs @@ -7,6 +7,7 @@ use alloy::transports::http::reqwest::Url; use anyhow::Result; use paladin::runtime::Runtime; use proof_gen::proof_types::GeneratedBlockProof; +use prover::ProverConfig; use rpc::{retry::build_http_retry_provider, RpcType}; use tracing::{error, info, warn}; use zero_bin_common::block_interval::BlockInterval; @@ -20,12 +21,12 @@ pub struct RpcParams { pub max_retries: u32, } -#[derive(Debug, Default)] +#[derive(Debug)] pub struct ProofParams { pub checkpoint_block_number: u64, pub previous_proof: Option, pub proof_output_dir: Option, - pub save_inputs_on_error: bool, + pub prover_config: ProverConfig, pub keep_intermediate_proofs: bool, } @@ -78,54 +79,56 @@ pub(crate) async fn client_main( block_prover_inputs, &runtime, params.previous_proof.take(), - params.save_inputs_on_error, + params.prover_config, params.proof_output_dir.clone(), ) .await; runtime.close().await?; let proved_blocks = proved_blocks?; - if cfg!(feature = "test_only") { + if params.prover_config.test_only { info!("All proof witnesses have been generated successfully."); } else { info!("All proofs have been generated successfully."); } - if params.keep_intermediate_proofs { - if params.proof_output_dir.is_some() { - // All proof files (including intermediary) are written to disk and kept - warn!("Skipping cleanup, intermediate proof files are kept"); + if !params.prover_config.test_only { + if params.keep_intermediate_proofs { + if params.proof_output_dir.is_some() { + // All proof files (including intermediary) are written to disk and kept + warn!("Skipping cleanup, intermediate proof files are kept"); + } else { + // Output all proofs to stdout + std::io::stdout().write_all(&serde_json::to_vec( + &proved_blocks + .into_iter() + .filter_map(|(_, block)| block) + .collect::>(), + )?)?; + } + } else if let Some(proof_output_dir) = params.proof_output_dir.as_ref() { + // Remove intermediary proof files + proved_blocks + .into_iter() + .rev() + .skip(1) + .map(|(block_number, _)| { + generate_block_proof_file_name(&proof_output_dir.to_str(), block_number) + }) + .for_each(|path| { + if let Err(e) = std::fs::remove_file(path) { + error!("Failed to remove intermediate proof file: {e}"); + } + }); } else { - // Output all proofs to stdout - std::io::stdout().write_all(&serde_json::to_vec( - &proved_blocks - .into_iter() - .filter_map(|(_, block)| block) - .collect::>(), - )?)?; - } - } else if let Some(proof_output_dir) = params.proof_output_dir.as_ref() { - // Remove intermediary proof files - proved_blocks - .into_iter() - .rev() - .skip(1) - .map(|(block_number, _)| { - generate_block_proof_file_name(&proof_output_dir.to_str(), block_number) - }) - .for_each(|path| { - if let Err(e) = std::fs::remove_file(path) { - error!("Failed to remove intermediate proof file: {e}"); - } - }); - } else { - // Output only last proof to stdout - if let Some(last_block) = proved_blocks - .into_iter() - .filter_map(|(_, block)| block) - .last() - { - std::io::stdout().write_all(&serde_json::to_vec(&last_block)?)?; + // Output only last proof to stdout + if let Some(last_block) = proved_blocks + .into_iter() + .filter_map(|(_, block)| block) + .last() + { + std::io::stdout().write_all(&serde_json::to_vec(&last_block)?)?; + } } } diff --git a/zero_bin/leader/src/http.rs b/zero_bin/leader/src/http.rs index 63f9543f5..39c7333e1 100644 --- a/zero_bin/leader/src/http.rs +++ b/zero_bin/leader/src/http.rs @@ -5,7 +5,7 @@ use anyhow::{bail, Result}; use axum::{http::StatusCode, routing::post, Json, Router}; use paladin::runtime::Runtime; use proof_gen::proof_types::GeneratedBlockProof; -use prover::BlockProverInput; +use prover::{BlockProverInput, ProverConfig}; use serde::{Deserialize, Serialize}; use serde_json::to_writer; use tracing::{debug, error, info}; @@ -15,7 +15,7 @@ pub(crate) async fn http_main( runtime: Runtime, port: u16, output_dir: PathBuf, - save_inputs_on_error: bool, + prover_config: ProverConfig, ) -> Result<()> { let addr = SocketAddr::from(([0, 0, 0, 0], port)); debug!("listening on {}", addr); @@ -25,7 +25,7 @@ pub(crate) async fn http_main( "/prove", post({ let runtime = runtime.clone(); - move |body| prove(body, runtime, output_dir.clone(), save_inputs_on_error) + move |body| prove(body, runtime, output_dir.clone(), prover_config) }), ); let listener = tokio::net::TcpListener::bind(&addr).await?; @@ -65,21 +65,33 @@ async fn prove( Json(payload): Json, runtime: Arc, output_dir: PathBuf, - save_inputs_on_error: bool, + prover_config: ProverConfig, ) -> StatusCode { debug!("Received payload: {:#?}", payload); let block_number = payload.prover_input.get_block_number(); - match payload - .prover_input - .prove( - &runtime, - payload.previous.map(futures::future::ok), - save_inputs_on_error, - ) - .await - { + let proof_res = if prover_config.test_only { + payload + .prover_input + .prove_test( + &runtime, + payload.previous.map(futures::future::ok), + prover_config, + ) + .await + } else { + payload + .prover_input + .prove( + &runtime, + payload.previous.map(futures::future::ok), + prover_config, + ) + .await + }; + + match proof_res { Ok(b_proof) => match write_to_file(output_dir, block_number, &b_proof) { Ok(file) => { info!("Successfully wrote proof to {}", file.display()); diff --git a/zero_bin/leader/src/main.rs b/zero_bin/leader/src/main.rs index eb250d199..32c9baa20 100644 --- a/zero_bin/leader/src/main.rs +++ b/zero_bin/leader/src/main.rs @@ -9,6 +9,7 @@ use dotenvy::dotenv; use ops::register; use paladin::runtime::Runtime; use proof_gen::proof_types::GeneratedBlockProof; +use prover::ProverConfig; use tracing::{info, warn}; use zero_bin_common::{ block_interval::BlockInterval, prover_state::persistence::set_circuit_cache_dir_env_if_not_set, @@ -53,29 +54,27 @@ async fn main() -> Result<()> { } let args = cli::Cli::parse(); - if let paladin::config::Runtime::InMemory = args.paladin.runtime { - // If running in emulation mode, we'll need to initialize the prover - // state here. - args.prover_state_config - .into_prover_state_manager() - .initialize()?; + + let runtime = Runtime::from_config(&args.paladin, register()).await?; + + let prover_config: ProverConfig = args.prover_config.into(); + + // If not in test_only mode and running in emulation mode, we'll need to + // initialize the prover state here. + if !prover_config.test_only { + if let paladin::config::Runtime::InMemory = args.paladin.runtime { + args.prover_state_config + .into_prover_state_manager() + .initialize()?; + } } match args.command { - Command::Stdio { - previous_proof, - save_inputs_on_error, - } => { - let runtime = Runtime::from_config(&args.paladin, register()).await?; + Command::Stdio { previous_proof } => { let previous_proof = get_previous_proof(previous_proof)?; - stdio::stdio_main(runtime, previous_proof, save_inputs_on_error).await?; + stdio::stdio_main(runtime, previous_proof, prover_config).await?; } - Command::Http { - port, - output_dir, - save_inputs_on_error, - } => { - let runtime = Runtime::from_config(&args.paladin, register()).await?; + Command::Http { port, output_dir } => { // check if output_dir exists, is a directory, and is writable let output_dir_metadata = std::fs::metadata(&output_dir); if output_dir_metadata.is_err() { @@ -85,7 +84,7 @@ async fn main() -> Result<()> { panic!("output-dir is not a writable directory"); } - http::http_main(runtime, port, output_dir, save_inputs_on_error).await?; + http::http_main(runtime, port, output_dir, prover_config).await?; } Command::Rpc { rpc_url, @@ -94,7 +93,6 @@ async fn main() -> Result<()> { checkpoint_block_number, previous_proof, proof_output_dir, - save_inputs_on_error, block_time, keep_intermediate_proofs, backoff, @@ -126,7 +124,7 @@ async fn main() -> Result<()> { checkpoint_block_number, previous_proof, proof_output_dir, - save_inputs_on_error, + prover_config, keep_intermediate_proofs, }, ) diff --git a/zero_bin/leader/src/stdio.rs b/zero_bin/leader/src/stdio.rs index 403ea2a6a..88dd20aac 100644 --- a/zero_bin/leader/src/stdio.rs +++ b/zero_bin/leader/src/stdio.rs @@ -3,14 +3,14 @@ use std::io::{Read, Write}; use anyhow::Result; use paladin::runtime::Runtime; use proof_gen::proof_types::GeneratedBlockProof; -use prover::{BlockProverInput, BlockProverInputFuture}; +use prover::{BlockProverInput, BlockProverInputFuture, ProverConfig}; use tracing::info; /// The main function for the stdio mode. pub(crate) async fn stdio_main( runtime: Runtime, previous: Option, - save_inputs_on_error: bool, + prover_config: ProverConfig, ) -> Result<()> { let mut buffer = String::new(); std::io::stdin().read_to_string(&mut buffer)?; @@ -21,18 +21,12 @@ pub(crate) async fn stdio_main( .map(Into::into) .collect::>(); - let proved_blocks = prover::prove( - block_prover_inputs, - &runtime, - previous, - save_inputs_on_error, - None, - ) - .await; + let proved_blocks = + prover::prove(block_prover_inputs, &runtime, previous, prover_config, None).await; runtime.close().await?; let proved_blocks = proved_blocks?; - if cfg!(feature = "test_only") { + if prover_config.test_only { info!("All proof witnesses have been generated successfully."); } else { info!("All proofs have been generated successfully."); diff --git a/zero_bin/ops/Cargo.toml b/zero_bin/ops/Cargo.toml index 05cfecb2e..8975cc8d5 100644 --- a/zero_bin/ops/Cargo.toml +++ b/zero_bin/ops/Cargo.toml @@ -14,10 +14,10 @@ serde = { workspace = true } evm_arithmetization = { workspace = true } proof_gen = { workspace = true } tracing = { workspace = true } +trace_decoder = { workspace = true } keccak-hash = { workspace = true } zero_bin_common = { path = "../common" } [features] default = [] -test_only = [] diff --git a/zero_bin/ops/src/lib.rs b/zero_bin/ops/src/lib.rs index ee369bbac..4de5e0c30 100644 --- a/zero_bin/ops/src/lib.rs +++ b/zero_bin/ops/src/lib.rs @@ -1,41 +1,53 @@ use std::time::Instant; -use evm_arithmetization::{proof::PublicValues, GenerationInputs}; -use keccak_hash::keccak; +use evm_arithmetization::generation::TrimmedGenerationInputs; +use evm_arithmetization::proof::PublicValues; +use evm_arithmetization::{prover::testing::simulate_execution_all_segments, GenerationInputs}; use paladin::{ operation::{FatalError, FatalStrategy, Monoid, Operation, Result}, registry, RemoteExecute, }; +use proof_gen::types::Field; use proof_gen::{ - proof_gen::{generate_agg_proof, generate_block_proof}, - proof_types::{AggregatableProof, GeneratedAggProof, GeneratedBlockProof}, + proof_gen::{generate_block_proof, generate_segment_agg_proof, generate_transaction_agg_proof}, + proof_types::{ + BatchAggregatableProof, GeneratedBlockProof, GeneratedTxnAggProof, SegmentAggregatableProof, + }, }; use serde::{Deserialize, Serialize}; -use tracing::{error, event, info_span, Level}; +use tracing::error; +use tracing::{event, info_span, Level}; use zero_bin_common::{debug_utils::save_inputs_to_disk, prover_state::p_state}; registry!(); #[derive(Deserialize, Serialize, RemoteExecute)] -pub struct TxProof { +pub struct SegmentProof { pub save_inputs_on_error: bool, } -#[cfg(not(feature = "test_only"))] -impl Operation for TxProof { - type Input = GenerationInputs; - type Output = proof_gen::proof_types::AggregatableProof; +impl Operation for SegmentProof { + type Input = evm_arithmetization::AllData; + type Output = proof_gen::proof_types::SegmentAggregatableProof; - fn execute(&self, input: Self::Input) -> Result { - let _span = TxProofSpan::new(&input); + fn execute(&self, all_data: Self::Input) -> Result { + let all_data = + all_data.map_err(|err| FatalError::from_str(&err.0, FatalStrategy::Terminate))?; + + let input = all_data.0.clone(); + let segment_index = all_data.1.segment_index(); + let _span = SegmentProofSpan::new(&input, all_data.1.segment_index()); let proof = if self.save_inputs_on_error { zero_bin_common::prover_state::p_manager() - .generate_txn_proof(input.clone()) + .generate_segment_proof(all_data) .map_err(|err| { if let Err(write_err) = save_inputs_to_disk( format!( - "b{}_txn_{}_input.json", - input.block_metadata.block_number, input.txn_number_before + "b{}_txns_{}..{}-({})_input.json", + input.block_metadata.block_number, + input.txn_number_before, + input.txn_number_before + input.txn_hashes.len(), + segment_index ), input, ) { @@ -46,7 +58,7 @@ impl Operation for TxProof { })? } else { zero_bin_common::prover_state::p_manager() - .generate_txn_proof(input) + .generate_segment_proof(all_data) .map_err(|err| FatalError::from_anyhow(err, FatalStrategy::Terminate))? }; @@ -54,36 +66,35 @@ impl Operation for TxProof { } } -#[cfg(feature = "test_only")] -impl Operation for TxProof { - type Input = GenerationInputs; - type Output = (); +#[derive(Deserialize, Serialize, RemoteExecute)] +pub struct SegmentProofTestOnly { + pub save_inputs_on_error: bool, +} - fn execute(&self, input: Self::Input) -> Result { - let _span = TxProofSpan::new(&input); +impl Operation for SegmentProofTestOnly { + type Input = (GenerationInputs, usize); + type Output = (); + fn execute(&self, inputs: Self::Input) -> Result { if self.save_inputs_on_error { - evm_arithmetization::prover::testing::simulate_execution::( - input.clone(), - ) - .map_err(|err| { + simulate_execution_all_segments::(inputs.0.clone(), inputs.1).map_err(|err| { if let Err(write_err) = save_inputs_to_disk( format!( - "b{}_txn_{}_input.json", - input.block_metadata.block_number, input.txn_number_before + "b{}_txns_{}..{}_input.json", + inputs.0.block_metadata.block_number, + inputs.0.txn_number_before, + inputs.0.txn_number_before + inputs.0.signed_txns.len(), ), - input, + inputs.0, ) { error!("Failed to save txn proof input to disk: {:?}", write_err); } FatalError::from_anyhow(err, FatalStrategy::Terminate) - })?; + })? } else { - evm_arithmetization::prover::testing::simulate_execution::( - input, - ) - .map_err(|err| FatalError::from_anyhow(err, FatalStrategy::Terminate))?; + simulate_execution_all_segments::(inputs.0, inputs.1) + .map_err(|err| FatalError::from_anyhow(err, FatalStrategy::Terminate))?; } Ok(()) @@ -94,37 +105,60 @@ impl Operation for TxProof { /// /// - When created, it starts a span with the transaction proof id. /// - When dropped, it logs the time taken by the transaction proof. -struct TxProofSpan { +struct SegmentProofSpan { _span: tracing::span::EnteredSpan, start: Instant, descriptor: String, } -impl TxProofSpan { +impl SegmentProofSpan { /// Get a unique id for the transaction proof. - fn get_id(ir: &GenerationInputs) -> String { - format!( - "b{} - {}", - ir.block_metadata.block_number, ir.txn_number_before - ) + fn get_id(ir: &TrimmedGenerationInputs, segment_index: usize) -> String { + if ir.txn_hashes.len() == 1 { + format!( + "b{} - {} ({})", + ir.block_metadata.block_number, ir.txn_number_before, segment_index + ) + } else { + format!( + "b{} - {}_{} ({})", + ir.block_metadata.block_number, + ir.txn_number_before, + ir.txn_number_before + ir.txn_hashes.len(), + segment_index + ) + } } /// Get a textual descriptor for the transaction proof. /// - /// Either the hex-encoded hash of the transaction or "Dummy" if the - /// transaction is not present. - fn get_descriptor(ir: &GenerationInputs) -> String { - ir.signed_txn - .as_ref() - .map(|txn| format!("{:x}", keccak(txn))) - .unwrap_or_else(|| "Dummy".to_string()) + /// Either the first 8 characters of the hex-encoded hash of the first and + /// last transactions, or "Dummy" if there is no transaction. + fn get_descriptor(ir: &TrimmedGenerationInputs) -> String { + if ir.txn_hashes.is_empty() { + "Dummy".to_string() + } else if ir.txn_hashes.len() == 1 { + format!("{:x?}", ir.txn_hashes[0]) + } else { + let first_encoding = u64::from_be_bytes(ir.txn_hashes[0].0[0..8].try_into().unwrap()); + let last_encoding = u64::from_be_bytes( + ir.txn_hashes + .last() + .expect("We have at least 2 transactions.") + .0[0..8] + .try_into() + .unwrap(), + ); + + format!("[0x{:x?}..0x{:x?}]", first_encoding, last_encoding) + } } /// Create a new transaction proof span. /// /// When dropped, it logs the time taken by the transaction proof. - fn new(ir: &GenerationInputs) -> Self { - let id = Self::get_id(ir); + fn new(ir: &TrimmedGenerationInputs, segment_index: usize) -> Self { + let id = Self::get_id(ir, segment_index); let span = info_span!("p_gen", id).entered(); let start = Instant::now(); let descriptor = Self::get_descriptor(ir); @@ -136,11 +170,11 @@ impl TxProofSpan { } } -impl Drop for TxProofSpan { +impl Drop for SegmentProofSpan { fn drop(&mut self) { event!( Level::INFO, - "txn proof ({}) took {:?}", + "segment proof ({}) took {:?}", self.descriptor, self.start.elapsed() ); @@ -148,26 +182,97 @@ impl Drop for TxProofSpan { } #[derive(Deserialize, Serialize, RemoteExecute)] -pub struct AggProof { +pub struct SegmentAggProof { pub save_inputs_on_error: bool, } -fn get_agg_proof_public_values(elem: AggregatableProof) -> PublicValues { +fn get_seg_agg_proof_public_values(elem: SegmentAggregatableProof) -> PublicValues { + match elem { + SegmentAggregatableProof::Seg(info) => info.p_vals, + SegmentAggregatableProof::Agg(info) => info.p_vals, + } +} + +impl Monoid for SegmentAggProof { + type Elem = SegmentAggregatableProof; + + fn combine(&self, a: Self::Elem, b: Self::Elem) -> Result { + let result = generate_segment_agg_proof(p_state(), &a, &b, false).map_err(|e| { + if self.save_inputs_on_error { + let pv = vec![ + get_seg_agg_proof_public_values(a), + get_seg_agg_proof_public_values(b), + ]; + if let Err(write_err) = save_inputs_to_disk( + format!( + "b{}_agg_lhs_rhs_inputs.log", + pv[0].block_metadata.block_number + ), + pv, + ) { + error!("Failed to save agg proof inputs to disk: {:?}", write_err); + } + } + + FatalError::from(e) + })?; + + Ok(result.into()) + } + + fn empty(&self) -> Self::Elem { + // Expect that empty blocks are padded. + unimplemented!("empty agg proof") + } +} + +#[derive(Deserialize, Serialize, RemoteExecute)] +pub struct BatchAggProof { + pub save_inputs_on_error: bool, +} +fn get_agg_proof_public_values(elem: BatchAggregatableProof) -> PublicValues { match elem { - AggregatableProof::Txn(info) => info.p_vals, - AggregatableProof::Agg(info) => info.p_vals, + BatchAggregatableProof::Segment(info) => info.p_vals, + BatchAggregatableProof::Txn(info) => info.p_vals, + BatchAggregatableProof::Agg(info) => info.p_vals, } } -impl Monoid for AggProof { - type Elem = AggregatableProof; +impl Monoid for BatchAggProof { + type Elem = BatchAggregatableProof; fn combine(&self, a: Self::Elem, b: Self::Elem) -> Result { - let result = generate_agg_proof(p_state(), &a, &b).map_err(|e| { + let lhs = match a { + BatchAggregatableProof::Segment(segment) => BatchAggregatableProof::from( + generate_segment_agg_proof( + p_state(), + &SegmentAggregatableProof::from(segment.clone()), + &SegmentAggregatableProof::from(segment), + true, + ) + .map_err(FatalError::from)?, + ), + _ => a, + }; + + let rhs = match b { + BatchAggregatableProof::Segment(segment) => BatchAggregatableProof::from( + generate_segment_agg_proof( + p_state(), + &SegmentAggregatableProof::from(segment.clone()), + &SegmentAggregatableProof::from(segment), + true, + ) + .map_err(FatalError::from)?, + ), + _ => b, + }; + + let result = generate_transaction_agg_proof(p_state(), &lhs, &rhs).map_err(|e| { if self.save_inputs_on_error { let pv = vec![ - get_agg_proof_public_values(a), - get_agg_proof_public_values(b), + get_agg_proof_public_values(lhs), + get_agg_proof_public_values(rhs), ]; if let Err(write_err) = save_inputs_to_disk( format!( @@ -199,7 +304,7 @@ pub struct BlockProof { } impl Operation for BlockProof { - type Input = GeneratedAggProof; + type Input = GeneratedTxnAggProof; type Output = GeneratedBlockProof; fn execute(&self, input: Self::Input) -> Result { diff --git a/zero_bin/prover/Cargo.toml b/zero_bin/prover/Cargo.toml index d1eaab7b3..9f2050746 100644 --- a/zero_bin/prover/Cargo.toml +++ b/zero_bin/prover/Cargo.toml @@ -11,10 +11,13 @@ categories.workspace = true [dependencies] serde = { workspace = true } proof_gen = { workspace = true } +plonky2 = { workspace = true } +plonky2_maybe_rayon = { workspace = true } trace_decoder = { workspace = true } tracing = { workspace = true } paladin-core = { workspace = true } anyhow = { workspace = true } +evm_arithmetization = { workspace = true } futures = { workspace = true } alloy.workspace = true tokio = { workspace = true } @@ -23,7 +26,7 @@ ruint = { workspace = true, features = ["num-traits", "primitive-types"] } ops = { workspace = true } zero_bin_common = { workspace = true } num-traits = { workspace = true } +clap = {workspace = true} [features] default = [] -test_only = ["ops/test_only"] diff --git a/zero_bin/prover/src/cli.rs b/zero_bin/prover/src/cli.rs new file mode 100644 index 000000000..a49eb2d1d --- /dev/null +++ b/zero_bin/prover/src/cli.rs @@ -0,0 +1,32 @@ +use clap::Args; + +const HELP_HEADING: &str = "Prover options"; + +/// Represents the main configuration structure for the runtime. +#[derive(Args, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Default)] +pub struct CliProverConfig { + /// The log of the max number of CPU cycles per proof. + #[arg(short, long, help_heading = HELP_HEADING, default_value_t = 19)] + max_cpu_len_log: usize, + /// Number of transactions in a batch to process at once. + #[arg(short, long, help_heading = HELP_HEADING, default_value_t = 10)] + batch_size: usize, + /// If true, save the public inputs to disk on error. + #[arg(short='i', long, help_heading = HELP_HEADING, default_value_t = false)] + save_inputs_on_error: bool, + /// If true, only test the trace decoder and witness generation without + /// generating a proof. + #[arg(long, help_heading = HELP_HEADING, default_value_t = false)] + test_only: bool, +} + +impl From for crate::ProverConfig { + fn from(cli: CliProverConfig) -> Self { + Self { + batch_size: cli.batch_size, + max_cpu_len_log: cli.max_cpu_len_log, + save_inputs_on_error: cli.save_inputs_on_error, + test_only: cli.test_only, + } + } +} diff --git a/zero_bin/prover/src/lib.rs b/zero_bin/prover/src/lib.rs index a30a4d3f3..117bbffd9 100644 --- a/zero_bin/prover/src/lib.rs +++ b/zero_bin/prover/src/lib.rs @@ -1,3 +1,5 @@ +pub mod cli; + use std::future::Future; use std::path::PathBuf; @@ -5,11 +7,7 @@ use alloy::primitives::{BlockNumber, U256}; use anyhow::{Context, Result}; use futures::{future::BoxFuture, stream::FuturesOrdered, FutureExt, TryFutureExt, TryStreamExt}; use num_traits::ToPrimitive as _; -use ops::TxProof; -use paladin::{ - directive::{Directive, IndexedStream}, - runtime::Runtime, -}; +use paladin::runtime::Runtime; use proof_gen::proof_types::GeneratedBlockProof; use serde::{Deserialize, Serialize}; use tokio::io::AsyncWriteExt; @@ -18,6 +16,14 @@ use trace_decoder::{BlockTrace, OtherBlockData}; use tracing::info; use zero_bin_common::fs::generate_block_proof_file_name; +#[derive(Debug, Clone, Copy)] +pub struct ProverConfig { + pub batch_size: usize, + pub max_cpu_len_log: usize, + pub save_inputs_on_error: bool, + pub test_only: bool, +} + pub type BlockProverInputFuture = std::pin::Pin< Box> + Send>, >; @@ -42,30 +48,71 @@ impl BlockProverInput { self.other_data.b_data.b_meta.block_number.into() } - #[cfg(not(feature = "test_only"))] pub async fn prove( self, runtime: &Runtime, previous: Option>>, - save_inputs_on_error: bool, + prover_config: ProverConfig, ) -> Result { use anyhow::Context as _; + use evm_arithmetization::prover::SegmentDataIterator; + use futures::{stream::FuturesUnordered, FutureExt}; + use paladin::directive::{Directive, IndexedStream}; + + let ProverConfig { + max_cpu_len_log, + batch_size, + save_inputs_on_error, + test_only: _, + } = prover_config; let block_number = self.get_block_number(); - let txs = trace_decoder::entrypoint(self.block_trace, self.other_data)?; + let block_generation_inputs = + trace_decoder::entrypoint(self.block_trace, self.other_data, batch_size)?; - let agg_proof = IndexedStream::from(txs) - .map(&TxProof { - save_inputs_on_error, - }) - .fold(&ops::AggProof { - save_inputs_on_error, + // Create segment proof. + let seg_prove_ops = ops::SegmentProof { + save_inputs_on_error, + }; + + // Aggregate multiple segment proofs to resulting segment proof. + let seg_agg_ops = ops::SegmentAggProof { + save_inputs_on_error, + }; + + // Aggregate batch proofs to a single proof. + let batch_agg_ops = ops::BatchAggProof { + save_inputs_on_error, + }; + + // Segment the batches, prove segments and aggregate them to resulting batch + // proofs. + let batch_proof_futs: FuturesUnordered<_> = block_generation_inputs + .iter() + .enumerate() + .map(|(idx, txn_batch)| { + let segment_data_iterator = SegmentDataIterator::::new( + txn_batch, + Some(max_cpu_len_log), + ); + + Directive::map(IndexedStream::from(segment_data_iterator), &seg_prove_ops) + .fold(&seg_agg_ops) + .run(runtime) + .map(move |e| { + e.map(|p| (idx, proof_gen::proof_types::BatchAggregatableProof::from(p))) + }) }) - .run(runtime) - .await?; + .collect(); - if let proof_gen::proof_types::AggregatableProof::Agg(proof) = agg_proof { + // Fold the batch aggregated proof stream into a single proof. + let final_batch_proof = + Directive::fold(IndexedStream::new(batch_proof_futs), &batch_agg_ops) + .run(runtime) + .await?; + + if let proof_gen::proof_types::BatchAggregatableProof::Agg(proof) = final_batch_proof { let block_number = block_number .to_u64() .context("block number overflows u64")?; @@ -83,33 +130,58 @@ impl BlockProverInput { .await?; info!("Successfully proved block {block_number}"); + Ok(block_proof.0) } else { anyhow::bail!("AggProof is is not GeneratedAggProof") } } - #[cfg(feature = "test_only")] - pub async fn prove( + pub async fn prove_test( self, runtime: &Runtime, previous: Option>>, - save_inputs_on_error: bool, + prover_config: ProverConfig, ) -> Result { + use std::iter::repeat; + + use futures::future; + use paladin::directive::{Directive, IndexedStream}; + + let ProverConfig { + max_cpu_len_log, + batch_size, + save_inputs_on_error, + test_only: _, + } = prover_config; + let block_number = self.get_block_number(); info!("Testing witness generation for block {block_number}."); - let txs = trace_decoder::entrypoint(self.block_trace, self.other_data)?; + let block_generation_inputs = + trace_decoder::entrypoint(self.block_trace, self.other_data, batch_size)?; - IndexedStream::from(txs) - .map(&TxProof { - save_inputs_on_error, - }) + let seg_ops = ops::SegmentProofTestOnly { + save_inputs_on_error, + }; + + let simulation = Directive::map( + IndexedStream::from( + block_generation_inputs + .into_iter() + .zip(repeat(max_cpu_len_log)), + ), + &seg_ops, + ); + + simulation .run(runtime) .await? - .try_collect::>() + .try_for_each(|_| future::ok(())) .await?; + info!("Successfully generated witness for block {block_number}."); + // Wait for previous block proof let _prev = match previous { Some(it) => Some(it.await?), @@ -126,14 +198,15 @@ impl BlockProverInput { } } -/// Prove all the blocks in the input. -/// Return the list of block numbers that are proved and if the proof data -/// is not saved to disk, return the generated block proofs as well. +/// Prove all the blocks in the input, or simulate their execution depending on +/// the selected prover configuration. Return the list of block numbers that are +/// proved and if the proof data is not saved to disk, return the generated +/// block proofs as well. pub async fn prove( block_prover_inputs: Vec, runtime: &Runtime, previous_proof: Option, - save_inputs_on_error: bool, + prover_config: ProverConfig, proof_output_dir: Option, ) -> Result)>> { let mut prev: Option>> = @@ -143,7 +216,7 @@ pub async fn prove( for block_prover_input in block_prover_inputs { let (tx, rx) = oneshot::channel::(); let proof_output_dir = proof_output_dir.clone(); - let previos_block_proof = prev.take(); + let previous_block_proof = prev.take(); let fut = async move { // Get the prover input data from the external source (e.g. Erigon node). let block = block_prover_input.await?; @@ -151,29 +224,55 @@ pub async fn prove( info!("Proving block {block_number}"); // Prove the block - let block_proof = block - .prove(runtime, previos_block_proof, save_inputs_on_error) - .then(move |proof| async move { - let proof = proof?; - let block_number = proof.b_height; - - // Write latest generated proof to disk if proof_output_dir is provided - // or alternatively return proof as function result. - let return_proof: Option = - if let Some(output_dir) = proof_output_dir { - write_proof_to_dir(output_dir, &proof).await?; - None - } else { - Some(proof.clone()) - }; - - if tx.send(proof).is_err() { - anyhow::bail!("Failed to send proof"); - } - - Ok((block_number, return_proof)) - }) - .await?; + let block_proof = if prover_config.test_only { + block + .prove_test(runtime, previous_block_proof, prover_config) + .then(move |proof| async move { + let proof = proof?; + let block_number = proof.b_height; + + // Write latest generated proof to disk if proof_output_dir is provided + // or alternatively return proof as function result. + let return_proof: Option = + if let Some(output_dir) = proof_output_dir { + write_proof_to_dir(output_dir, &proof).await?; + None + } else { + Some(proof.clone()) + }; + + if tx.send(proof).is_err() { + anyhow::bail!("Failed to send proof"); + } + + Ok((block_number, return_proof)) + }) + .await? + } else { + block + .prove(runtime, previous_block_proof, prover_config) + .then(move |proof| async move { + let proof = proof?; + let block_number = proof.b_height; + + // Write latest generated proof to disk if proof_output_dir is provided + // or alternatively return proof as function result. + let return_proof: Option = + if let Some(output_dir) = proof_output_dir { + write_proof_to_dir(output_dir, &proof).await?; + None + } else { + Some(proof.clone()) + }; + + if tx.send(proof).is_err() { + anyhow::bail!("Failed to send proof"); + } + + Ok((block_number, return_proof)) + }) + .await? + }; Ok(block_proof) } diff --git a/zero_bin/tools/prove_rpc.sh b/zero_bin/tools/prove_rpc.sh index d7651b65a..9c0f4bad9 100755 --- a/zero_bin/tools/prove_rpc.sh +++ b/zero_bin/tools/prove_rpc.sh @@ -17,23 +17,17 @@ export RUST_LOG=info # See also .cargo/config.toml. export RUSTFLAGS='-C target-cpu=native -Zlinker-features=-lld' -if [[ $8 == "test_only" ]]; then - # Circuit sizes don't matter in test_only mode, so we keep them minimal. - export ARITHMETIC_CIRCUIT_SIZE="16..17" - export BYTE_PACKING_CIRCUIT_SIZE="9..10" - export CPU_CIRCUIT_SIZE="12..13" - export KECCAK_CIRCUIT_SIZE="14..15" - export KECCAK_SPONGE_CIRCUIT_SIZE="9..10" - export LOGIC_CIRCUIT_SIZE="12..13" - export MEMORY_CIRCUIT_SIZE="17..18" -else - export ARITHMETIC_CIRCUIT_SIZE="16..23" - export BYTE_PACKING_CIRCUIT_SIZE="8..21" - export CPU_CIRCUIT_SIZE="12..25" - export KECCAK_CIRCUIT_SIZE="14..20" - export KECCAK_SPONGE_CIRCUIT_SIZE="9..15" - export LOGIC_CIRCUIT_SIZE="12..18" - export MEMORY_CIRCUIT_SIZE="17..28" +# Circuit sizes only matter in non test_only mode. +if ! [[ $8 == "test_only" ]]; then + export ARITHMETIC_CIRCUIT_SIZE="16..21" + export BYTE_PACKING_CIRCUIT_SIZE="8..21" + export CPU_CIRCUIT_SIZE="8..21" + export KECCAK_CIRCUIT_SIZE="4..20" + export KECCAK_SPONGE_CIRCUIT_SIZE="8..17" + export LOGIC_CIRCUIT_SIZE="4..21" + export MEMORY_CIRCUIT_SIZE="17..24" + export MEMORY_BEFORE_CIRCUIT_SIZE="16..23" + export MEMORY_AFTER_CIRCUIT_SIZE="7..23" fi # Force the working directory to always be the `tools/` directory. @@ -108,7 +102,7 @@ fi if [[ $8 == "test_only" ]]; then # test only run echo "Proving blocks ${BLOCK_INTERVAL} in a test_only mode now... (Total: ${TOT_BLOCKS})" - command='cargo r --release --features test_only --bin leader -- --runtime in-memory --load-strategy on-demand rpc --rpc-type "$NODE_RPC_TYPE" --rpc-url "$NODE_RPC_URL" --block-interval $BLOCK_INTERVAL --proof-output-dir $PROOF_OUTPUT_DIR $PREV_PROOF_EXTRA_ARG --backoff "$BACKOFF" --max-retries "$RETRIES" ' + command='cargo r --release --bin leader -- --test-only --runtime in-memory --load-strategy on-demand rpc --rpc-type "$NODE_RPC_TYPE" --rpc-url "$NODE_RPC_URL" --block-interval $BLOCK_INTERVAL --proof-output-dir $PROOF_OUTPUT_DIR $PREV_PROOF_EXTRA_ARG --backoff "$BACKOFF" --max-retries "$RETRIES" ' if [ "$OUTPUT_TO_TERMINAL" = true ]; then eval $command retVal=$? diff --git a/zero_bin/tools/prove_stdio.sh b/zero_bin/tools/prove_stdio.sh index 5e4792f4c..64c140023 100755 --- a/zero_bin/tools/prove_stdio.sh +++ b/zero_bin/tools/prove_stdio.sh @@ -44,44 +44,42 @@ if [[ $INPUT_FILE == "" ]]; then exit 1 fi -if [[ $TEST_ONLY == "test_only" ]]; then - # Circuit sizes don't matter in test_only mode, so we keep them minimal. - export ARITHMETIC_CIRCUIT_SIZE="16..17" - export BYTE_PACKING_CIRCUIT_SIZE="9..10" - export CPU_CIRCUIT_SIZE="12..13" - export KECCAK_CIRCUIT_SIZE="14..15" - export KECCAK_SPONGE_CIRCUIT_SIZE="9..10" - export LOGIC_CIRCUIT_SIZE="12..13" - export MEMORY_CIRCUIT_SIZE="17..18" -else +# Circuit sizes only matter in non test_only mode. +if ! [[ $TEST_ONLY == "test_only" ]]; then if [[ $INPUT_FILE == *"witness_b19807080"* ]]; then # These sizes are configured specifically for block 19807080. Don't use this in other scenarios echo "Using specific circuit sizes for witness_b19807080.json" export ARITHMETIC_CIRCUIT_SIZE="16..18" - export BYTE_PACKING_CIRCUIT_SIZE="11..15" - export CPU_CIRCUIT_SIZE="17..21" - export KECCAK_CIRCUIT_SIZE="14..17" - export KECCAK_SPONGE_CIRCUIT_SIZE="10..13" - export LOGIC_CIRCUIT_SIZE="13..16" - export MEMORY_CIRCUIT_SIZE="19..23" + export BYTE_PACKING_CIRCUIT_SIZE="10..15" + export CPU_CIRCUIT_SIZE="16..20" + export KECCAK_CIRCUIT_SIZE="12..18" + export KECCAK_SPONGE_CIRCUIT_SIZE="8..14" + export LOGIC_CIRCUIT_SIZE="8..17" + export MEMORY_CIRCUIT_SIZE="18..22" + export MEMORY_BEFORE_CIRCUIT_SIZE="16..20" + export MEMORY_AFTER_CIRCUIT_SIZE="7..20" elif [[ $INPUT_FILE == *"witness_b3_b6"* ]]; then # These sizes are configured specifically for custom blocks 3 to 6. Don't use this in other scenarios echo "Using specific circuit sizes for witness_b3_b6.json" - export ARITHMETIC_CIRCUIT_SIZE="16..17" - export BYTE_PACKING_CIRCUIT_SIZE="8..14" - export CPU_CIRCUIT_SIZE="14..19" - export KECCAK_CIRCUIT_SIZE="14..15" - export KECCAK_SPONGE_CIRCUIT_SIZE="10..11" - export LOGIC_CIRCUIT_SIZE="12..13" - export MEMORY_CIRCUIT_SIZE="17..21" + export ARITHMETIC_CIRCUIT_SIZE="16..18" + export BYTE_PACKING_CIRCUIT_SIZE="8..15" + export CPU_CIRCUIT_SIZE="10..20" + export KECCAK_CIRCUIT_SIZE="4..13" + export KECCAK_SPONGE_CIRCUIT_SIZE="8..9" + export LOGIC_CIRCUIT_SIZE="4..14" + export MEMORY_CIRCUIT_SIZE="17..22" + export MEMORY_BEFORE_CIRCUIT_SIZE="17..18" + export MEMORY_AFTER_CIRCUIT_SIZE="7..8" else - export ARITHMETIC_CIRCUIT_SIZE="16..23" + export ARITHMETIC_CIRCUIT_SIZE="16..21" export BYTE_PACKING_CIRCUIT_SIZE="8..21" - export CPU_CIRCUIT_SIZE="12..25" - export KECCAK_CIRCUIT_SIZE="14..20" - export KECCAK_SPONGE_CIRCUIT_SIZE="9..15" - export LOGIC_CIRCUIT_SIZE="12..18" - export MEMORY_CIRCUIT_SIZE="17..28" + export CPU_CIRCUIT_SIZE="8..21" + export KECCAK_CIRCUIT_SIZE="4..20" + export KECCAK_SPONGE_CIRCUIT_SIZE="8..17" + export LOGIC_CIRCUIT_SIZE="4..21" + export MEMORY_CIRCUIT_SIZE="17..24" + export MEMORY_BEFORE_CIRCUIT_SIZE="16..23" + export MEMORY_AFTER_CIRCUIT_SIZE="7..23" fi fi @@ -90,7 +88,7 @@ fi # proof. This is useful for quickly testing decoding and all of the # other non-proving code. if [[ $TEST_ONLY == "test_only" ]]; then - cargo run --release --features test_only --bin leader -- --runtime in-memory --load-strategy on-demand stdio < $INPUT_FILE &> $TEST_OUT_PATH + cargo run --release --bin leader -- --test-only --runtime in-memory --load-strategy on-demand stdio < $INPUT_FILE &> $TEST_OUT_PATH if grep -q 'All proof witnesses have been generated successfully.' $TEST_OUT_PATH; then echo -e "\n\nSuccess - Note this was just a test, not a proof" rm $TEST_OUT_PATH