From dbb76f33aefebf5ca20b902aef8db0079d7c0b84 Mon Sep 17 00:00:00 2001 From: Sagar Dhawan <107961892+sagar-a16z@users.noreply.github.com> Date: Fri, 1 Nov 2024 12:01:03 -0400 Subject: [PATCH] fix: add memory_layout to preprocessing so verifier doesn't rely on the prover (#494) * fix: add memory_layout to preprocessing so verifier doesn't rely on the prover * add a test to show that the proof's memory_layout is not used * fmt * setup memory_layout based on attributes not program.trace() * remove unnecessary assertions in tests * build warnings --- jolt-core/src/benches/bench.rs | 18 +++- jolt-core/src/host/mod.rs | 4 +- jolt-core/src/jolt/vm/mod.rs | 41 ++++++-- jolt-core/src/jolt/vm/read_write_memory.rs | 6 +- jolt-core/src/jolt/vm/rv32i_vm.rs | 104 +++++++++++++++++---- jolt-sdk/macros/src/lib.rs | 6 ++ jolt-sdk/src/host_utils.rs | 2 +- 7 files changed, 147 insertions(+), 34 deletions(-) diff --git a/jolt-core/src/benches/bench.rs b/jolt-core/src/benches/bench.rs index f4f6da289..d9596b984 100644 --- a/jolt-core/src/benches/bench.rs +++ b/jolt-core/src/benches/bench.rs @@ -131,7 +131,14 @@ where let (io_device, trace) = program.trace(); let preprocessing: crate::jolt::vm::JoltPreprocessing = - RV32IJoltVM::preprocess(bytecode.clone(), memory_init, 1 << 20, 1 << 20, 1 << 22); + RV32IJoltVM::preprocess( + bytecode.clone(), + io_device.memory_layout.clone(), + memory_init, + 1 << 20, + 1 << 20, + 1 << 22, + ); let (jolt_proof, jolt_commitments, _) = >::prove( @@ -187,7 +194,14 @@ where let (io_device, trace) = program.trace(); let preprocessing: crate::jolt::vm::JoltPreprocessing = - RV32IJoltVM::preprocess(bytecode.clone(), memory_init, 1 << 20, 1 << 20, 1 << 22); + RV32IJoltVM::preprocess( + bytecode.clone(), + io_device.memory_layout.clone(), + memory_init, + 1 << 20, + 1 << 20, + 1 << 22, + ); let (jolt_proof, jolt_commitments, _) = >::prove( diff --git a/jolt-core/src/host/mod.rs b/jolt-core/src/host/mod.rs index 42739549b..2cd534e25 100644 --- a/jolt-core/src/host/mod.rs +++ b/jolt-core/src/host/mod.rs @@ -181,9 +181,9 @@ impl Program { // TODO(moodlezoup): Make this generic over InstructionSet #[tracing::instrument(skip_all, name = "Program::trace")] - pub fn trace(mut self) -> (JoltDevice, Vec>) { + pub fn trace(&mut self) -> (JoltDevice, Vec>) { self.build(); - let elf = self.elf.unwrap(); + let elf = self.elf.clone().unwrap(); let (raw_trace, io_device) = tracer::trace(&elf, &self.input, self.max_input_size, self.max_output_size); diff --git a/jolt-core/src/jolt/vm/mod.rs b/jolt-core/src/jolt/vm/mod.rs index 2216ecf5c..2a1f6ea8d 100644 --- a/jolt-core/src/jolt/vm/mod.rs +++ b/jolt-core/src/jolt/vm/mod.rs @@ -9,7 +9,7 @@ use crate::r1cs::constraints::R1CSConstraints; use crate::r1cs::spartan::{self, UniformSpartanProof}; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; use common::constants::RAM_START_ADDRESS; -use common::rv_trace::NUM_CIRCUIT_FLAGS; +use common::rv_trace::{MemoryLayout, NUM_CIRCUIT_FLAGS}; use serde::{Deserialize, Serialize}; use std::marker::PhantomData; use strum::EnumCount; @@ -60,6 +60,7 @@ where pub instruction_lookups: InstructionLookupsPreprocessing, pub bytecode: BytecodePreprocessing, pub read_write_memory: ReadWriteMemoryPreprocessing, + pub memory_layout: MemoryLayout, } #[derive(Clone, Serialize, Deserialize, Debug)] @@ -276,6 +277,7 @@ where #[tracing::instrument(skip_all, name = "Jolt::preprocess")] fn preprocess( bytecode: Vec, + memory_layout: MemoryLayout, memory_init: Vec<(u64, u8)>, max_bytecode_size: usize, max_memory_address: usize, @@ -336,6 +338,7 @@ where JoltPreprocessing { generators, + memory_layout, instruction_lookups: instruction_lookups_preprocessing, bytecode: bytecode_preprocessing, read_write_memory: read_write_memory_preprocessing, @@ -368,7 +371,12 @@ where JoltTraceStep::pad(&mut trace); let mut transcript = ProofTranscript::new(b"Jolt transcript"); - Self::fiat_shamir_preamble(&mut transcript, &program_io, trace_length); + Self::fiat_shamir_preamble( + &mut transcript, + &program_io, + &program_io.memory_layout, + trace_length, + ); let instruction_polynomials = InstructionLookupsProof::< @@ -539,11 +547,16 @@ where opening_accumulator .compare_to(debug_info.opening_accumulator, &preprocessing.generators); } - Self::fiat_shamir_preamble(&mut transcript, &proof.program_io, proof.trace_length); + Self::fiat_shamir_preamble( + &mut transcript, + &proof.program_io, + &preprocessing.memory_layout, + proof.trace_length, + ); // Regenerate the uniform Spartan key let padded_trace_length = proof.trace_length.next_power_of_two(); - let memory_start = RAM_START_ADDRESS - proof.program_io.memory_layout.ram_witness_offset; + let memory_start = RAM_START_ADDRESS - preprocessing.memory_layout.ram_witness_offset; let r1cs_builder = Self::Constraints::construct_constraints(padded_trace_length, memory_start); let spartan_key = spartan::UniformSpartanProof::::setup( @@ -586,6 +599,7 @@ where Self::verify_memory( &mut preprocessing.read_write_memory, &preprocessing.generators, + &preprocessing.memory_layout, proof.read_write_memory, &commitments, proof.program_io, @@ -657,19 +671,27 @@ where ) } + #[allow(clippy::too_many_arguments)] #[tracing::instrument(skip_all)] fn verify_memory<'a>( preprocessing: &mut ReadWriteMemoryPreprocessing, generators: &PCS::Setup, + memory_layout: &MemoryLayout, proof: ReadWriteMemoryProof, commitment: &'a JoltCommitments, program_io: JoltDevice, opening_accumulator: &mut VerifierOpeningAccumulator, transcript: &mut ProofTranscript, ) -> Result<(), ProofVerifyError> { - assert!(program_io.inputs.len() <= program_io.memory_layout.max_input_size as usize); - assert!(program_io.outputs.len() <= program_io.memory_layout.max_output_size as usize); - preprocessing.program_io = Some(program_io); + assert!(program_io.inputs.len() <= memory_layout.max_input_size as usize); + assert!(program_io.outputs.len() <= memory_layout.max_output_size as usize); + // pair the memory layout with the program io from the proof + preprocessing.program_io = Some(JoltDevice { + inputs: program_io.inputs, + outputs: program_io.outputs, + panic: program_io.panic, + memory_layout: memory_layout.clone(), + }); ReadWriteMemoryProof::verify( proof, @@ -701,6 +723,7 @@ where fn fiat_shamir_preamble( transcript: &mut ProofTranscript, program_io: &JoltDevice, + memory_layout: &MemoryLayout, trace_length: usize, ) { transcript.append_u64(trace_length as u64); @@ -708,8 +731,8 @@ where transcript.append_u64(M as u64); transcript.append_u64(Self::InstructionSet::COUNT as u64); transcript.append_u64(Self::Subtables::COUNT as u64); - transcript.append_u64(program_io.memory_layout.max_input_size); - transcript.append_u64(program_io.memory_layout.max_output_size); + transcript.append_u64(memory_layout.max_input_size); + transcript.append_u64(memory_layout.max_output_size); transcript.append_bytes(&program_io.inputs); transcript.append_bytes(&program_io.outputs); transcript.append_u64(program_io.panic as u64); diff --git a/jolt-core/src/jolt/vm/read_write_memory.rs b/jolt-core/src/jolt/vm/read_write_memory.rs index ca41b5540..423103174 100644 --- a/jolt-core/src/jolt/vm/read_write_memory.rs +++ b/jolt-core/src/jolt/vm/read_write_memory.rs @@ -39,9 +39,9 @@ use super::{JoltPolynomials, JoltStuff, JoltTraceStep}; pub struct ReadWriteMemoryPreprocessing { min_bytecode_address: u64, pub bytecode_bytes: Vec, - // HACK: The verifier will populate this field by copying it - // over from the `ReadWriteMemoryProof`. Having `program_io` in - // this preprocessing struct allows the verifier to access it + // HACK: The verifier will populate this field by copying inputs/outputs from the + // `ReadWriteMemoryProof` and the memory layout from preprocessing. + // Having `program_io` in this preprocessing struct allows the verifier to access it // to compute the v_init and v_final openings, with no impact // on existing function signatures. pub program_io: Option, diff --git a/jolt-core/src/jolt/vm/rv32i_vm.rs b/jolt-core/src/jolt/vm/rv32i_vm.rs index ba09b6787..ca6f7ebff 100644 --- a/jolt-core/src/jolt/vm/rv32i_vm.rs +++ b/jolt-core/src/jolt/vm/rv32i_vm.rs @@ -312,8 +312,14 @@ mod tests { let (io_device, trace) = program.trace(); drop(artifact_guard); - let preprocessing = - RV32IJoltVM::preprocess(bytecode.clone(), memory_init, 1 << 20, 1 << 20, 1 << 20); + let preprocessing = RV32IJoltVM::preprocess( + bytecode.clone(), + io_device.memory_layout.clone(), + memory_init, + 1 << 20, + 1 << 20, + 1 << 20, + ); let (proof, commitments, debug_info) = >::prove( io_device, @@ -371,8 +377,14 @@ mod tests { let (bytecode, memory_init) = program.decode(); let (io_device, trace) = program.trace(); - let preprocessing = - RV32IJoltVM::preprocess(bytecode.clone(), memory_init, 1 << 20, 1 << 20, 1 << 20); + let preprocessing = RV32IJoltVM::preprocess( + bytecode.clone(), + io_device.memory_layout.clone(), + memory_init, + 1 << 20, + 1 << 20, + 1 << 20, + ); let (jolt_proof, jolt_commitments, debug_info) = , @@ -462,8 +486,14 @@ mod tests { let (io_device, trace) = program.trace(); drop(guard); - let preprocessing = - RV32IJoltVM::preprocess(bytecode.clone(), memory_init, 1 << 20, 1 << 20, 1 << 20); + let preprocessing = RV32IJoltVM::preprocess( + bytecode.clone(), + io_device.memory_layout.clone(), + memory_init, + 1 << 20, + 1 << 20, + 1 << 20, + ); let (jolt_proof, jolt_commitments, debug_info) = , @@ -495,8 +525,14 @@ mod tests { io_device.outputs[0] = 0; // change the output to 0 drop(artifact_guard); - let preprocessing = - RV32IJoltVM::preprocess(bytecode.clone(), memory_init, 1 << 20, 1 << 20, 1 << 20); + let preprocessing = RV32IJoltVM::preprocess( + bytecode.clone(), + io_device.memory_layout.clone(), + memory_init, + 1 << 20, + 1 << 20, + 1 << 20, + ); let (proof, commitments, debug_info) = , @@ -506,12 +542,46 @@ mod tests { >>::prove( io_device, trace, preprocessing.clone() ); - let verification_result = + let _verification_result = RV32IJoltVM::verify(preprocessing, proof, commitments, debug_info); - assert!( - verification_result.is_ok(), - "Verification failed with error: {:?}", - verification_result.err() + } + + #[test] + #[should_panic] + fn malicious_trace() { + let artifact_guard = FIB_FILE_LOCK.lock().unwrap(); + let mut program = host::Program::new("fibonacci-guest"); + program.set_input(&1u8); // change input to 1 so that termination bit equal true + let (bytecode, memory_init) = program.decode(); + let (mut io_device, trace) = program.trace(); + let memory_layout = io_device.memory_layout.clone(); + drop(artifact_guard); + + // change memory address of output & termination bit to the same address as input + // changes here should not be able to spoof the verifier result + io_device.memory_layout.output_start = io_device.memory_layout.input_start; + io_device.memory_layout.output_end = io_device.memory_layout.input_end; + io_device.memory_layout.termination = io_device.memory_layout.input_start; + + // Since the preprocessing is done with the original memory layout, the verifier should fail + let preprocessing = RV32IJoltVM::preprocess( + bytecode.clone(), + memory_layout, + memory_init, + 1 << 20, + 1 << 20, + 1 << 20, ); + let (proof, commitments, debug_info) = , + C, + M, + KeccakTranscript, + >>::prove( + io_device, trace, preprocessing.clone() + ); + let _verification_result = + RV32IJoltVM::verify(preprocessing, proof, commitments, debug_info); } } diff --git a/jolt-sdk/macros/src/lib.rs b/jolt-sdk/macros/src/lib.rs index 3dba83691..e47552b87 100644 --- a/jolt-sdk/macros/src/lib.rs +++ b/jolt-sdk/macros/src/lib.rs @@ -178,6 +178,9 @@ impl MacroBuilder { } fn make_preprocess_func(&self) -> TokenStream2 { + let attributes = parse_attributes(&self.attr); + let max_input_size = proc_macro2::Literal::u64_unsuffixed(attributes.max_input_size); + let max_output_size = proc_macro2::Literal::u64_unsuffixed(attributes.max_output_size); let set_mem_size = self.make_set_linker_parameters(); let guest_name = self.get_guest_name(); let imports = self.make_imports(); @@ -199,11 +202,13 @@ impl MacroBuilder { #set_std #set_mem_size let (bytecode, memory_init) = program.decode(); + let memory_layout = MemoryLayout::new(#max_input_size, #max_output_size); // TODO(moodlezoup): Feed in size parameters via macro let preprocessing: JoltPreprocessing<4, jolt::F, jolt::PCS, jolt::ProofTranscript> = RV32IJoltVM::preprocess( bytecode, + memory_layout, memory_init, 1 << 20, 1 << 20, @@ -409,6 +414,7 @@ impl MacroBuilder { RV32IJoltProof, BytecodeRow, MemoryOp, + MemoryLayout, MEMORY_OPS_PER_INSTRUCTION, instruction::add::ADDInstruction, tracer, diff --git a/jolt-sdk/src/host_utils.rs b/jolt-sdk/src/host_utils.rs index deadb4304..7d73c3aab 100644 --- a/jolt-sdk/src/host_utils.rs +++ b/jolt-sdk/src/host_utils.rs @@ -4,7 +4,7 @@ pub use jolt_core::{field::JoltField, poly::commitment::hyperkzg::HyperKZG}; pub use common::{ constants::MEMORY_OPS_PER_INSTRUCTION, - rv_trace::{MemoryOp, RV32IM}, + rv_trace::{MemoryLayout, MemoryOp, RV32IM}, }; pub use jolt_core::host; pub use jolt_core::jolt::instruction;