Skip to content

Commit

Permalink
add
Browse files Browse the repository at this point in the history
  • Loading branch information
ratankaliani committed Dec 20, 2024
1 parent 29ac0d5 commit 217437d
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 11 deletions.
42 changes: 36 additions & 6 deletions crates/sdk/src/network/prove.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ pub struct NetworkProveBuilder<'a> {
pub(crate) timeout: Option<Duration>,
pub(crate) strategy: FulfillmentStrategy,
pub(crate) skip_simulation: bool,
pub(crate) cycle_limit: Option<u64>,
}

impl<'a> NetworkProveBuilder<'a> {
Expand Down Expand Up @@ -231,6 +232,35 @@ impl<'a> NetworkProveBuilder<'a> {
self
}

/// Sets the cycle limit for the proof request.
///
/// # Details
/// The cycle limit determines the maximum number of cycles that the program can execute.
/// By default, the cycle limit is determined by simulating the program locally. However,
/// you can manually set it if you know the exact cycle count needed and want to skip the
/// simulation step locally.
///
/// # Example
/// ```rust,no_run
/// use sp1_sdk::{ProverClient, SP1Stdin, Prover};
///
/// let elf = &[1, 2, 3];
/// let stdin = SP1Stdin::new();
///
/// let client = ProverClient::builder().network().build();
/// let (pk, vk) = client.setup(elf);
/// let proof = client.prove(&pk, &stdin)
/// .cycle_limit(1_000_000) // Set 1M cycle limit
/// .skip_simulation(true) // Skip simulation since we set limit manually
/// .run()
/// .unwrap();
/// ```
#[must_use]
pub fn cycle_limit(mut self, cycle_limit: u64) -> Self {
self.cycle_limit = Some(cycle_limit);
self
}

/// Request a proof from the prover network.
///
/// # Details
Expand All @@ -251,8 +281,8 @@ impl<'a> NetworkProveBuilder<'a> {
/// .unwrap();
/// ```
pub async fn request(self) -> Result<Vec<u8>> {
let Self { prover, mode, pk, stdin, timeout, strategy, skip_simulation } = self;
prover.request_proof_impl(pk, &stdin, mode, strategy, timeout, skip_simulation).await
let Self { prover, mode, pk, stdin, timeout, strategy, skip_simulation, cycle_limit } = self;
prover.request_proof_impl(pk, &stdin, mode, strategy, timeout, skip_simulation, cycle_limit).await
}

/// Run the prover with the built arguments.
Expand All @@ -275,7 +305,7 @@ impl<'a> NetworkProveBuilder<'a> {
/// .unwrap();
/// ```
pub fn run(self) -> Result<SP1ProofWithPublicValues> {
let Self { prover, mode, pk, stdin, timeout, strategy, mut skip_simulation } = self;
let Self { prover, mode, pk, stdin, timeout, strategy, mut skip_simulation, cycle_limit } = self;

// Check for deprecated environment variable
if let Ok(val) = std::env::var("SKIP_SIMULATION") {
Expand All @@ -287,7 +317,7 @@ impl<'a> NetworkProveBuilder<'a> {

sp1_dump(&pk.elf, &stdin);

block_on(prover.prove_impl(pk, &stdin, mode, strategy, timeout, skip_simulation))
block_on(prover.prove_impl(pk, &stdin, mode, strategy, timeout, skip_simulation, cycle_limit))
}

/// Run the prover with the built arguments asynchronously.
Expand All @@ -308,7 +338,7 @@ impl<'a> NetworkProveBuilder<'a> {
/// .run_async();
/// ```
pub async fn run_async(self) -> Result<SP1ProofWithPublicValues> {
let Self { prover, mode, pk, stdin, timeout, strategy, mut skip_simulation } = self;
let Self { prover, mode, pk, stdin, timeout, strategy, mut skip_simulation, cycle_limit } = self;

// Check for deprecated environment variable
if let Ok(val) = std::env::var("SKIP_SIMULATION") {
Expand All @@ -320,6 +350,6 @@ impl<'a> NetworkProveBuilder<'a> {

sp1_dump(&pk.elf, &stdin);

prover.prove_impl(pk, &stdin, mode, strategy, timeout, skip_simulation).await
prover.prove_impl(pk, &stdin, mode, strategy, timeout, skip_simulation, cycle_limit).await
}
}
31 changes: 26 additions & 5 deletions crates/sdk/src/network/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ impl NetworkProver {
timeout: None,
strategy: FulfillmentStrategy::Hosted,
skip_simulation: false,
cycle_limit: None,
}
}

Expand Down Expand Up @@ -281,9 +282,10 @@ impl NetworkProver {
strategy: FulfillmentStrategy,
timeout: Option<Duration>,
skip_simulation: bool,
cycle_limit: Option<u64>,
) -> Result<Vec<u8>> {
let vk_hash = self.register_program(&pk.vk, &pk.elf).await?;
let cycle_limit = self.get_cycle_limit(&pk.elf, stdin, skip_simulation)?;
let cycle_limit = self.get_cycle_limit(cycle_limit, &pk.elf, stdin, skip_simulation)?;
self.request_proof(&vk_hash, stdin, mode.into(), strategy, cycle_limit, timeout).await
}

Expand All @@ -296,13 +298,32 @@ impl NetworkProver {
strategy: FulfillmentStrategy,
timeout: Option<Duration>,
skip_simulation: bool,
cycle_limit: Option<u64>,
) -> Result<SP1ProofWithPublicValues> {
let request_id =
self.request_proof_impl(pk, stdin, mode, strategy, timeout, skip_simulation).await?;
let request_id = self
.request_proof_impl(pk, stdin, mode, strategy, timeout, skip_simulation, cycle_limit)
.await?;
self.wait_proof(&request_id, timeout).await
}

fn get_cycle_limit(&self, elf: &[u8], stdin: &SP1Stdin, skip_simulation: bool) -> Result<u64> {
/// The cycle limit is determined according to the following priority:
///
/// # Details
/// 1. If a cycle limit was explicitly set, use the specified value.
/// 2. If simulation is enabled (it is by default), calculate the limit by simulating the
/// execution of the program.
/// 3. Otherwise, use the default cycle limit ([`DEFAULT_CYCLE_LIMIT`]).
fn get_cycle_limit(
&self,
cycle_limit: Option<u64>,
elf: &[u8],
stdin: &SP1Stdin,
skip_simulation: bool,
) -> Result<u64> {
if let Some(cycle_limit) = cycle_limit {
return Ok(cycle_limit);
}

if skip_simulation {
Ok(DEFAULT_CYCLE_LIMIT)
} else {
Expand All @@ -328,7 +349,7 @@ impl Prover<CpuProverComponents> for NetworkProver {
stdin: &SP1Stdin,
mode: SP1ProofMode,
) -> Result<SP1ProofWithPublicValues> {
block_on(self.prove_impl(pk, stdin, mode, FulfillmentStrategy::Hosted, None, false))
block_on(self.prove_impl(pk, stdin, mode, FulfillmentStrategy::Hosted, None, false, None))
}
}

Expand Down

0 comments on commit 217437d

Please sign in to comment.