Skip to content

Commit

Permalink
clean up + refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
roynalnaruto committed Nov 5, 2024
1 parent f9d59f2 commit f339984
Show file tree
Hide file tree
Showing 7 changed files with 115 additions and 83 deletions.
40 changes: 40 additions & 0 deletions aggregator/src/aggregation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,43 @@ pub(crate) use rlc::{RlcConfig, POWS_OF_256};

pub use circuit::BatchCircuit;
pub use config::BatchCircuitConfig;
use halo2_base::halo2_proofs::halo2curves::bn256::{Fr, G1Affine};
use snark_verifier::Protocol;

/// Alias for a list of G1 points.
pub type PreprocessedPolyCommits = Vec<G1Affine>;
/// Alias for the transcript's initial state.
pub type TranscriptInitState = Fr;

/// Alias for the fixed part of the protocol which consists of the commitments to the preprocessed
/// polynomials and the initial state of the transcript.
#[derive(Clone)]
pub struct FixedProtocol {
/// The commitments to the preprocessed polynomials.
pub preprocessed: PreprocessedPolyCommits,
/// The initial state of the transcript.
pub init_state: TranscriptInitState,
}

impl From<Protocol<G1Affine>> for FixedProtocol {
fn from(protocol: Protocol<G1Affine>) -> Self {
Self {
preprocessed: protocol.preprocessed,
init_state: protocol
.transcript_initial_state
.expect("protocol transcript init state None"),
}
}
}

