Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Example of recursive prover #1833

Open
wants to merge 2 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions examples/recursive/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[workspace]
members = [
"lib",
"program",
]
resolver = "2"

[workspace.dependencies]
alloy-sol-types = "0.7.7"
sha2 = "0.10.8"
sp1-zkvm = { version = "3.3.0", features = ["verify"] }
serde = { version = "1.0.209", features = ["derive"] }
10 changes: 10 additions & 0 deletions examples/recursive/lib/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
[package]
name = "recursive-lib"
version = "0.1.0"
edition = "2021"

[dependencies]
alloy-sol-types = { workspace = true }
sha2 = { workspace = true }
sp1-zkvm = { workspace = true }
serde = { workspace = true }
Binary file added examples/recursive/lib/src/.lib.rs.swp
Binary file not shown.
49 changes: 49 additions & 0 deletions examples/recursive/lib/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
pub mod utils;

use serde::{Deserialize, Serialize};
use crate::utils::sha256_hash;

#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct CircuitInput {
pub public_input_merkle_root: [u8; 32],
pub public_value: u32,
pub private_value: u32,
pub witness: Vec<u32>,
}

impl CircuitInput {
pub fn new(public_input_merkle_root: [u8; 32], public_value: u32, private_value: u32, witness: Vec<u32>) -> Self {
Self {
public_input_merkle_root,
public_value,
private_value,
witness,
}
}
}

/// A toy example of accumulation of cubic
pub fn acc_cubic(public_value: u32, private_value: u32) -> u32 {
private_value.wrapping_add(public_value.wrapping_mul(public_value).wrapping_mul(public_value))
}

/// Verify last prover's proof
pub fn verify_proof(vkey_hash: &[u32; 8], public_input_merkle_root: &[u8; 32], private_value: u32) {
let mut bytes = Vec::with_capacity(36);
bytes.extend_from_slice(public_input_merkle_root);
bytes.extend_from_slice(&private_value.to_le_bytes());
sp1_zkvm::lib::verify::verify_sp1_proof(vkey_hash, &sha256_hash(&bytes));
}

/// Construct a merkle tree for all public inputs avoiding commit these public inputs directly
pub fn merkle_tree_public_input(
public_input: Vec<u32>,
public_value: u32,
) -> [u8; 32] {
let public_input_hashes = public_input
.iter()
.chain([public_value].iter())
.map(|pi| sha256_hash(&pi.to_le_bytes()))
.collect::<Vec<_>>();
utils::get_merkle_root(public_input_hashes)
}
57 changes: 57 additions & 0 deletions examples/recursive/lib/src/utils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
use sha2::{Digest, Sha256};

pub trait AsLittleEndianBytes {
fn to_little_endian(self) -> Self;
}

impl<const N: usize> AsLittleEndianBytes for [u8; N] {
fn to_little_endian(mut self) -> Self {
self.reverse();
self
}
}

pub fn sha256_hash(bytes: &[u8]) -> [u8; 32] {
Sha256::digest(bytes).into()
}

pub fn hash_pairs(hash_1: [u8; 32], hash_2: [u8; 32]) -> [u8; 32] {
// [0] & [1] Combine hashes into one 64 byte array, reversing byte order
let combined_hashes: [u8; 64] = hash_1
.into_iter()
.rev()
.chain(hash_2.into_iter().rev())
.collect::<Vec<u8>>()
.try_into()
.unwrap();

// [2] Double sha256 combined hashes
let new_hash_be = sha256_hash(&sha256_hash(&combined_hashes));

// [3] Convert new hash to little-endian
new_hash_be.to_little_endian()
}

pub fn get_merkle_root(leaves: Vec<[u8; 32]>) -> [u8; 32] {
let mut current_level = leaves;
while current_level.len() > 1 {
let mut next_level = Vec::new();
let mut i = 0;

while i < current_level.len() {
let left = current_level[i];
let right = if i + 1 < current_level.len() {
current_level[i + 1]
} else {
left
};

let parent_hash = hash_pairs(left, right);
next_level.push(parent_hash);

i += 2;
}
current_level = next_level;
}
current_level[0]
}
9 changes: 9 additions & 0 deletions examples/recursive/program/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
[package]
version = "0.1.0"
name = "recursive-program"
edition = "2021"

[dependencies]
alloy-sol-types = { workspace = true }
sp1-zkvm = "3.0.0-rc4"
recursive-lib = { path = "../lib" }
35 changes: 35 additions & 0 deletions examples/recursive/program/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
//! A simple program that takes a sequence of numbers as input, cubic all of them, and then sum up.

// These two lines are necessary for the program to properly compile.
//
// Under the hood, we wrap your main function with some extra code so that it behaves properly
// inside the zkVM.
#![no_main]
sp1_zkvm::entrypoint!(main);

use recursive_lib::{acc_cubic, verify_proof, merkle_tree_public_input, CircuitInput};

