diff --git a/zero/src/prover.rs b/zero/src/prover.rs index 3916f98d2..794c61878 100644 --- a/zero/src/prover.rs +++ b/zero/src/prover.rs @@ -31,9 +31,8 @@ use crate::ops; // // While proving a block interval, we will output proofs corresponding to block // batches as soon as they are generated. -const PARALLEL_BLOCK_PROVING_PERMIT_POOL_SIZE: usize = 16; -static PARALLEL_BLOCK_PROVING_PERMIT_POOL: Semaphore = - Semaphore::const_new(PARALLEL_BLOCK_PROVING_PERMIT_POOL_SIZE); +const DEFAULT_PARALLEL_BLOCK_PROVING_PERMIT_POOL_SIZE: usize = 16; +static PARALLEL_BLOCK_PROVING_PERMIT_POOL: Semaphore = Semaphore::const_new(0); #[derive(Debug, Clone)] pub struct ProverConfig { @@ -44,6 +43,7 @@ pub struct ProverConfig { pub proof_output_dir: PathBuf, pub keep_intermediate_proofs: bool, pub block_batch_size: usize, + pub block_pool_size: usize, } #[derive(Clone, Debug, Deserialize, Serialize)] @@ -242,6 +242,16 @@ pub async fn prove( let mut task_set: JoinSet< std::result::Result, anyhow::Error>, > = JoinSet::new(); + + if prover_config.block_pool_size > 0 { + PARALLEL_BLOCK_PROVING_PERMIT_POOL.add_permits(prover_config.block_pool_size); + } else { + anyhow::bail!( + "block_pool_size should be greater than 0, value passed from cli is {}", + prover_config.block_pool_size + ); + } + while let Some((block_prover_input, is_last_block)) = block_receiver.recv().await { block_counter += 1; let (tx, rx) = oneshot::channel::(); diff --git a/zero/src/prover/cli.rs b/zero/src/prover/cli.rs index e55141b7a..94ca03ed4 100644 --- a/zero/src/prover/cli.rs +++ b/zero/src/prover/cli.rs @@ -43,6 +43,9 @@ pub struct CliProverConfig { /// generate one proof file. #[arg(long, default_value_t = 8)] block_batch_size: usize, + /// The maximum number of block proving tasks that can run in parallel. + #[arg(long, default_value_t = 16)] + block_pool_size: usize, } impl From for super::ProverConfig { @@ -55,6 +58,7 @@ impl From for super::ProverConfig { proof_output_dir: cli.proof_output_dir, keep_intermediate_proofs: cli.keep_intermediate_proofs, block_batch_size: cli.block_batch_size, + block_pool_size: cli.block_pool_size, } } }