diff --git a/recursion/compiler/src/circuit/mod.rs b/recursion/compiler/src/circuit/mod.rs index 6098320b30..eb1a350d86 100644 --- a/recursion/compiler/src/circuit/mod.rs +++ b/recursion/compiler/src/circuit/mod.rs @@ -7,7 +7,7 @@ pub use compiler::*; #[cfg(test)] mod tests { use p3_baby_bear::DiffusionMatrixBabyBear; - use p3_field::AbstractExtensionField; + use p3_field::{AbstractExtensionField, AbstractField}; use rand::{rngs::StdRng, Rng, SeedableRng}; use sp1_core::{ stark::{Chip, StarkGenericConfig, StarkMachine, PROOF_MAX_NUM_PVS}, @@ -23,10 +23,14 @@ mod tests { poseidon2_wide::Poseidon2WideChip, }, machine::RecursionAir, - Runtime, + Runtime, RuntimeError, }; - use crate::{asm::AsmBuilder, circuit::AsmCompiler, ir::*}; + use crate::{ + asm::AsmBuilder, + circuit::{AsmCompiler, CircuitV2Builder}, + ir::*, + }; const DEGREE: usize = 3; @@ -119,4 +123,100 @@ mod tests { tracing::info!("num shard proofs: {}", result.shard_proofs.len()); } + + #[test] + fn test_io() { + let mut builder = AsmBuilder::::default(); + + let felts = builder.hint_felts_v2(3); + assert_eq!(felts.len(), 3); + let sum: Felt<_> = builder.eval(felts[0] + felts[1]); + builder.assert_felt_eq(sum, felts[2]); + + let exts = builder.hint_exts_v2(3); + assert_eq!(exts.len(), 3); + let sum: Ext<_, _> = builder.eval(exts[0] + exts[1]); + builder.assert_ext_ne(sum, exts[2]); + + let x = builder.hint_ext_v2(); + builder.assert_ext_eq(x, exts[0] + felts[0]); + + let y = builder.hint_felt_v2(); + let zero: Felt<_> = builder.constant(F::zero()); + builder.assert_felt_eq(y, zero); + + let operations = builder.operations; + let mut compiler = AsmCompiler::default(); + let program = compiler.compile(operations); + let mut runtime = Runtime::::new(&program, SC::new().perm); + runtime.witness_stream = [ + vec![F::one().into(), F::one().into(), F::two().into()], + vec![F::zero().into(), F::one().into(), F::two().into()], + vec![F::one().into()], + vec![F::zero().into()], + ] + .into(); + runtime.run().unwrap(); + + let machine = A::machine(SC::new()); + + let (pk, vk) = machine.setup(&program); + let result = + run_test_machine(vec![runtime.record], machine, pk, vk.clone()).expect("should verify"); + + tracing::info!("num shard proofs: {}", result.shard_proofs.len()); + } + + #[test] + fn test_empty_witness_stream() { + let mut builder = AsmBuilder::::default(); + + let felts = builder.hint_felts_v2(3); + assert_eq!(felts.len(), 3); + let sum: Felt<_> = builder.eval(felts[0] + felts[1]); + builder.assert_felt_eq(sum, felts[2]); + + let exts = builder.hint_exts_v2(3); + assert_eq!(exts.len(), 3); + let sum: Ext<_, _> = builder.eval(exts[0] + exts[1]); + builder.assert_ext_ne(sum, exts[2]); + + let operations = builder.operations; + let mut compiler = AsmCompiler::default(); + let program = compiler.compile(operations); + let mut runtime = Runtime::::new(&program, SC::new().perm); + runtime.witness_stream = [vec![F::one().into(), F::one().into(), F::two().into()]].into(); + + match runtime.run() { + Err(RuntimeError::EmptyWitnessStream) => (), + Ok(_) => panic!("should not succeed"), + Err(x) => panic!("should not yield error variant: {}", x), + } + } + + #[test] + fn test_mismatched_witness_size() { + const MEM_VEC_LEN: usize = 3; + const WITNESS_LEN: usize = 5; + + let mut builder = AsmBuilder::::default(); + + let felts = builder.hint_felts_v2(MEM_VEC_LEN); + assert_eq!(felts.len(), MEM_VEC_LEN); + + let operations = builder.operations; + let mut compiler = AsmCompiler::default(); + let program = compiler.compile(operations); + let mut runtime = Runtime::::new(&program, SC::new().perm); + runtime.witness_stream = [vec![F::zero().into(); WITNESS_LEN]].into(); + + match runtime.run() { + Err(RuntimeError::WitnessLenMismatch { + mem_vec_len: MEM_VEC_LEN, + witness_len: WITNESS_LEN, + }) => (), + Ok(_) => panic!("should not succeed"), + Err(x) => panic!("should not yield error variant: {}", x), + } + } } diff --git a/recursion/core-v2/src/runtime/mod.rs b/recursion/core-v2/src/runtime/mod.rs index cd2ffc6be3..5dca4059a7 100644 --- a/recursion/core-v2/src/runtime/mod.rs +++ b/recursion/core-v2/src/runtime/mod.rs @@ -158,8 +158,8 @@ pub enum RuntimeError { }, #[error("failed to print to `debug_stdout`: {0}")] DebugPrint(#[from] std::io::Error), - #[error("attempted to read vec of {0:?} from empty witness stream")] - EmptyWitnessStream(FieldEltType), + #[error("attempted to read from empty witness stream")] + EmptyWitnessStream, #[error("attempted to write to memory vec of len {mem_vec_len} witness of size {witness_len}")] WitnessLenMismatch { mem_vec_len: usize, @@ -572,7 +572,7 @@ where let witness = self .witness_stream .pop_front() - .ok_or(RuntimeError::EmptyWitnessStream(FieldEltType::Base))?; + .ok_or(RuntimeError::EmptyWitnessStream)?; // Check the lengths are the same. if output_addrs_mults.len() != witness.len() { return Err(RuntimeError::WitnessLenMismatch {