pub fn main() {
// Read prover's sequence number
let seq = sp1_zkvm::io::read::<u32>();
// Read hash of vkey for verifying last prover's proof
let vkey_u32_hash = sp1_zkvm::io::read::<[u32; 8]>();
// Read circuit input
let circuit_input = sp1_zkvm::io::read::<CircuitInput>();

// Do cubic computation
let result = acc_cubic(circuit_input.public_value, circuit_input.private_value);
// Verify proof output by last prover
if seq != 0 {
verify_proof(&vkey_u32_hash, &circuit_input.public_input_merkle_root, circuit_input.private_value);
}
// Construct a merkle root of all public inputs
let merkle_root = merkle_tree_public_input(
circuit_input.witness,
circuit_input.public_value,
);

// Commit this merkle root and cubic result
sp1_zkvm::io::commit(&merkle_root);
sp1_zkvm::io::commit(&result);
}
3 changes: 3 additions & 0 deletions examples/recursive/rust-toolchain
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[toolchain]
channel = "1.81.0"
components = ["llvm-tools", "rustc-dev"]
22 changes: 22 additions & 0 deletions examples/recursive/script/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
[package]
version = "0.1.0"
name = "recursive-script"
edition = "2021"
default-run = "recursive_prover"

[[bin]]
name = "recursive_prover"
path = "src/bin/recursive_prover.rs"

[dependencies]
sp1-sdk = "3.0.0"
serde_json = { version = "1.0", default-features = false, features = ["alloc"] }
serde = { version = "1.0.200", default-features = false, features = ["derive"] }
clap = { version = "4.0", features = ["derive", "env"] }
tracing = "0.1.40"
hex = "0.4.3"
alloy-sol-types = { workspace = true }
recursive-lib = { path = "../lib" }

[build-dependencies]
sp1-helper = "3.0.0"
5 changes: 5 additions & 0 deletions examples/recursive/script/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
use sp1_helper::build_program_with_args;

fn main() {
build_program_with_args("../program", Default::default())
}
111 changes: 111 additions & 0 deletions examples/recursive/script/src/bin/recursive_prover.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
//! An end-to-end example of using the SP1 SDK to generate a proof of a program that can be executed
//! or have a core proof generated.
//!
//! You can run this script using the following command:
//! ```shell
//! RUST_LOG=info cargo run --release -- --execute
//! ```
//! or
//! ```shell
//! RUST_LOG=info cargo run --release -- --prove
//! ```

use clap::Parser;
use sp1_sdk::{include_elf, ProverClient, SP1Stdin};
use recursive_lib::CircuitInput;
use sp1_sdk::SP1Proof;
use sp1_sdk::HashableKey;

/// The ELF (executable and linkable format) file for the Succinct RISC-V zkVM.
pub const RECURSIVE_ELF: &[u8] = include_elf!("recursive-program");

/// The arguments for the command.
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
#[clap(long)]
execute: bool,

#[clap(long)]
prove: bool,

#[clap(long, default_value = "20")]
n: u32,
}

fn main() {
// Setup the logger.
sp1_sdk::utils::setup_logger();

// Parse the command line arguments.
let args = Args::parse();

if args.execute == args.prove {
eprintln!("Error: You must specify either --execute or --prove");
std::process::exit(1);
}
assert_eq!(args.n > 0, true, "n must be greater than 0");
let test_public_values = (0..args.n).map(|i| 100 + i).collect::<Vec<_>>();

if args.execute {
let client = ProverClient::new();
let (recursive_prover_pk, recursive_prover_vk) = client.setup(RECURSIVE_ELF);

// For the very first prover
// initialized public and private values for the very first prover
let mut vkey_hash = [0u32; 8];
let mut public_input_merkle_root = [0u8; 32];
let mut public_value = test_public_values[0];
let mut private_value = 0u32;
let mut witness: Vec<u32> = vec![];
let mut circuit_input = CircuitInput::new(public_input_merkle_root, public_value, private_value, witness);

// just fill in STDIN
let mut stdin = SP1Stdin::new();
// write sequence number
stdin.write(&(0 as u32));
// write vkey u32 hash
stdin.write(&vkey_hash);
// write circuit input
stdin.write(&circuit_input);
// generate proof for the very first prover
let mut last_prover_proof = client
.prove(&recursive_prover_pk, stdin)
.compressed()
.run()
.expect("proving failed");
println!("## Generating proof for the very first prover succeeds!");

// For the rest of the provers
for seq in 1..args.n {
// public and private values for the rest of provers
vkey_hash = recursive_prover_vk.hash_u32();
public_input_merkle_root = last_prover_proof.public_values.read::<[u8; 32]>();
private_value = last_prover_proof.public_values.read::<u32>();
public_value = test_public_values[seq as usize];
witness = test_public_values[..seq as usize].to_vec();
circuit_input = CircuitInput::new(public_input_merkle_root, public_value, private_value, witness);

// just fill in STDIN
stdin = SP1Stdin::new();
stdin.write(&(seq as u32));
stdin.write(&vkey_hash);
stdin.write(&circuit_input);
let SP1Proof::Compressed(proof) = last_prover_proof.proof else {
panic!()
};
// write proof and vkey as private value
stdin.write_proof(*proof, recursive_prover_vk.vk.clone());
last_prover_proof = client
.prove(&recursive_prover_pk, stdin)
.compressed()
.run()
.expect("proving failed");
println!("## Generating proof for one of the rest provers succeeds!");
}


} else {
unimplemented!();
}
}
Loading