From d11e8665fbf0f24c024f9a820d9bf3605724068d Mon Sep 17 00:00:00 2001 From: Sagar Dhawan Date: Thu, 31 Oct 2024 11:29:30 -0700 Subject: [PATCH] fix: add memory_layout to preprocessing so verifier doesn't rely on the prover --- jolt-core/src/benches/bench.rs | 18 ++++++- jolt-core/src/host/mod.rs | 4 +- jolt-core/src/jolt/vm/mod.rs | 41 +++++++++++---- jolt-core/src/jolt/vm/read_write_memory.rs | 6 +-- jolt-core/src/jolt/vm/rv32i_vm.rs | 60 +++++++++++++++++----- jolt-sdk/macros/src/lib.rs | 2 + 6 files changed, 103 insertions(+), 28 deletions(-) diff --git a/jolt-core/src/benches/bench.rs b/jolt-core/src/benches/bench.rs index f4f6da289..d9596b984 100644 --- a/jolt-core/src/benches/bench.rs +++ b/jolt-core/src/benches/bench.rs @@ -131,7 +131,14 @@ where let (io_device, trace) = program.trace(); let preprocessing: crate::jolt::vm::JoltPreprocessing = - RV32IJoltVM::preprocess(bytecode.clone(), memory_init, 1 << 20, 1 << 20, 1 << 22); + RV32IJoltVM::preprocess( + bytecode.clone(), + io_device.memory_layout.clone(), + memory_init, + 1 << 20, + 1 << 20, + 1 << 22, + ); let (jolt_proof, jolt_commitments, _) = >::prove( @@ -187,7 +194,14 @@ where let (io_device, trace) = program.trace(); let preprocessing: crate::jolt::vm::JoltPreprocessing = - RV32IJoltVM::preprocess(bytecode.clone(), memory_init, 1 << 20, 1 << 20, 1 << 22); + RV32IJoltVM::preprocess( + bytecode.clone(), + io_device.memory_layout.clone(), + memory_init, + 1 << 20, + 1 << 20, + 1 << 22, + ); let (jolt_proof, jolt_commitments, _) = >::prove( diff --git a/jolt-core/src/host/mod.rs b/jolt-core/src/host/mod.rs index 42739549b..2cd534e25 100644 --- a/jolt-core/src/host/mod.rs +++ b/jolt-core/src/host/mod.rs @@ -181,9 +181,9 @@ impl Program { // TODO(moodlezoup): Make this generic over InstructionSet #[tracing::instrument(skip_all, name = "Program::trace")] - pub fn trace(mut self) -> (JoltDevice, Vec>) { + pub fn trace(&mut self) -> (JoltDevice, Vec>) { self.build(); - let elf = self.elf.unwrap(); + let elf = self.elf.clone().unwrap(); let (raw_trace, io_device) = tracer::trace(&elf, &self.input, self.max_input_size, self.max_output_size); diff --git a/jolt-core/src/jolt/vm/mod.rs b/jolt-core/src/jolt/vm/mod.rs index 2216ecf5c..2a1f6ea8d 100644 --- a/jolt-core/src/jolt/vm/mod.rs +++ b/jolt-core/src/jolt/vm/mod.rs @@ -9,7 +9,7 @@ use crate::r1cs::constraints::R1CSConstraints; use crate::r1cs::spartan::{self, UniformSpartanProof}; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; use common::constants::RAM_START_ADDRESS; -use common::rv_trace::NUM_CIRCUIT_FLAGS; +use common::rv_trace::{MemoryLayout, NUM_CIRCUIT_FLAGS}; use serde::{Deserialize, Serialize}; use std::marker::PhantomData; use strum::EnumCount; @@ -60,6 +60,7 @@ where pub instruction_lookups: InstructionLookupsPreprocessing, pub bytecode: BytecodePreprocessing, pub read_write_memory: ReadWriteMemoryPreprocessing, + pub memory_layout: MemoryLayout, } #[derive(Clone, Serialize, Deserialize, Debug)] @@ -276,6 +277,7 @@ where #[tracing::instrument(skip_all, name = "Jolt::preprocess")] fn preprocess( bytecode: Vec, + memory_layout: MemoryLayout, memory_init: Vec<(u64, u8)>, max_bytecode_size: usize, max_memory_address: usize, @@ -336,6 +338,7 @@ where JoltPreprocessing { generators, + memory_layout, instruction_lookups: instruction_lookups_preprocessing, bytecode: bytecode_preprocessing, read_write_memory: read_write_memory_preprocessing, @@ -368,7 +371,12 @@ where JoltTraceStep::pad(&mut trace); let mut transcript = ProofTranscript::new(b"Jolt transcript"); - Self::fiat_shamir_preamble(&mut transcript, &program_io, trace_length); + Self::fiat_shamir_preamble( + &mut transcript, + &program_io, + &program_io.memory_layout, + trace_length, + ); let instruction_polynomials = InstructionLookupsProof::< @@ -539,11 +547,16 @@ where opening_accumulator .compare_to(debug_info.opening_accumulator, &preprocessing.generators); } - Self::fiat_shamir_preamble(&mut transcript, &proof.program_io, proof.trace_length); + Self::fiat_shamir_preamble( + &mut transcript, + &proof.program_io, + &preprocessing.memory_layout, + proof.trace_length, + ); // Regenerate the uniform Spartan key let padded_trace_length = proof.trace_length.next_power_of_two(); - let memory_start = RAM_START_ADDRESS - proof.program_io.memory_layout.ram_witness_offset; + let memory_start = RAM_START_ADDRESS - preprocessing.memory_layout.ram_witness_offset; let r1cs_builder = Self::Constraints::construct_constraints(padded_trace_length, memory_start); let spartan_key = spartan::UniformSpartanProof::::setup( @@ -586,6 +599,7 @@ where Self::verify_memory( &mut preprocessing.read_write_memory, &preprocessing.generators, + &preprocessing.memory_layout, proof.read_write_memory, &commitments, proof.program_io, @@ -657,19 +671,27 @@ where ) } + #[allow(clippy::too_many_arguments)] #[tracing::instrument(skip_all)] fn verify_memory<'a>( preprocessing: &mut ReadWriteMemoryPreprocessing, generators: &PCS::Setup, + memory_layout: &MemoryLayout, proof: ReadWriteMemoryProof, commitment: &'a JoltCommitments, program_io: JoltDevice, opening_accumulator: &mut VerifierOpeningAccumulator, transcript: &mut ProofTranscript, ) -> Result<(), ProofVerifyError> { - assert!(program_io.inputs.len() <= program_io.memory_layout.max_input_size as usize); - assert!(program_io.outputs.len() <= program_io.memory_layout.max_output_size as usize); - preprocessing.program_io = Some(program_io); + assert!(program_io.inputs.len() <= memory_layout.max_input_size as usize); + assert!(program_io.outputs.len() <= memory_layout.max_output_size as usize); + // pair the memory layout with the program io from the proof + preprocessing.program_io = Some(JoltDevice { + inputs: program_io.inputs, + outputs: program_io.outputs, + panic: program_io.panic, + memory_layout: memory_layout.clone(), + }); ReadWriteMemoryProof::verify( proof, @@ -701,6 +723,7 @@ where fn fiat_shamir_preamble( transcript: &mut ProofTranscript, program_io: &JoltDevice, + memory_layout: &MemoryLayout, trace_length: usize, ) { transcript.append_u64(trace_length as u64); @@ -708,8 +731,8 @@ where transcript.append_u64(M as u64); transcript.append_u64(Self::InstructionSet::COUNT as u64); transcript.append_u64(Self::Subtables::COUNT as u64); - transcript.append_u64(program_io.memory_layout.max_input_size); - transcript.append_u64(program_io.memory_layout.max_output_size); + transcript.append_u64(memory_layout.max_input_size); + transcript.append_u64(memory_layout.max_output_size); transcript.append_bytes(&program_io.inputs); transcript.append_bytes(&program_io.outputs); transcript.append_u64(program_io.panic as u64); diff --git a/jolt-core/src/jolt/vm/read_write_memory.rs b/jolt-core/src/jolt/vm/read_write_memory.rs index ca41b5540..423103174 100644 --- a/jolt-core/src/jolt/vm/read_write_memory.rs +++ b/jolt-core/src/jolt/vm/read_write_memory.rs @@ -39,9 +39,9 @@ use super::{JoltPolynomials, JoltStuff, JoltTraceStep}; pub struct ReadWriteMemoryPreprocessing { min_bytecode_address: u64, pub bytecode_bytes: Vec, - // HACK: The verifier will populate this field by copying it - // over from the `ReadWriteMemoryProof`. Having `program_io` in - // this preprocessing struct allows the verifier to access it + // HACK: The verifier will populate this field by copying inputs/outputs from the + // `ReadWriteMemoryProof` and the memory layout from preprocessing. + // Having `program_io` in this preprocessing struct allows the verifier to access it // to compute the v_init and v_final openings, with no impact // on existing function signatures. pub program_io: Option, diff --git a/jolt-core/src/jolt/vm/rv32i_vm.rs b/jolt-core/src/jolt/vm/rv32i_vm.rs index ba09b6787..f27cb0f8c 100644 --- a/jolt-core/src/jolt/vm/rv32i_vm.rs +++ b/jolt-core/src/jolt/vm/rv32i_vm.rs @@ -312,8 +312,14 @@ mod tests { let (io_device, trace) = program.trace(); drop(artifact_guard); - let preprocessing = - RV32IJoltVM::preprocess(bytecode.clone(), memory_init, 1 << 20, 1 << 20, 1 << 20); + let preprocessing = RV32IJoltVM::preprocess( + bytecode.clone(), + io_device.memory_layout.clone(), + memory_init, + 1 << 20, + 1 << 20, + 1 << 20, + ); let (proof, commitments, debug_info) = >::prove( io_device, @@ -371,8 +377,14 @@ mod tests { let (bytecode, memory_init) = program.decode(); let (io_device, trace) = program.trace(); - let preprocessing = - RV32IJoltVM::preprocess(bytecode.clone(), memory_init, 1 << 20, 1 << 20, 1 << 20); + let preprocessing = RV32IJoltVM::preprocess( + bytecode.clone(), + io_device.memory_layout.clone(), + memory_init, + 1 << 20, + 1 << 20, + 1 << 20, + ); let (jolt_proof, jolt_commitments, debug_info) = , @@ -462,8 +486,14 @@ mod tests { let (io_device, trace) = program.trace(); drop(guard); - let preprocessing = - RV32IJoltVM::preprocess(bytecode.clone(), memory_init, 1 << 20, 1 << 20, 1 << 20); + let preprocessing = RV32IJoltVM::preprocess( + bytecode.clone(), + io_device.memory_layout.clone(), + memory_init, + 1 << 20, + 1 << 20, + 1 << 20, + ); let (jolt_proof, jolt_commitments, debug_info) = , @@ -495,8 +525,14 @@ mod tests { io_device.outputs[0] = 0; // change the output to 0 drop(artifact_guard); - let preprocessing = - RV32IJoltVM::preprocess(bytecode.clone(), memory_init, 1 << 20, 1 << 20, 1 << 20); + let preprocessing = RV32IJoltVM::preprocess( + bytecode.clone(), + io_device.memory_layout.clone(), + memory_init, + 1 << 20, + 1 << 20, + 1 << 20, + ); let (proof, commitments, debug_info) = , diff --git a/jolt-sdk/macros/src/lib.rs b/jolt-sdk/macros/src/lib.rs index 3dba83691..cad24fa9c 100644 --- a/jolt-sdk/macros/src/lib.rs +++ b/jolt-sdk/macros/src/lib.rs @@ -199,11 +199,13 @@ impl MacroBuilder { #set_std #set_mem_size let (bytecode, memory_init) = program.decode(); + let (io_device, _trace) = program.trace(); // TODO(moodlezoup): Feed in size parameters via macro let preprocessing: JoltPreprocessing<4, jolt::F, jolt::PCS, jolt::ProofTranscript> = RV32IJoltVM::preprocess( bytecode, + io_device.memory_layout.clone(), memory_init, 1 << 20, 1 << 20,