impl From<&Protocol<G1Affine>> for FixedProtocol {
fn from(protocol: &Protocol<G1Affine>) -> Self {
Self {
preprocessed: protocol.preprocessed.clone(),
init_state: protocol
.transcript_initial_state
.clone()
.expect("protocol transcript init state None"),
}
}
}
36 changes: 20 additions & 16 deletions aggregator/src/aggregation/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ use snark_verifier::{
fields::{fp::FpConfig, FieldChip},
halo2_base::{
gates::{GateInstructions, RangeInstructions},
utils::fe_to_biguint,
AssignedValue, Context, ContextParams,
QuantumCell::Existing,
},
Expand All @@ -34,12 +33,12 @@ use crate::{
aggregation::{decoder::WORKED_EXAMPLE, witgen::process, BatchCircuitConfig, BatchData},
batch::BatchHash,
blob_consistency::BlobConsistencyConfig,
constants::{ACC_LEN, DIGEST_LEN, FIXED_PROTOCOL_HALO2, FIXED_PROTOCOL_SP1},
constants::{ACC_LEN, DIGEST_LEN},
core::{assign_batch_hashes, extract_proof_and_instances_with_pairing_check},
util::parse_hash_digest_cells,
witgen::{zstd_encode, MultiBlockProcessResult},
ConfigParams, LOG_DEGREE, PI_CHAIN_ID, PI_CURRENT_BATCH_HASH, PI_CURRENT_STATE_ROOT,
PI_CURRENT_WITHDRAW_ROOT, PI_PARENT_BATCH_HASH, PI_PARENT_STATE_ROOT,
ConfigParams, FixedProtocol, LOG_DEGREE, PI_CHAIN_ID, PI_CURRENT_BATCH_HASH,
PI_CURRENT_STATE_ROOT, PI_CURRENT_WITHDRAW_ROOT, PI_PARENT_BATCH_HASH, PI_PARENT_STATE_ROOT,
};

/// Batch circuit, the chunk aggregation routine below recursion circuit
Expand All @@ -63,14 +62,21 @@ pub struct BatchCircuit<const N_SNARKS: usize> {
// batch hash circuit for which the snarks are generated
// the chunks in this batch are also padded already
pub batch_hash: BatchHash<N_SNARKS>,

/// The SNARK protocol from the halo2-based inner circuit route.
pub halo2_protocol: FixedProtocol,
/// The SNARK protocol from the sp1-based inner circuit route.
pub sp1_protocol: FixedProtocol,
}

impl<const N_SNARKS: usize> BatchCircuit<N_SNARKS> {
pub fn new(
pub fn new<P: Into<FixedProtocol>>(
params: &ParamsKZG<Bn256>,
snarks_with_padding: &[Snark],
rng: impl Rng + Send,
batch_hash: BatchHash<N_SNARKS>,
halo2_protocol: P,
sp1_protocol: P,
) -> Result<Self, snark_verifier::Error> {
let timer = start_timer!(|| "generate aggregation circuit");

Expand Down Expand Up @@ -128,6 +134,8 @@ impl<const N_SNARKS: usize> BatchCircuit<N_SNARKS> {
flattened_instances,
as_proof: Value::known(as_proof),
batch_hash,
halo2_protocol: halo2_protocol.into(),
sp1_protocol: sp1_protocol.into(),
})
}

Expand Down Expand Up @@ -252,9 +260,7 @@ impl<const N_SNARKS: usize> Circuit<Fr> for BatchCircuit<N_SNARKS> {
log::trace!("{}-th instance: {:?}", i, e.value)
}

loader
.ctx_mut()
.print_stats(&["snark aggregation"]);
loader.ctx_mut().print_stats(&["snark aggregation"]);

let mut ctx = Rc::into_inner(loader).unwrap().into_ctx();

Expand All @@ -266,11 +272,8 @@ impl<const N_SNARKS: usize> Circuit<Fr> for BatchCircuit<N_SNARKS> {
log::info!("populating constants");
let mut preprocessed_polys_halo2 = Vec::with_capacity(7);
let mut preprocessed_polys_sp1 = Vec::with_capacity(7);
let (fixed_preprocessed_polys_halo2, fixed_transcript_init_state_halo2) =
FIXED_PROTOCOL_HALO2.clone();
let (fixed_preprocessed_polys_sp1, fixed_transcript_init_state_sp1) =
FIXED_PROTOCOL_SP1.clone();
for (i, &preprocessed_poly) in fixed_preprocessed_polys_halo2.iter().enumerate()
for (i, &preprocessed_poly) in
self.halo2_protocol.preprocessed.iter().enumerate()
{
log::debug!("load const {i}");
preprocessed_polys_halo2.push(
Expand All @@ -280,7 +283,8 @@ impl<const N_SNARKS: usize> Circuit<Fr> for BatchCircuit<N_SNARKS> {
);
log::debug!("load const {i} OK");
}
for (i, &preprocessed_poly) in fixed_preprocessed_polys_sp1.iter().enumerate() {
for (i, &preprocessed_poly) in self.sp1_protocol.preprocessed.iter().enumerate()
{
log::debug!("load const (sp1) {i}");
preprocessed_polys_sp1.push(
config
Expand All @@ -294,15 +298,15 @@ impl<const N_SNARKS: usize> Circuit<Fr> for BatchCircuit<N_SNARKS> {
.field_chip()
.range()
.gate()
.assign_constant(&mut ctx, fixed_transcript_init_state_halo2)
.assign_constant(&mut ctx, self.halo2_protocol.init_state)
.expect("IntegerInstructions::assign_constant infallible");
log::debug!("load transcript OK");
let transcript_init_state_sp1 = config
.ecc_chip()
.field_chip()
.range()
.gate()
.assign_constant(&mut ctx, fixed_transcript_init_state_sp1)
.assign_constant(&mut ctx, self.sp1_protocol.init_state)
.expect("IntegerInstructions::assign_constant infallible");
log::info!("populating constants OK");

Expand Down
61 changes: 0 additions & 61 deletions aggregator/src/constants.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
use halo2_proofs::halo2curves::bn256::{Fr, G1Affine};
use std::sync::LazyLock;

// A chain_id is u64 and uses 8 bytes
pub(crate) const CHAIN_ID_LEN: usize = 8;

Expand Down Expand Up @@ -91,61 +88,3 @@ pub const MAX_AGG_SNARKS: usize = 45;

// Number of bytes in a u256.
pub const N_BYTES_U256: usize = 32;

/// Alias for a list of G1 points.
type PreprocessedPolyCommits = Vec<G1Affine>;
/// Alias for the transcript's initial state.
type TranscriptInitState = Fr;
/// Alias for the fixed part of the protocol which consists of the commitments to the preprocessed
/// polynomials and the initial state of the transcript.
type FixedProtocol = (PreprocessedPolyCommits, TranscriptInitState);

/// The [`Batch Circuit`] supports aggregation of up to [`MAX_AGG_SNARKS`] SNARKs, where either
/// SNARK is of 2 kinds, namely:
///
/// 1. halo2-based [`SuperCircuit`] -> [`CompressionCircuit`] (wide) -> `CompressionCircuit` (thin)
/// 2. sp1-based STARK -> halo2-based backend -> `CompressionCircuit` (thin)
///
/// For each SNARK witness provided for aggregation, we require that the commitments to the
/// preprocessed polynomials and the transcript's initial state belong to a fixed set, one
/// belonging to each of the above SNARK kinds.
///
/// Represents the fixed commitments to the preprocessed polynomials and the initial state of the
/// transcript for [`ChunkKind::Halo2`].
pub static FIXED_PROTOCOL_HALO2: LazyLock<FixedProtocol> = LazyLock::new(|| {
let name =
std::env::var("HALO2_CHUNK_PROTOCOL").unwrap_or("chunk_chunk_halo2.protocol".to_string());
let dir =
std::env::var("SCROLL_PROVER_ASSETS_DIR").unwrap_or("./tests/test_assets".to_string());
let path = std::path::Path::new(&dir).join(name);
let file = std::fs::File::open(&path).expect("could not open file");
let reader = std::io::BufReader::new(file);
let protocol: snark_verifier::Protocol<G1Affine> =
serde_json::from_reader(reader).expect("could not deserialise protocol");
(
protocol.preprocessed,
protocol
.transcript_initial_state
.expect("transcript initial state is None"),
)
});

/// Represents the fixed commitments to the preprocessed polynomials and the initial state of the
/// transcript for [`ChunkKind::Sp1`].
pub static FIXED_PROTOCOL_SP1: LazyLock<FixedProtocol> = LazyLock::new(|| {
let name =
std::env::var("SP1_CHUNK_PROTOCOL").unwrap_or("chunk_chunk_sp1.protocol".to_string());
let dir =
std::env::var("SCROLL_PROVER_ASSETS_DIR").unwrap_or("./tests/test_assets".to_string());
let path = std::path::Path::new(&dir).join(name);
let file = std::fs::File::open(&path).expect("could not open file");
let reader = std::io::BufReader::new(file);
let protocol: snark_verifier::Protocol<G1Affine> =
serde_json::from_reader(reader).expect("could not deserialise protocol");
(
protocol.preprocessed,
protocol
.transcript_initial_state
.expect("transcript initial state is None"),
)
});
7 changes: 7 additions & 0 deletions aggregator/src/tests/aggregation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ fn build_new_batch_circuit<const N_SNARKS: usize>(
})
.collect_vec()
};
let snark_protocol = real_snarks[0].protocol.clone();

// ==========================
// padded chunks
Expand All @@ -225,6 +226,8 @@ fn build_new_batch_circuit<const N_SNARKS: usize>(
[real_snarks, padded_snarks].concat().as_ref(),
rng,
batch_hash,
&snark_protocol,
&snark_protocol,
)
.unwrap()
}
Expand Down Expand Up @@ -293,6 +296,8 @@ fn build_batch_circuit_skip_encoding<const N_SNARKS: usize>() -> BatchCircuit<N_
})
.collect_vec()
};
let snark_protocol = real_snarks[0].protocol.clone();

// ==========================
// padded chunks
// ==========================
Expand All @@ -302,6 +307,8 @@ fn build_batch_circuit_skip_encoding<const N_SNARKS: usize>() -> BatchCircuit<N_
[real_snarks, padded_snarks].concat().as_ref(),
rng,
batch_hash,
&snark_protocol,
&snark_protocol,
)
.unwrap()
}
2 changes: 2 additions & 0 deletions prover/src/aggregator/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ impl<'params> Prover<'params> {
LayerId::Layer3.id(),
LayerId::Layer3.degree(),
batch_info,
&self.halo2_protocol,
&self.sp1_protocol,
&layer2_snarks,
output_dir,
)?;
Expand Down
32 changes: 28 additions & 4 deletions prover/src/common/prover/aggregation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use crate::{
};
use aggregator::{BatchCircuit, BatchHash};
use anyhow::{anyhow, Result};
use halo2_proofs::halo2curves::bn256::G1Affine;
use rand::Rng;
use snark_verifier_sdk::Snark;
use std::env;
Expand All @@ -17,13 +18,26 @@ impl<'params> Prover<'params> {
degree: u32,
mut rng: impl Rng + Send,
batch_info: BatchHash<N_SNARKS>,
halo2_protocol: &[u8],
sp1_protocol: &[u8],
previous_snarks: &[Snark],
) -> Result<Snark> {
env::set_var("AGGREGATION_CONFIG", layer_config_path(id));

let circuit: BatchCircuit<N_SNARKS> =
BatchCircuit::new(self.params(degree), previous_snarks, &mut rng, batch_info)
.map_err(|err| anyhow!("Failed to construct aggregation circuit: {err:?}"))?;
let halo2_protocol =
serde_json::from_slice::<snark_verifier::Protocol<G1Affine>>(halo2_protocol)?;
let sp1_protocol =
serde_json::from_slice::<snark_verifier::Protocol<G1Affine>>(sp1_protocol)?;

let circuit: BatchCircuit<N_SNARKS> = BatchCircuit::new(
self.params(degree),
previous_snarks,
&mut rng,
batch_info,
halo2_protocol,
sp1_protocol,
)
.map_err(|err| anyhow!("Failed to construct aggregation circuit: {err:?}"))?;

self.gen_snark(id, degree, &mut rng, circuit, "gen_agg_snark")
}
Expand All @@ -34,6 +48,8 @@ impl<'params> Prover<'params> {
id: &str,
degree: u32,
batch_info: BatchHash<N_SNARKS>,
halo2_protocol: &[u8],
sp1_protocol: &[u8],
previous_snarks: &[Snark],
output_dir: Option<&str>,
) -> Result<Snark> {
Expand All @@ -48,7 +64,15 @@ impl<'params> Prover<'params> {
Some(snark) => Ok(snark),
None => {
let rng = gen_rng();
let result = self.gen_agg_snark(id, degree, rng, batch_info, previous_snarks);
let result = self.gen_agg_snark(
id,
degree,
rng,
batch_info,
halo2_protocol,
sp1_protocol,
previous_snarks,
);
if let (Some(_), Ok(snark)) = (output_dir, &result) {
write_snark(&file_path, snark);
}
Expand Down
20 changes: 18 additions & 2 deletions prover/src/consts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,24 @@ pub fn chunk_vk_filename() -> String {
read_env_var("CHUNK_VK_FILENAME", "vk_chunk.vkey".to_string())
}

pub static CHUNK_PROTOCOL_FILENAME: LazyLock<String> =
LazyLock::new(|| read_env_var("CHUNK_PROTOCOL_FILENAME", "chunk.protocol".to_string()));
/// The file descriptor for the JSON serialised SNARK [`protocol`][protocol] that
/// defines the [`CompressionCircuit`][compr_circuit] SNARK that uses halo2-based
/// [`SuperCircuit`][super_circuit].
///
/// [protocol]: snark_verifier::Protocol
/// [compr_circuit]: aggregator::CompressionCircuit
/// [super_circuit]: zkevm_circuits::super_circuit::SuperCircuit
pub static FD_HALO2_CHUNK_PROTOCOL: LazyLock<String> =
LazyLock::new(|| read_env_var("HALO2_CHUNK_PROTOCOL", "chunk_halo2.protocol".to_string()));

/// The file descriptor for the JSON serialised SNARK [`protocol`][protocol] that
/// defines the [`CompressionCircuit`][compr_circuit] SNARK that uses sp1-based
/// STARK that is SNARKified using a halo2-backend.
///
/// [protocol]: snark_verifier::Protocol
/// [compr_circuit]: aggregator::CompressionCircuit
pub static FD_SP1_CHUNK_PROTOCOL: LazyLock<String> =
LazyLock::new(|| read_env_var("SP1_CHUNK_PROTOCOL", "chunk_sp1.protocol".to_string()));

pub static CHUNK_VK_FILENAME: LazyLock<String> = LazyLock::new(chunk_vk_filename);
pub static BATCH_VK_FILENAME: LazyLock<String> = LazyLock::new(batch_vk_filename);
Expand Down

0 comments on commit f339984

Please sign in to comment.