diff --git a/examples/recursive/Cargo.toml b/examples/recursive/Cargo.toml new file mode 100644 index 0000000000..4373614e63 --- /dev/null +++ b/examples/recursive/Cargo.toml @@ -0,0 +1,12 @@ +[workspace] +members = [ + "lib", + "program", +] +resolver = "2" + +[workspace.dependencies] +alloy-sol-types = "0.7.7" +sha2 = "0.10.8" +sp1-zkvm = { version = "3.3.0", features = ["verify"] } +serde = { version = "1.0.209", features = ["derive"] } diff --git a/examples/recursive/lib/Cargo.toml b/examples/recursive/lib/Cargo.toml new file mode 100644 index 0000000000..cd606c99bb --- /dev/null +++ b/examples/recursive/lib/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "recursive-lib" +version = "0.1.0" +edition = "2021" + +[dependencies] +alloy-sol-types = { workspace = true } +sha2 = { workspace = true } +sp1-zkvm = { workspace = true } +serde = { workspace = true } diff --git a/examples/recursive/lib/src/lib.rs b/examples/recursive/lib/src/lib.rs new file mode 100644 index 0000000000..ff9fd2ae8c --- /dev/null +++ b/examples/recursive/lib/src/lib.rs @@ -0,0 +1,34 @@ +pub mod utils; + +use serde::{Deserialize, Serialize}; +use crate::utils::sha256_hash; + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct CircuitInput { + pub public_input_merkle_root: [u8; 32], + pub public_value: u32, + pub witness: Vec, +} + +/// A toy example of cubic computation +pub fn cubic(n: u32) -> u32 { + n.wrapping_mul(n).wrapping_mul(n) +} + +/// Verify last prover's proof +pub fn verify_proof(vkey_hash: &[u32; 8], public_input_merkle_root: &[u8; 32]) { + sp1_zkvm::lib::verify::verify_sp1_proof(vkey_hash, &sha256_hash(public_input_merkle_root)); +} + +/// Construct a merkle tree for all public inputs avoiding commit these public inputs directly +pub fn merkle_tree_public_input( + public_input: Vec, + public_value: u32, +) -> [u8; 32] { + let public_input_hashes = public_input + .iter() + .chain([public_value].iter()) + .map(|pi| sha256_hash(&pi.to_le_bytes())) + .collect::>(); + utils::get_merkle_root(public_input_hashes) +} diff --git a/examples/recursive/lib/src/utils.rs b/examples/recursive/lib/src/utils.rs new file mode 100644 index 0000000000..55abf53bc1 --- /dev/null +++ b/examples/recursive/lib/src/utils.rs @@ -0,0 +1,57 @@ +use sha2::{Digest, Sha256}; + +pub trait AsLittleEndianBytes { + fn to_little_endian(self) -> Self; +} + +impl AsLittleEndianBytes for [u8; N] { + fn to_little_endian(mut self) -> Self { + self.reverse(); + self + } +} + +pub fn sha256_hash(bytes: &[u8]) -> [u8; 32] { + Sha256::digest(bytes).into() +} + +pub fn hash_pairs(hash_1: [u8; 32], hash_2: [u8; 32]) -> [u8; 32] { + // [0] & [1] Combine hashes into one 64 byte array, reversing byte order + let combined_hashes: [u8; 64] = hash_1 + .into_iter() + .rev() + .chain(hash_2.into_iter().rev()) + .collect::>() + .try_into() + .unwrap(); + + // [2] Double sha256 combined hashes + let new_hash_be = sha256_hash(&sha256_hash(&combined_hashes)); + + // [3] Convert new hash to little-endian + new_hash_be.to_little_endian() +} + +pub fn get_merkle_root(leaves: Vec<[u8; 32]>) -> [u8; 32] { + let mut current_level = leaves; + while current_level.len() > 1 { + let mut next_level = Vec::new(); + let mut i = 0; + + while i < current_level.len() { + let left = current_level[i]; + let right = if i + 1 < current_level.len() { + current_level[i + 1] + } else { + left + }; + + let parent_hash = hash_pairs(left, right); + next_level.push(parent_hash); + + i += 2; + } + current_level = next_level; + } + current_level[0] +} diff --git a/examples/recursive/program/Cargo.toml b/examples/recursive/program/Cargo.toml new file mode 100644 index 0000000000..279d510070 --- /dev/null +++ b/examples/recursive/program/Cargo.toml @@ -0,0 +1,9 @@ +[package] +version = "0.1.0" +name = "recursive-program" +edition = "2021" + +[dependencies] +alloy-sol-types = { workspace = true } +sp1-zkvm = "3.0.0-rc4" +recursive-lib = { path = "../lib" } diff --git a/examples/recursive/program/src/main.rs b/examples/recursive/program/src/main.rs new file mode 100644 index 0000000000..9923930854 --- /dev/null +++ b/examples/recursive/program/src/main.rs @@ -0,0 +1,35 @@ +//! A simple program that takes a sequence of numbers as input, cubic all of them, and then sum up. + +// These two lines are necessary for the program to properly compile. +// +// Under the hood, we wrap your main function with some extra code so that it behaves properly +// inside the zkVM. +#![no_main] +sp1_zkvm::entrypoint!(main); + +use recursive_lib::{cubic, verify_proof, merkle_tree_public_input, CircuitInput}; + +pub fn main() { + // Read prover's sequence number + let seq = sp1_zkvm::io::read::(); + // Read hash of vkey for verifying last prover's proof + let vkey_u32_hash = sp1_zkvm::io::read::<[u32; 8]>(); + // Read circuit input + let circuit_input = sp1_zkvm::io::read::(); + + // Do cubic computation + let result = cubic(circuit_input.public_value); + // Verify proof output by last prover + if seq != 0 { + verify_proof(&vkey_u32_hash, &circuit_input.public_input_merkle_root); + } + // Construct a merkle root of all public inputs + let merkle_root = merkle_tree_public_input( + circuit_input.witness, + circuit_input.public_value, + ); + + // Commit this merkle root and cubic result + sp1_zkvm::io::commit(&merkle_root); + sp1_zkvm::io::commit(&result); +} diff --git a/examples/recursive/rust-toolchain b/examples/recursive/rust-toolchain new file mode 100644 index 0000000000..d9143e67a5 --- /dev/null +++ b/examples/recursive/rust-toolchain @@ -0,0 +1,3 @@ +[toolchain] +channel = "1.81.0" +components = ["llvm-tools", "rustc-dev"] \ No newline at end of file diff --git a/examples/recursive/script/Cargo.toml b/examples/recursive/script/Cargo.toml new file mode 100644 index 0000000000..4d8a9d10c9 --- /dev/null +++ b/examples/recursive/script/Cargo.toml @@ -0,0 +1,22 @@ +[package] +version = "0.1.0" +name = "recursive-script" +edition = "2021" +default-run = "recursive_prover" + +[[bin]] +name = "recursive_prover" +path = "src/bin/recursive_prover.rs" + +[dependencies] +sp1-sdk = "3.0.0" +serde_json = { version = "1.0", default-features = false, features = ["alloc"] } +serde = { version = "1.0.200", default-features = false, features = ["derive"] } +clap = { version = "4.0", features = ["derive", "env"] } +tracing = "0.1.40" +hex = "0.4.3" +alloy-sol-types = { workspace = true } +recursive-lib = { path = "../lib" } + +[build-dependencies] +sp1-helper = "3.0.0" diff --git a/examples/recursive/script/build.rs b/examples/recursive/script/build.rs new file mode 100644 index 0000000000..bc5f025978 --- /dev/null +++ b/examples/recursive/script/build.rs @@ -0,0 +1,5 @@ +use sp1_helper::build_program_with_args; + +fn main() { + build_program_with_args("../program", Default::default()) +} diff --git a/examples/recursive/script/src/bin/recursive_prover.rs b/examples/recursive/script/src/bin/recursive_prover.rs new file mode 100644 index 0000000000..97d03b7eac --- /dev/null +++ b/examples/recursive/script/src/bin/recursive_prover.rs @@ -0,0 +1,91 @@ +//! An end-to-end example of using the SP1 SDK to generate a proof of a program that can be executed +//! or have a core proof generated. +//! +//! You can run this script using the following command: +//! ```shell +//! RUST_LOG=info cargo run --release -- --execute +//! ``` +//! or +//! ```shell +//! RUST_LOG=info cargo run --release -- --prove +//! ``` + +use alloy_sol_types::SolType; +use clap::Parser; +use fibonacci_lib::PublicValuesStruct; +use sp1_sdk::{include_elf, ProverClient, SP1Stdin}; + +/// The ELF (executable and linkable format) file for the Succinct RISC-V zkVM. +pub const FIBONACCI_ELF: &[u8] = include_elf!("fibonacci-program"); + +/// The arguments for the command. +#[derive(Parser, Debug)] +#[clap(author, version, about, long_about = None)] +struct Args { + #[clap(long)] + execute: bool, + + #[clap(long)] + prove: bool, + + #[clap(long, default_value = "20")] + n: u32, +} + +fn main() { + // Setup the logger. + sp1_sdk::utils::setup_logger(); + + // Parse the command line arguments. + let args = Args::parse(); + + if args.execute == args.prove { + eprintln!("Error: You must specify either --execute or --prove"); + std::process::exit(1); + } + + // Setup the prover client. + let client = ProverClient::new(); + + // Setup the inputs. + let mut stdin = SP1Stdin::new(); + stdin.write(&args.n); + + println!("n: {}", args.n); + + if args.execute { + // Execute the program + let (output, report) = client.execute(FIBONACCI_ELF, stdin).run().unwrap(); + println!("Program executed successfully."); + + // Read the output. + let decoded = PublicValuesStruct::abi_decode(output.as_slice(), true).unwrap(); + let PublicValuesStruct { n, a, b } = decoded; + println!("n: {}", n); + println!("a: {}", a); + println!("b: {}", b); + + let (expected_a, expected_b) = fibonacci_lib::fibonacci(n); + assert_eq!(a, expected_a); + assert_eq!(b, expected_b); + println!("Values are correct!"); + + // Record the number of cycles executed. + println!("Number of cycles: {}", report.total_instruction_count()); + } else { + // Setup the program for proving. + let (pk, vk) = client.setup(FIBONACCI_ELF); + + // Generate the proof + let proof = client + .prove(&pk, stdin) + .run() + .expect("failed to generate proof"); + + println!("Successfully generated proof!"); + + // Verify the proof. + client.verify(&proof, &vk).expect("failed to verify proof"); + println!("Successfully verified proof!"); + } +}