Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add STARK batching #388

Closed
wants to merge 24 commits into from
Closed
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
292 changes: 139 additions & 153 deletions Cargo.lock

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,10 @@ rpc = { path = "zero_bin/rpc" }
zero_bin_common = { path = "zero_bin/common" }

# plonky2-related dependencies
plonky2 = { git = "https://github.com/0xPolygonZero/plonky2.git", rev = "dc77c77f2b06500e16ad4d7f1c2b057903602eed" }
plonky2_maybe_rayon = "0.2.0"
plonky2_util = { git = "https://github.com/0xPolygonZero/plonky2.git", rev = "dc77c77f2b06500e16ad4d7f1c2b057903602eed" }
starky = { git = "https://github.com/0xPolygonZero/plonky2.git", rev = "dc77c77f2b06500e16ad4d7f1c2b057903602eed" }
plonky2 = { path = "../plonky2/plonky2"}
plonky2_maybe_rayon = { path = "../plonky2/maybe_rayon"}
plonky2_util = { path = "../plonky2/util"}
starky = { path = "../plonky2/starky"}

# proc macro related dependencies
proc-macro2 = "1.0"
Expand Down
94 changes: 94 additions & 0 deletions evm_arithmetization/src/all_stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,18 @@ impl Deref for Table {
/// Number of STARK tables.
pub(crate) const NUM_TABLES: usize = Table::MemAfter as usize + 1;

pub(crate) const TABLE_DEGREES: [usize; NUM_TABLES] = [
20, // Arithmetic
20, // BytePacking,
20, // Cpu,
18, // Keccak,
16, // KeccakSponge,
20, // Logic,
23, // Memory,
22, // MemBefore,
22, // MemAfter,
];

impl Table {
/// Returns all STARK table indices.
pub(crate) const fn all() -> [Self; NUM_TABLES] {
Expand All @@ -120,6 +132,88 @@ impl Table {
Self::MemAfter,
]
}

/// Returns all STARK table indices in descending order of their padded
/// trace degrees.
pub(crate) const fn all_sorted() -> [Self; NUM_TABLES] {
hratoanina marked this conversation as resolved.
Show resolved Hide resolved
let mut sorted_pairs = [(0, Table::Arithmetic); NUM_TABLES];
let mut i = 0;
while i < NUM_TABLES {
sorted_pairs[i] = (TABLE_DEGREES[i], Self::all()[i]);
i += 1;
}

// Simple bubble sort.
let mut i = 0;
while i < NUM_TABLES - 1 {
let mut j = 0;
while j < NUM_TABLES - i - 1 {
let (pair_a, pair_b) = (sorted_pairs[j], sorted_pairs[j + 1]);
if pair_a.0 < pair_b.0 {
sorted_pairs[j] = pair_b;
sorted_pairs[j + 1] = pair_a;
}
j += 1;
}
i += 1;
}

let mut sorted_tables = [Table::Arithmetic; NUM_TABLES];
let mut i = 0;
while i < NUM_TABLES {
sorted_tables[i] = sorted_pairs[i].1;
i += 1;
}

sorted_tables
}

/// Returns the ordered position of the tables. This is the inverse of
/// `all_sorted()`.
pub(crate) const fn table_to_sorted_index() -> [usize; NUM_TABLES] {
let mut res = [0; NUM_TABLES];
let mut i = 0;
while i < NUM_TABLES {
res[Self::all_sorted()[i] as usize] = i;
i += 1;
}

res
}

/// Returns the ordered position of the tables in a batch Merkle tree. Each
/// entry is a couple to account for duplicate sizes.
pub(crate) const fn sorted_index_pair() -> [(usize, usize); NUM_TABLES] {
let mut pairs = [(0, 0); NUM_TABLES];

let mut outer = 0;
let mut inner = 0;
let mut i = 1;
while i < NUM_TABLES {
if Self::all_degree_logs()[i] < Self::all_degree_logs()[i - 1] {
outer += 1;
inner = 0;
} else {
inner += 1;
}
pairs[i] = (outer, inner);
i += 1;
}

pairs
}

/// Returns all STARK padded trace degrees in descending order.
pub(crate) const fn all_degree_logs() -> [usize; NUM_TABLES] {
let mut res = [0; NUM_TABLES];
let mut i = 0;
while i < NUM_TABLES {
res[i] = TABLE_DEGREES[Self::all_sorted()[i] as usize];
i += 1;
}

res
}
}

/// Returns all the `CrossTableLookups` used for proving the EVM.
Expand Down
3 changes: 2 additions & 1 deletion evm_arithmetization/src/arithmetic/arithmetic_stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ impl<F: RichField, const D: usize> ArithmeticStark<F, D> {
// Pad the trace with zero rows if it doesn't have enough rows
// to accommodate the range check columns. Also make sure the
// trace length is a power of two.
let padded_len = trace_rows.len().next_power_of_two();
let padded_len =
1 << Table::all_degree_logs()[Table::table_to_sorted_index()[*Table::Arithmetic]];
for _ in trace_rows.len()..std::cmp::max(padded_len, RANGE_MAX) {
trace_rows.push(vec![F::ZERO; columns::NUM_ARITH_COLUMNS]);
}
Expand Down
6 changes: 4 additions & 2 deletions evm_arithmetization/src/byte_packing/byte_packing_stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ use starky::lookup::{Column, Filter, Lookup};
use starky::stark::Stark;

use super::NUM_BYTES;
use crate::all_stark::EvmStarkFrame;
use crate::all_stark::{EvmStarkFrame, Table};
use crate::byte_packing::columns::{
index_len, value_bytes, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL, IS_READ, LEN_INDICES_COLS,
NUM_COLUMNS, RANGE_COUNTER, RC_FREQUENCIES, TIMESTAMP,
Expand Down Expand Up @@ -175,7 +175,9 @@ impl<F: RichField + Extendable<D>, const D: usize> BytePackingStark<F, D> {
ops: Vec<BytePackingOp>,
min_rows: usize,
) -> Vec<[F; NUM_COLUMNS]> {
let num_rows = core::cmp::max(ops.len().max(BYTE_RANGE_MAX), min_rows).next_power_of_two();
let num_rows =
1 << Table::all_degree_logs()[Table::table_to_sorted_index()[*Table::BytePacking]];

let mut rows = Vec::with_capacity(num_rows);

for op in ops {
Expand Down
26 changes: 14 additions & 12 deletions evm_arithmetization/src/fixed_recursive_verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -755,26 +755,28 @@ where

// Extra sums to add to the looked last value.
// Only necessary for the Memory values.
let mut extra_looking_sums =
vec![vec![builder.zero(); stark_config.num_challenges]; NUM_TABLES];
let mut extra_looking_sums = HashMap::new();

// Memory
extra_looking_sums[*Table::Memory] = (0..stark_config.num_challenges)
.map(|c| {
get_memory_extra_looking_sum_circuit(
&mut builder,
&public_values,
ctl_challenges.challenges[c],
)
})
.collect_vec();
extra_looking_sums.insert(
Table::Memory as usize,
(0..stark_config.num_challenges)
.map(|c| {
get_memory_extra_looking_sum_circuit(
&mut builder,
&public_values,
ctl_challenges.challenges[c],
)
})
.collect_vec(),
);

// Verify the CTL checks.
verify_cross_table_lookups_circuit::<F, D, NUM_TABLES>(
&mut builder,
all_cross_table_lookups(),
pis.map(|p| p.ctl_zs_first),
Some(&extra_looking_sums),
&extra_looking_sums,
stark_config,
);

Expand Down
8 changes: 5 additions & 3 deletions evm_arithmetization/src/generation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use GlobalMetadata::{
StateTrieRootDigestBefore, TransactionTrieRootDigestAfter, TransactionTrieRootDigestBefore,
};

use crate::all_stark::{AllStark, NUM_TABLES};
use crate::all_stark::{AllStark, Table, NUM_TABLES};
use crate::cpu::columns::CpuColumnsView;
use crate::cpu::kernel::aggregator::KERNEL;
use crate::cpu::kernel::constants::global_metadata::GlobalMetadata;
Expand Down Expand Up @@ -487,10 +487,12 @@ fn simulate_cpu<F: Field>(
row.stack_len = F::from_canonical_usize(state.registers.stack_len);

loop {
// Padding to a power of 2.
// Padding.
state.push_cpu(row);
row.clock += F::ONE;
if state.traces.clock().is_power_of_two() {
if state.traces.clock()
== 1 << Table::all_degree_logs()[Table::table_to_sorted_index()[*Table::Cpu]]
{
break;
}
}
Expand Down
7 changes: 3 additions & 4 deletions evm_arithmetization/src/keccak/keccak_stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use starky::stark::Stark;
use starky::util::trace_rows_to_poly_values;

use super::columns::reg_input_limb;
use crate::all_stark::EvmStarkFrame;
use crate::all_stark::{EvmStarkFrame, Table};
use crate::keccak::columns::{
reg_a, reg_a_prime, reg_a_prime_prime, reg_a_prime_prime_0_0_bit, reg_a_prime_prime_prime,
reg_b, reg_c, reg_c_prime, reg_output_limb, reg_step, NUM_COLUMNS, TIMESTAMP,
Expand Down Expand Up @@ -72,9 +72,8 @@ impl<F: RichField + Extendable<D>, const D: usize> KeccakStark<F, D> {
inputs_and_timestamps: Vec<([u64; NUM_INPUTS], usize)>,
min_rows: usize,
) -> Vec<[F; NUM_COLUMNS]> {
let num_rows = (inputs_and_timestamps.len() * NUM_ROUNDS)
.max(min_rows)
.next_power_of_two();
let num_rows =
1 << Table::all_degree_logs()[Table::table_to_sorted_index()[*Table::Keccak]];

let mut rows = Vec::with_capacity(num_rows);
for input_and_timestamp in inputs_and_timestamps.iter() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use core::marker::PhantomData;
use core::mem::size_of;

use itertools::Itertools;
use num::integer::div_ceil;
use plonky2::field::extension::{Extendable, FieldExtension};
use plonky2::field::packed::PackedField;
use plonky2::field::polynomial::PolynomialValues;
Expand All @@ -18,7 +19,7 @@ use starky::evaluation_frame::StarkEvaluationFrame;
use starky::lookup::{Column, Filter, Lookup};
use starky::stark::Stark;

use crate::all_stark::EvmStarkFrame;
use crate::all_stark::{EvmStarkFrame, Table};
use crate::cpu::kernel::keccak_util::keccakf_u32s;
use crate::keccak_sponge::columns::*;
use crate::witness::memory::MemoryAddress;
Expand Down Expand Up @@ -289,7 +290,8 @@ impl<F: RichField + Extendable<D>, const D: usize> KeccakSpongeStark<F, D> {
rows.extend(self.generate_rows_for_op(op));
}
// Pad the trace.
let padded_rows = rows.len().max(min_num_rows).next_power_of_two();
let padded_rows =
1 << Table::all_degree_logs()[Table::table_to_sorted_index()[*Table::KeccakSponge]];
for _ in rows.len()..padded_rows {
rows.push(self.generate_padding_row());
}
Expand Down
5 changes: 3 additions & 2 deletions evm_arithmetization/src/logic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use starky::lookup::{Column, Filter};
use starky::stark::Stark;
use starky::util::trace_rows_to_poly_values;

use crate::all_stark::EvmStarkFrame;
use crate::all_stark::{EvmStarkFrame, Table};
use crate::logic::columns::{LogicColumnsView, LOGIC_COL_MAP, NUM_COLUMNS};
use crate::util::{limb_from_bits_le, limb_from_bits_le_recursive};

Expand Down Expand Up @@ -220,7 +220,8 @@ impl<F: RichField, const D: usize> LogicStark<F, D> {
min_rows: usize,
) -> Vec<[F; NUM_COLUMNS]> {
let len = operations.len();
let padded_len = len.max(min_rows).next_power_of_two();
let padded_len =
1 << Table::all_degree_logs()[Table::table_to_sorted_index()[*Table::Logic]];

let mut rows = Vec::with_capacity(padded_len);
for op in operations {
Expand Down
3 changes: 2 additions & 1 deletion evm_arithmetization/src/memory/memory_stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,8 @@ impl<F: RichField + Extendable<D>, const D: usize> MemoryStark<F, D> {
let num_ops = memory_ops.len();
// We want at least one padding row, so that the last real operation can have
// its flags set correctly.
let num_ops_padded = (num_ops + 1).next_power_of_two();
let num_ops_padded =
1 << Table::all_degree_logs()[Table::table_to_sorted_index()[*Table::Memory]];
for _ in num_ops..num_ops_padded {
memory_ops.push(padding_op);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use starky::lookup::{Column, Filter, Lookup};
use starky::stark::Stark;

use super::columns::{value_limb, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL, FILTER, NUM_COLUMNS};
use crate::all_stark::EvmStarkFrame;
use crate::all_stark::{EvmStarkFrame, Table};
use crate::generation::MemBeforeValues;
use crate::memory::VALUE_LIMBS;

Expand Down Expand Up @@ -85,7 +85,8 @@ impl<F: RichField + Extendable<D>, const D: usize> MemoryContinuationStark<F, D>
let mut rows = propagated_values;

let num_rows = rows.len();
let num_rows_padded = max(128, num_rows.next_power_of_two());
let num_rows_padded =
1 << Table::all_degree_logs()[Table::table_to_sorted_index()[*Table::MemBefore]];
for _ in num_rows..num_rows_padded {
rows.push(vec![F::ZERO; NUM_COLUMNS]);
}
Expand Down
Loading
Loading