Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: add memory_layout to preprocessing so verifier doesn't rely on the prover #494

Merged
merged 6 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions jolt-core/src/benches/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,14 @@ where
let (io_device, trace) = program.trace();

let preprocessing: crate::jolt::vm::JoltPreprocessing<C, F, PCS, ProofTranscript> =
RV32IJoltVM::preprocess(bytecode.clone(), memory_init, 1 << 20, 1 << 20, 1 << 22);
RV32IJoltVM::preprocess(
bytecode.clone(),
io_device.memory_layout.clone(),
memory_init,
1 << 20,
1 << 20,
1 << 22,
);

let (jolt_proof, jolt_commitments, _) =
<RV32IJoltVM as Jolt<_, PCS, C, M, ProofTranscript>>::prove(
Expand Down Expand Up @@ -187,7 +194,14 @@ where
let (io_device, trace) = program.trace();

let preprocessing: crate::jolt::vm::JoltPreprocessing<C, F, PCS, ProofTranscript> =
RV32IJoltVM::preprocess(bytecode.clone(), memory_init, 1 << 20, 1 << 20, 1 << 22);
RV32IJoltVM::preprocess(
bytecode.clone(),
io_device.memory_layout.clone(),
memory_init,
1 << 20,
1 << 20,
1 << 22,
);

let (jolt_proof, jolt_commitments, _) =
<RV32IJoltVM as Jolt<_, PCS, C, M, ProofTranscript>>::prove(
Expand Down
4 changes: 2 additions & 2 deletions jolt-core/src/host/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,9 @@ impl Program {

// TODO(moodlezoup): Make this generic over InstructionSet
#[tracing::instrument(skip_all, name = "Program::trace")]
pub fn trace(mut self) -> (JoltDevice, Vec<JoltTraceStep<RV32I>>) {
pub fn trace(&mut self) -> (JoltDevice, Vec<JoltTraceStep<RV32I>>) {
self.build();
let elf = self.elf.unwrap();
let elf = self.elf.clone().unwrap();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we shouldn't need this clone, we pass in elf by reference below

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Self.elf is an Option<PathBuf> the clone is required to unwrap it. Previously this API would consume itself trace(mut self), which seemed weird to me so I changed it to trace(&mut self) and that means we can't just pull elf out of Self anymore since we just have a ref to Self.

Does that make sense?

let (raw_trace, io_device) =
tracer::trace(&elf, &self.input, self.max_input_size, self.max_output_size);

Expand Down
41 changes: 32 additions & 9 deletions jolt-core/src/jolt/vm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::r1cs::constraints::R1CSConstraints;
use crate::r1cs::spartan::{self, UniformSpartanProof};
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
use common::constants::RAM_START_ADDRESS;
use common::rv_trace::NUM_CIRCUIT_FLAGS;
use common::rv_trace::{MemoryLayout, NUM_CIRCUIT_FLAGS};
use serde::{Deserialize, Serialize};
use std::marker::PhantomData;
use strum::EnumCount;
Expand Down Expand Up @@ -60,6 +60,7 @@ where
pub instruction_lookups: InstructionLookupsPreprocessing<C, F>,
pub bytecode: BytecodePreprocessing<F>,
pub read_write_memory: ReadWriteMemoryPreprocessing,
pub memory_layout: MemoryLayout,
}

#[derive(Clone, Serialize, Deserialize, Debug)]
Expand Down Expand Up @@ -276,6 +277,7 @@ where
#[tracing::instrument(skip_all, name = "Jolt::preprocess")]
fn preprocess(
bytecode: Vec<ELFInstruction>,
memory_layout: MemoryLayout,
memory_init: Vec<(u64, u8)>,
max_bytecode_size: usize,
max_memory_address: usize,
Expand Down Expand Up @@ -336,6 +338,7 @@ where

JoltPreprocessing {
generators,
memory_layout,
instruction_lookups: instruction_lookups_preprocessing,
bytecode: bytecode_preprocessing,
read_write_memory: read_write_memory_preprocessing,
Expand Down Expand Up @@ -368,7 +371,12 @@ where
JoltTraceStep::pad(&mut trace);

let mut transcript = ProofTranscript::new(b"Jolt transcript");
Self::fiat_shamir_preamble(&mut transcript, &program_io, trace_length);
Self::fiat_shamir_preamble(
&mut transcript,
&program_io,
&program_io.memory_layout,
trace_length,
);

let instruction_polynomials =
InstructionLookupsProof::<
Expand Down Expand Up @@ -539,11 +547,16 @@ where
opening_accumulator
.compare_to(debug_info.opening_accumulator, &preprocessing.generators);
}
Self::fiat_shamir_preamble(&mut transcript, &proof.program_io, proof.trace_length);
Self::fiat_shamir_preamble(
&mut transcript,
&proof.program_io,
&preprocessing.memory_layout,
proof.trace_length,
);

// Regenerate the uniform Spartan key
let padded_trace_length = proof.trace_length.next_power_of_two();
let memory_start = RAM_START_ADDRESS - proof.program_io.memory_layout.ram_witness_offset;
let memory_start = RAM_START_ADDRESS - preprocessing.memory_layout.ram_witness_offset;
let r1cs_builder =
Self::Constraints::construct_constraints(padded_trace_length, memory_start);
let spartan_key = spartan::UniformSpartanProof::<C, _, F, ProofTranscript>::setup(
Expand Down Expand Up @@ -586,6 +599,7 @@ where
Self::verify_memory(
&mut preprocessing.read_write_memory,
&preprocessing.generators,
&preprocessing.memory_layout,
proof.read_write_memory,
&commitments,
proof.program_io,
Expand Down Expand Up @@ -657,19 +671,27 @@ where
)
}

#[allow(clippy::too_many_arguments)]
#[tracing::instrument(skip_all)]
fn verify_memory<'a>(
preprocessing: &mut ReadWriteMemoryPreprocessing,
generators: &PCS::Setup,
memory_layout: &MemoryLayout,
proof: ReadWriteMemoryProof<F, PCS, ProofTranscript>,
commitment: &'a JoltCommitments<PCS, ProofTranscript>,
program_io: JoltDevice,
opening_accumulator: &mut VerifierOpeningAccumulator<F, PCS, ProofTranscript>,
transcript: &mut ProofTranscript,
) -> Result<(), ProofVerifyError> {
assert!(program_io.inputs.len() <= program_io.memory_layout.max_input_size as usize);
assert!(program_io.outputs.len() <= program_io.memory_layout.max_output_size as usize);
preprocessing.program_io = Some(program_io);
assert!(program_io.inputs.len() <= memory_layout.max_input_size as usize);
assert!(program_io.outputs.len() <= memory_layout.max_output_size as usize);
// pair the memory layout with the program io from the proof
preprocessing.program_io = Some(JoltDevice {
inputs: program_io.inputs,
outputs: program_io.outputs,
panic: program_io.panic,
memory_layout: memory_layout.clone(),
});

ReadWriteMemoryProof::verify(
proof,
Expand Down Expand Up @@ -701,15 +723,16 @@ where
fn fiat_shamir_preamble(
transcript: &mut ProofTranscript,
program_io: &JoltDevice,
memory_layout: &MemoryLayout,
trace_length: usize,
) {
transcript.append_u64(trace_length as u64);
transcript.append_u64(C as u64);
transcript.append_u64(M as u64);
transcript.append_u64(Self::InstructionSet::COUNT as u64);
transcript.append_u64(Self::Subtables::COUNT as u64);
transcript.append_u64(program_io.memory_layout.max_input_size);
transcript.append_u64(program_io.memory_layout.max_output_size);
transcript.append_u64(memory_layout.max_input_size);
transcript.append_u64(memory_layout.max_output_size);
transcript.append_bytes(&program_io.inputs);
transcript.append_bytes(&program_io.outputs);
transcript.append_u64(program_io.panic as u64);
Expand Down
6 changes: 3 additions & 3 deletions jolt-core/src/jolt/vm/read_write_memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ use super::{JoltPolynomials, JoltStuff, JoltTraceStep};
pub struct ReadWriteMemoryPreprocessing {
min_bytecode_address: u64,
pub bytecode_bytes: Vec<u8>,
// HACK: The verifier will populate this field by copying it
// over from the `ReadWriteMemoryProof`. Having `program_io` in
// this preprocessing struct allows the verifier to access it
// HACK: The verifier will populate this field by copying inputs/outputs from the
// `ReadWriteMemoryProof` and the memory layout from preprocessing.
// Having `program_io` in this preprocessing struct allows the verifier to access it
// to compute the v_init and v_final openings, with no impact
// on existing function signatures.
pub program_io: Option<JoltDevice>,
Expand Down
104 changes: 87 additions & 17 deletions jolt-core/src/jolt/vm/rv32i_vm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -312,8 +312,14 @@ mod tests {
let (io_device, trace) = program.trace();
drop(artifact_guard);

let preprocessing =
RV32IJoltVM::preprocess(bytecode.clone(), memory_init, 1 << 20, 1 << 20, 1 << 20);
let preprocessing = RV32IJoltVM::preprocess(
bytecode.clone(),
io_device.memory_layout.clone(),
memory_init,
1 << 20,
1 << 20,
1 << 20,
);
let (proof, commitments, debug_info) =
<RV32IJoltVM as Jolt<F, PCS, C, M, ProofTranscript>>::prove(
io_device,
Expand Down Expand Up @@ -371,8 +377,14 @@ mod tests {
let (bytecode, memory_init) = program.decode();
let (io_device, trace) = program.trace();

let preprocessing =
RV32IJoltVM::preprocess(bytecode.clone(), memory_init, 1 << 20, 1 << 20, 1 << 20);
let preprocessing = RV32IJoltVM::preprocess(
bytecode.clone(),
io_device.memory_layout.clone(),
memory_init,
1 << 20,
1 << 20,
1 << 20,
);
let (jolt_proof, jolt_commitments, debug_info) =
<RV32IJoltVM as Jolt<
_,
Expand Down Expand Up @@ -401,8 +413,14 @@ mod tests {
let (io_device, trace) = program.trace();
drop(guard);

let preprocessing =
RV32IJoltVM::preprocess(bytecode.clone(), memory_init, 1 << 20, 1 << 20, 1 << 20);
let preprocessing = RV32IJoltVM::preprocess(
bytecode.clone(),
io_device.memory_layout.clone(),
memory_init,
1 << 20,
1 << 20,
1 << 20,
);
let (jolt_proof, jolt_commitments, debug_info) =
<RV32IJoltVM as Jolt<
_,
Expand Down Expand Up @@ -431,8 +449,14 @@ mod tests {
let (io_device, trace) = program.trace();
drop(guard);

let preprocessing =
RV32IJoltVM::preprocess(bytecode.clone(), memory_init, 1 << 20, 1 << 20, 1 << 20);
let preprocessing = RV32IJoltVM::preprocess(
bytecode.clone(),
io_device.memory_layout.clone(),
memory_init,
1 << 20,
1 << 20,
1 << 20,
);
let (jolt_proof, jolt_commitments, debug_info) = <RV32IJoltVM as Jolt<
_,
Zeromorph<Bn254, KeccakTranscript>,
Expand Down Expand Up @@ -462,8 +486,14 @@ mod tests {
let (io_device, trace) = program.trace();
drop(guard);

let preprocessing =
RV32IJoltVM::preprocess(bytecode.clone(), memory_init, 1 << 20, 1 << 20, 1 << 20);
let preprocessing = RV32IJoltVM::preprocess(
bytecode.clone(),
io_device.memory_layout.clone(),
memory_init,
1 << 20,
1 << 20,
1 << 20,
);
let (jolt_proof, jolt_commitments, debug_info) = <RV32IJoltVM as Jolt<
_,
HyperKZG<Bn254, KeccakTranscript>,
Expand Down Expand Up @@ -495,8 +525,14 @@ mod tests {
io_device.outputs[0] = 0; // change the output to 0
drop(artifact_guard);

let preprocessing =
RV32IJoltVM::preprocess(bytecode.clone(), memory_init, 1 << 20, 1 << 20, 1 << 20);
let preprocessing = RV32IJoltVM::preprocess(
bytecode.clone(),
io_device.memory_layout.clone(),
memory_init,
1 << 20,
1 << 20,
1 << 20,
);
let (proof, commitments, debug_info) = <RV32IJoltVM as Jolt<
Fr,
HyperKZG<Bn254, KeccakTranscript>,
Expand All @@ -506,12 +542,46 @@ mod tests {
>>::prove(
io_device, trace, preprocessing.clone()
);
let verification_result =
let _verification_result =
RV32IJoltVM::verify(preprocessing, proof, commitments, debug_info);
assert!(
verification_result.is_ok(),
"Verification failed with error: {:?}",
verification_result.err()
}

#[test]
#[should_panic]
sagar-a16z marked this conversation as resolved.
Show resolved Hide resolved
fn malicious_trace() {
let artifact_guard = FIB_FILE_LOCK.lock().unwrap();
let mut program = host::Program::new("fibonacci-guest");
program.set_input(&1u8); // change input to 1 so that termination bit equal true
let (bytecode, memory_init) = program.decode();
let (mut io_device, trace) = program.trace();
let memory_layout = io_device.memory_layout.clone();
drop(artifact_guard);

// change memory address of output & termination bit to the same address as input
// changes here should not be able to spoof the verifier result
io_device.memory_layout.output_start = io_device.memory_layout.input_start;
io_device.memory_layout.output_end = io_device.memory_layout.input_end;
io_device.memory_layout.termination = io_device.memory_layout.input_start;

// Since the preprocessing is done with the original memory layout, the verifier should fail
let preprocessing = RV32IJoltVM::preprocess(
bytecode.clone(),
memory_layout,
memory_init,
1 << 20,
1 << 20,
1 << 20,
);
let (proof, commitments, debug_info) = <RV32IJoltVM as Jolt<
Fr,
HyperKZG<Bn254, KeccakTranscript>,
C,
M,
KeccakTranscript,
>>::prove(
io_device, trace, preprocessing.clone()
);
let _verification_result =
RV32IJoltVM::verify(preprocessing, proof, commitments, debug_info);
}
}
6 changes: 6 additions & 0 deletions jolt-sdk/macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,9 @@ impl MacroBuilder {
}

fn make_preprocess_func(&self) -> TokenStream2 {
let attributes = parse_attributes(&self.attr);
let max_input_size = proc_macro2::Literal::u64_unsuffixed(attributes.max_input_size);
let max_output_size = proc_macro2::Literal::u64_unsuffixed(attributes.max_output_size);
let set_mem_size = self.make_set_linker_parameters();
let guest_name = self.get_guest_name();
let imports = self.make_imports();
Expand All @@ -199,11 +202,13 @@ impl MacroBuilder {
#set_std
#set_mem_size
let (bytecode, memory_init) = program.decode();
let memory_layout = MemoryLayout::new(#max_input_size, #max_output_size);

// TODO(moodlezoup): Feed in size parameters via macro
let preprocessing: JoltPreprocessing<4, jolt::F, jolt::PCS, jolt::ProofTranscript> =
RV32IJoltVM::preprocess(
bytecode,
memory_layout,
memory_init,
1 << 20,
1 << 20,
Expand Down Expand Up @@ -409,6 +414,7 @@ impl MacroBuilder {
RV32IJoltProof,
BytecodeRow,
MemoryOp,
MemoryLayout,
MEMORY_OPS_PER_INSTRUCTION,
instruction::add::ADDInstruction,
tracer,
Expand Down
2 changes: 1 addition & 1 deletion jolt-sdk/src/host_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ pub use jolt_core::{field::JoltField, poly::commitment::hyperkzg::HyperKZG};

pub use common::{
constants::MEMORY_OPS_PER_INSTRUCTION,
rv_trace::{MemoryOp, RV32IM},
rv_trace::{MemoryLayout, MemoryOp, RV32IM},
};
pub use jolt_core::host;
pub use jolt_core::jolt::instruction;
Expand Down
Loading