diff --git a/crates/core/executor/src/executor.rs b/crates/core/executor/src/executor.rs index c2cc1915a..11eb33e91 100644 --- a/crates/core/executor/src/executor.rs +++ b/crates/core/executor/src/executor.rs @@ -1594,8 +1594,9 @@ impl<'a> Executor<'a> { many proofs in or forget to call verify_sp1_proof?" ); } - if self.state.input_stream_ptr != self.state.input_stream.len() { - tracing::warn!("Not all input bytes were read."); + + if !self.state.input_stream.is_empty() { + tracing::warn!("Not all input bytes were read"); } if self.emit_global_memory_events diff --git a/crates/core/executor/src/hook.rs b/crates/core/executor/src/hook.rs index 63645d84d..69d887a7c 100644 --- a/crates/core/executor/src/hook.rs +++ b/crates/core/executor/src/hook.rs @@ -7,6 +7,7 @@ use sp1_curves::{ ecdsa::RecoveryId as ecdsaRecoveryId, k256::{Invert, RecoveryId, Signature, VerifyingKey}, p256::{Signature as p256Signature, VerifyingKey as p256VerifyingKey}, + BigUint, Integer }; use crate::Executor; @@ -18,9 +19,15 @@ pub type BoxedHook<'a> = Arc>; pub const FD_K1_ECRECOVER_HOOK: u32 = 5; /// The file descriptor through which to access `hook_r1_ecrecover`. pub const FD_R1_ECRECOVER_HOOK: u32 = 6; + +// 7 is used in main + /// The file descriptor through which to access `hook_ed_decompress`. pub const FD_EDDECOMPRESS: u32 = 8; +/// The file descriptor through which to access `hook_rsa_mul_mod`. +pub const FD_RSA_MUL_MOD: u32 = 9; + /// A runtime hook. May be called during execution by writing to a specified file descriptor, /// accepting and returning arbitrary data. pub trait Hook { @@ -86,6 +93,7 @@ impl<'a> Default for HookRegistry<'a> { (FD_K1_ECRECOVER_HOOK, hookify(hook_k1_ecrecover)), (FD_R1_ECRECOVER_HOOK, hookify(hook_r1_ecrecover)), (FD_EDDECOMPRESS, hookify(hook_ed_decompress)), + (FD_RSA_MUL_MOD, hookify(hook_rsa_mul_mod)), ]); Self { table } @@ -223,6 +231,37 @@ pub fn hook_r1_ecrecover(_: HookEnv, buf: &[u8]) -> Vec> { vec![vec![1], bytes.to_vec(), s_inverse.to_bytes().to_vec()] } +/// Given the product of some 256-bit numbers and the modulus, this function does a modular +/// reduction and hints back to the vm in order to constrain it. +/// +/// # Arguments +/// +/// * `env` - The environment in which the hook is invoked. +/// * `buf` - The buffer containing the product and the modulus. +/// +/// WANRING: This function is used to perform a modular reduction outside of the zkVM context. +/// These values must be constrained by the zkVM for correctness. +#[must_use] +pub fn hook_rsa_mul_mod(_: HookEnv, buf: &[u8]) -> Vec> { + assert_eq!(buf.len(), 256 + 256 + 256, "rsa_mul_mod input should have length 256 + 256 + 256, this is a bug."); + + let prod: &[u8; 512] = buf[..512].try_into().unwrap(); + let m: &[u8; 256] = buf[512..].try_into().unwrap(); + + let prod = BigUint::from_bytes_le(prod); + let m = BigUint::from_bytes_le(m); + + let (q, rem) = prod.div_rem(&m); + + let mut rem = rem.to_bytes_le(); + rem.resize(512, 0); + + let mut q = q.to_bytes_le(); + q.resize(256, 0); + + vec![rem, q] +} + #[cfg(test)] pub mod tests { use super::*; diff --git a/crates/core/executor/src/io.rs b/crates/core/executor/src/io.rs index 767697c60..67a8f7983 100644 --- a/crates/core/executor/src/io.rs +++ b/crates/core/executor/src/io.rs @@ -18,18 +18,18 @@ impl<'a> Executor<'a> { pub fn write_stdin(&mut self, input: &T) { let mut buf = Vec::new(); bincode::serialize_into(&mut buf, input).expect("serialization failed"); - self.state.input_stream.push(buf); + self.state.input_stream.push_back(buf); } /// Write a slice of bytes to the standard input stream. pub fn write_stdin_slice(&mut self, input: &[u8]) { - self.state.input_stream.push(input.to_vec()); + self.state.input_stream.push_back(input.to_vec()); } /// Write a slice of vecs to the standard input stream. pub fn write_vecs(&mut self, inputs: &[Vec]) { for input in inputs { - self.state.input_stream.push(input.clone()); + self.state.input_stream.push_back(input.clone()); } } diff --git a/crates/core/executor/src/state.rs b/crates/core/executor/src/state.rs index 55ba22e32..2bc29a57e 100644 --- a/crates/core/executor/src/state.rs +++ b/crates/core/executor/src/state.rs @@ -1,6 +1,5 @@ use std::{ - fs::File, - io::{Seek, Write}, + collections::VecDeque, fs::File, io::{Seek, Write} }; use hashbrown::HashMap; @@ -42,10 +41,7 @@ pub struct ExecutionState { pub uninitialized_memory: PagedMemory, /// A stream of input values (global to the entire program). - pub input_stream: Vec>, - - /// A ptr to the current position in the input stream incremented by `HINT_READ` opcode. - pub input_stream_ptr: usize, + pub input_stream: VecDeque>, /// A stream of proofs (reduce vk, proof, verifying key) inputted to the program. pub proof_stream: @@ -77,8 +73,7 @@ impl ExecutionState { pc: pc_start, memory: PagedMemory::new_preallocated(), uninitialized_memory: PagedMemory::default(), - input_stream: Vec::new(), - input_stream_ptr: 0, + input_stream: VecDeque::new(), public_values_stream: Vec::new(), public_values_stream_ptr: 0, proof_stream: Vec::new(), diff --git a/crates/core/executor/src/syscalls/hint.rs b/crates/core/executor/src/syscalls/hint.rs index a740250fe..4946a118e 100644 --- a/crates/core/executor/src/syscalls/hint.rs +++ b/crates/core/executor/src/syscalls/hint.rs @@ -1,4 +1,5 @@ use super::{Syscall, SyscallCode, SyscallContext}; +use std::collections::VecDeque; pub(crate) struct HintLenSyscall; @@ -10,14 +11,7 @@ impl Syscall for HintLenSyscall { _arg1: u32, _arg2: u32, ) -> Option { - if ctx.rt.state.input_stream_ptr >= ctx.rt.state.input_stream.len() { - panic!( - "failed reading stdin due to insufficient input data: input_stream_ptr={}, input_stream_len={}", - ctx.rt.state.input_stream_ptr, - ctx.rt.state.input_stream.len() - ); - } - Some(ctx.rt.state.input_stream[ctx.rt.state.input_stream_ptr].len() as u32) + Some(next_len_or_panic(&ctx.rt.state.input_stream)) } } @@ -25,18 +19,15 @@ pub(crate) struct HintReadSyscall; impl Syscall for HintReadSyscall { fn execute(&self, ctx: &mut SyscallContext, _: SyscallCode, ptr: u32, len: u32) -> Option { - if ctx.rt.state.input_stream_ptr >= ctx.rt.state.input_stream.len() { - panic!( - "failed reading stdin due to insufficient input data: input_stream_ptr={}, input_stream_len={}", - ctx.rt.state.input_stream_ptr, - ctx.rt.state.input_stream.len() - ); - } - let vec = &ctx.rt.state.input_stream[ctx.rt.state.input_stream_ptr]; - ctx.rt.state.input_stream_ptr += 1; + let _ = next_len_or_panic(&ctx.rt.state.input_stream); + + // SAFTEY: We check if we have a vec in the input stream in the previous line. + let vec = unsafe { ctx.rt.state.input_stream.pop_front().unwrap_unchecked() }; + assert!(!ctx.rt.unconstrained, "hint read should not be used in a unconstrained block"); assert_eq!(vec.len() as u32, len, "hint input stream read length mismatch"); assert_eq!(ptr % 4, 0, "hint read address not aligned to 4 bytes"); + // Iterate through the vec in 4-byte chunks for i in (0..len).step_by(4) { // Get each byte in the chunk @@ -61,3 +52,9 @@ impl Syscall for HintReadSyscall { None } } + +fn next_len_or_panic(queue: &VecDeque>) -> u32 { + queue.front().map(|vec| vec.len() as u32).unwrap_or_else(|| { + panic!("Syscall Hint Len failed becasue the input stream is exhausted"); + }) +} diff --git a/crates/core/executor/src/syscalls/write.rs b/crates/core/executor/src/syscalls/write.rs index d0db44cae..398f665de 100644 --- a/crates/core/executor/src/syscalls/write.rs +++ b/crates/core/executor/src/syscalls/write.rs @@ -63,12 +63,15 @@ impl Syscall for WriteSyscall { } else if fd == 3 { rt.state.public_values_stream.extend_from_slice(slice); } else if fd == 4 { - rt.state.input_stream.push(slice.to_vec()); + rt.state.input_stream.push_back(slice.to_vec()); } else if let Some(mut hook) = rt.hook_registry.get(fd) { let res = hook.invoke_hook(rt.hook_env(), slice); - // Add result vectors to the beginning of the stream. - let ptr = rt.state.input_stream_ptr; - rt.state.input_stream.splice(ptr..ptr, res); + + // Write the items in reverse order to the input stream + // to preserve the expected order when reading + for item in res.into_iter().rev() { + rt.state.input_stream.push_front(item); + } } else { tracing::warn!("tried to write to unknown file descriptor {fd}"); } diff --git a/crates/curves/src/lib.rs b/crates/curves/src/lib.rs index 73682a74d..29b813c2d 100644 --- a/crates/curves/src/lib.rs +++ b/crates/curves/src/lib.rs @@ -68,7 +68,7 @@ use std::{ }; use typenum::Unsigned; -use num::BigUint; +pub use num::{Integer, BigUint}; use serde::{de::DeserializeOwned, Serialize}; pub const NUM_WORDS_FIELD_ELEMENT: usize = 8; diff --git a/crates/zkvm/lib/src/io.rs b/crates/zkvm/lib/src/io.rs index 6ab651fe0..34a06170c 100644 --- a/crates/zkvm/lib/src/io.rs +++ b/crates/zkvm/lib/src/io.rs @@ -18,6 +18,8 @@ pub const FD_K1_ECRECOVER_HOOK: u32 = 5; pub const FD_R1_ECRECOVER_HOOK: u32 = 6; /// The file descriptor through which to access `hook_ed_decompress`. pub const FD_EDDECOMPRESS: u32 = 8; +/// The file descriptor through which to access `hook_rsa_mul_mod`. +pub const FD_RSA_MUL_MOD: u32 = 9; /// A writer that writes to a file descriptor inside the zkVM. struct SyscallWriter {