Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chore(hook): Add RSA hook #1860

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions crates/core/executor/src/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1594,8 +1594,9 @@ impl<'a> Executor<'a> {
many proofs in or forget to call verify_sp1_proof?"
);
}
if self.state.input_stream_ptr != self.state.input_stream.len() {
tracing::warn!("Not all input bytes were read.");

if !self.state.input_stream.is_empty() {
tracing::warn!("Not all input bytes were read");
}

if self.emit_global_memory_events
Expand Down
39 changes: 39 additions & 0 deletions crates/core/executor/src/hook.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use sp1_curves::{
ecdsa::RecoveryId as ecdsaRecoveryId,
k256::{Invert, RecoveryId, Signature, VerifyingKey},
p256::{Signature as p256Signature, VerifyingKey as p256VerifyingKey},
BigUint, Integer
};

use crate::Executor;
Expand All @@ -18,9 +19,15 @@ pub type BoxedHook<'a> = Arc<RwLock<dyn Hook + Send + Sync + 'a>>;
pub const FD_K1_ECRECOVER_HOOK: u32 = 5;
/// The file descriptor through which to access `hook_r1_ecrecover`.
pub const FD_R1_ECRECOVER_HOOK: u32 = 6;

// 7 is used in main

/// The file descriptor through which to access `hook_ed_decompress`.
pub const FD_EDDECOMPRESS: u32 = 8;

/// The file descriptor through which to access `hook_rsa_mul_mod`.
pub const FD_RSA_MUL_MOD: u32 = 9;

/// A runtime hook. May be called during execution by writing to a specified file descriptor,
/// accepting and returning arbitrary data.
pub trait Hook {
Expand Down Expand Up @@ -86,6 +93,7 @@ impl<'a> Default for HookRegistry<'a> {
(FD_K1_ECRECOVER_HOOK, hookify(hook_k1_ecrecover)),
(FD_R1_ECRECOVER_HOOK, hookify(hook_r1_ecrecover)),
(FD_EDDECOMPRESS, hookify(hook_ed_decompress)),
(FD_RSA_MUL_MOD, hookify(hook_rsa_mul_mod)),
]);

Self { table }
Expand Down Expand Up @@ -223,6 +231,37 @@ pub fn hook_r1_ecrecover(_: HookEnv, buf: &[u8]) -> Vec<Vec<u8>> {
vec![vec![1], bytes.to_vec(), s_inverse.to_bytes().to_vec()]
}

/// Given the product of some 256-bit numbers and the modulus, this function does a modular
/// reduction and hints back to the vm in order to constrain it.
///
/// # Arguments
///
/// * `env` - The environment in which the hook is invoked.
/// * `buf` - The buffer containing the product and the modulus.
///
/// WANRING: This function is used to perform a modular reduction outside of the zkVM context.
/// These values must be constrained by the zkVM for correctness.
#[must_use]
pub fn hook_rsa_mul_mod(_: HookEnv, buf: &[u8]) -> Vec<Vec<u8>> {
assert_eq!(buf.len(), 256 + 256 + 256, "rsa_mul_mod input should have length 256 + 256 + 256, this is a bug.");

let prod: &[u8; 512] = buf[..512].try_into().unwrap();
let m: &[u8; 256] = buf[512..].try_into().unwrap();

let prod = BigUint::from_bytes_le(prod);
let m = BigUint::from_bytes_le(m);

let (q, rem) = prod.div_rem(&m);

let mut rem = rem.to_bytes_le();
rem.resize(512, 0);

let mut q = q.to_bytes_le();
q.resize(256, 0);

vec![rem, q]
}

#[cfg(test)]
pub mod tests {
use super::*;
Expand Down
6 changes: 3 additions & 3 deletions crates/core/executor/src/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,18 @@ impl<'a> Executor<'a> {
pub fn write_stdin<T: Serialize>(&mut self, input: &T) {
let mut buf = Vec::new();
bincode::serialize_into(&mut buf, input).expect("serialization failed");
self.state.input_stream.push(buf);
self.state.input_stream.push_back(buf);
}

/// Write a slice of bytes to the standard input stream.
pub fn write_stdin_slice(&mut self, input: &[u8]) {
self.state.input_stream.push(input.to_vec());
self.state.input_stream.push_back(input.to_vec());
}

/// Write a slice of vecs to the standard input stream.
pub fn write_vecs(&mut self, inputs: &[Vec<u8>]) {
for input in inputs {
self.state.input_stream.push(input.clone());
self.state.input_stream.push_back(input.clone());
}
}

Expand Down
11 changes: 3 additions & 8 deletions crates/core/executor/src/state.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use std::{
fs::File,
io::{Seek, Write},
collections::VecDeque, fs::File, io::{Seek, Write}
};

use hashbrown::HashMap;
Expand Down Expand Up @@ -42,10 +41,7 @@ pub struct ExecutionState {
pub uninitialized_memory: PagedMemory<u32>,

/// A stream of input values (global to the entire program).
pub input_stream: Vec<Vec<u8>>,

/// A ptr to the current position in the input stream incremented by `HINT_READ` opcode.
pub input_stream_ptr: usize,
pub input_stream: VecDeque<Vec<u8>>,

/// A stream of proofs (reduce vk, proof, verifying key) inputted to the program.
pub proof_stream:
Expand Down Expand Up @@ -77,8 +73,7 @@ impl ExecutionState {
pc: pc_start,
memory: PagedMemory::new_preallocated(),
uninitialized_memory: PagedMemory::default(),
input_stream: Vec::new(),
input_stream_ptr: 0,
input_stream: VecDeque::new(),
public_values_stream: Vec::new(),
public_values_stream_ptr: 0,
proof_stream: Vec::new(),
Expand Down
31 changes: 14 additions & 17 deletions crates/core/executor/src/syscalls/hint.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::{Syscall, SyscallCode, SyscallContext};
use std::collections::VecDeque;

pub(crate) struct HintLenSyscall;

Expand All @@ -10,33 +11,23 @@ impl Syscall for HintLenSyscall {
_arg1: u32,
_arg2: u32,
) -> Option<u32> {
if ctx.rt.state.input_stream_ptr >= ctx.rt.state.input_stream.len() {
panic!(
"failed reading stdin due to insufficient input data: input_stream_ptr={}, input_stream_len={}",
ctx.rt.state.input_stream_ptr,
ctx.rt.state.input_stream.len()
);
}
Some(ctx.rt.state.input_stream[ctx.rt.state.input_stream_ptr].len() as u32)
Some(next_len_or_panic(&ctx.rt.state.input_stream))
}
}

pub(crate) struct HintReadSyscall;

impl Syscall for HintReadSyscall {
fn execute(&self, ctx: &mut SyscallContext, _: SyscallCode, ptr: u32, len: u32) -> Option<u32> {
if ctx.rt.state.input_stream_ptr >= ctx.rt.state.input_stream.len() {
panic!(
"failed reading stdin due to insufficient input data: input_stream_ptr={}, input_stream_len={}",
ctx.rt.state.input_stream_ptr,
ctx.rt.state.input_stream.len()
);
}
let vec = &ctx.rt.state.input_stream[ctx.rt.state.input_stream_ptr];
ctx.rt.state.input_stream_ptr += 1;
let _ = next_len_or_panic(&ctx.rt.state.input_stream);

// SAFTEY: We check if we have a vec in the input stream in the previous line.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typos:

Suggested change
// SAFTEY: We check if we have a vec in the input stream in the previous line.
// SAFETY: We check if we have a vec in the input stream in the previous line.

let vec = unsafe { ctx.rt.state.input_stream.pop_front().unwrap_unchecked() };

assert!(!ctx.rt.unconstrained, "hint read should not be used in a unconstrained block");
assert_eq!(vec.len() as u32, len, "hint input stream read length mismatch");
assert_eq!(ptr % 4, 0, "hint read address not aligned to 4 bytes");

// Iterate through the vec in 4-byte chunks
for i in (0..len).step_by(4) {
// Get each byte in the chunk
Expand All @@ -61,3 +52,9 @@ impl Syscall for HintReadSyscall {
None
}
}

fn next_len_or_panic(queue: &VecDeque<Vec<u8>>) -> u32 {
queue.front().map(|vec| vec.len() as u32).unwrap_or_else(|| {
panic!("Syscall Hint Len failed becasue the input stream is exhausted");

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typos:

Suggested change
panic!("Syscall Hint Len failed becasue the input stream is exhausted");
panic!("Syscall Hint Len failed because the input stream is exhausted");
++ b/patch-testing/keccak/src/lib.rs

})
}
11 changes: 7 additions & 4 deletions crates/core/executor/src/syscalls/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,15 @@ impl Syscall for WriteSyscall {
} else if fd == 3 {
rt.state.public_values_stream.extend_from_slice(slice);
} else if fd == 4 {
rt.state.input_stream.push(slice.to_vec());
rt.state.input_stream.push_back(slice.to_vec());
} else if let Some(mut hook) = rt.hook_registry.get(fd) {
let res = hook.invoke_hook(rt.hook_env(), slice);
// Add result vectors to the beginning of the stream.
let ptr = rt.state.input_stream_ptr;
rt.state.input_stream.splice(ptr..ptr, res);

// Write the items in reverse order to the input stream
// to preserve the expected order when reading
for item in res.into_iter().rev() {
rt.state.input_stream.push_front(item);
}
} else {
tracing::warn!("tried to write to unknown file descriptor {fd}");
}
Expand Down
2 changes: 1 addition & 1 deletion crates/curves/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ use std::{
};
use typenum::Unsigned;

use num::BigUint;
pub use num::{Integer, BigUint};
use serde::{de::DeserializeOwned, Serialize};

pub const NUM_WORDS_FIELD_ELEMENT: usize = 8;
Expand Down
2 changes: 2 additions & 0 deletions crates/zkvm/lib/src/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ pub const FD_K1_ECRECOVER_HOOK: u32 = 5;
pub const FD_R1_ECRECOVER_HOOK: u32 = 6;
/// The file descriptor through which to access `hook_ed_decompress`.
pub const FD_EDDECOMPRESS: u32 = 8;
/// The file descriptor through which to access `hook_rsa_mul_mod`.
pub const FD_RSA_MUL_MOD: u32 = 9;

/// A writer that writes to a file descriptor inside the zkVM.
struct SyscallWriter {
Expand Down
Loading