diff --git a/ceno_emul/src/platform.rs b/ceno_emul/src/platform.rs index fec7e0181..5d7d669e1 100644 --- a/ceno_emul/src/platform.rs +++ b/ceno_emul/src/platform.rs @@ -1,4 +1,4 @@ -use std::ops::Range; +use std::{collections::HashSet, ops::Range}; use crate::addr::{Addr, RegIdx}; @@ -10,20 +10,23 @@ use crate::addr::{Addr, RegIdx}; #[derive(Clone, Debug)] pub struct Platform { pub rom: Range, - pub ram: Range, + // This is an `Option` to allow `const` here. + pub prog_data: Option>, + pub stack: Range, + pub heap: Range, pub public_io: Range, pub hints: Range, - pub stack_top: Addr, /// If true, ecall instructions are no-op instead of trap. Testing only. pub unsafe_ecall_nop: bool, } pub const CENO_PLATFORM: Platform = Platform { rom: 0x2000_0000..0x3000_0000, - ram: 0x8000_0000..0xFFFF_0000, + prog_data: None, + stack: 0xB0000000..0xC0000000, + heap: 0x8000_0000..0xFFFF_0000, public_io: 0x3000_1000..0x3000_2000, hints: 0x4000_0000..0x5000_0000, - stack_top: 0xC0000000, unsafe_ecall_nop: false, }; @@ -34,8 +37,15 @@ impl Platform { self.rom.contains(&addr) } + pub fn is_prog_data(&self, addr: Addr) -> bool { + self.prog_data + .as_ref() + .map(|set| set.contains(&(addr & !0x3))) + .unwrap_or(false) + } + pub fn is_ram(&self, addr: Addr) -> bool { - self.ram.contains(&addr) + self.stack.contains(&addr) || self.heap.contains(&addr) || self.is_prog_data(addr) } pub fn is_pub_io(&self, addr: Addr) -> bool { @@ -66,11 +76,11 @@ impl Platform { // Permissions. pub fn can_read(&self, addr: Addr) -> bool { - self.is_rom(addr) || self.is_ram(addr) || self.is_pub_io(addr) || self.is_hints(addr) + self.can_write(addr) } pub fn can_write(&self, addr: Addr) -> bool { - self.is_ram(addr) + self.is_ram(addr) || self.is_pub_io(addr) || self.is_hints(addr) } // Environment calls. @@ -110,8 +120,8 @@ mod tests { fn test_no_overlap() { let p = CENO_PLATFORM; // ROM and RAM do not overlap. - assert!(!p.is_rom(p.ram.start)); - assert!(!p.is_rom(p.ram.end - WORD_SIZE as Addr)); + assert!(!p.is_rom(p.heap.start)); + assert!(!p.is_rom(p.heap.end - WORD_SIZE as Addr)); assert!(!p.is_ram(p.rom.start)); assert!(!p.is_ram(p.rom.end - WORD_SIZE as Addr)); // Registers do not overlap with ROM or RAM. diff --git a/ceno_emul/src/tracer.rs b/ceno_emul/src/tracer.rs index 01ad33e01..0cd87783c 100644 --- a/ceno_emul/src/tracer.rs +++ b/ceno_emul/src/tracer.rs @@ -198,7 +198,7 @@ impl StepRecord { Some(value), Some(Change::new(value, value)), Some(WriteOp { - addr: CENO_PLATFORM.ram.start.into(), + addr: CENO_PLATFORM.heap.start.into(), value: Change { before: value, after: value, diff --git a/ceno_emul/src/vm_state.rs b/ceno_emul/src/vm_state.rs index 838779979..c5e685239 100644 --- a/ceno_emul/src/vm_state.rs +++ b/ceno_emul/src/vm_state.rs @@ -122,7 +122,7 @@ impl EmuContext for VMState { tracing::warn!("ecall ignored: syscall_id={}", function); self.store_register(Instruction::RD_NULL as RegIdx, 0)?; // Example ecall effect - any writable address will do. - let addr = (self.platform.stack_top - WORD_SIZE as u32).into(); + let addr = (self.platform.stack.end - WORD_SIZE as u32).into(); self.store_memory(addr, self.peek_memory(addr))?; self.set_pc(ByteAddr(self.pc) + PC_STEP_SIZE); Ok(true) diff --git a/ceno_emul/tests/test_elf.rs b/ceno_emul/tests/test_elf.rs index 7448d4508..c7a52685f 100644 --- a/ceno_emul/tests/test_elf.rs +++ b/ceno_emul/tests/test_elf.rs @@ -1,5 +1,9 @@ +use std::{collections::HashSet, sync::Arc}; + use anyhow::Result; -use ceno_emul::{ByteAddr, CENO_PLATFORM, EmuContext, InsnKind, Platform, StepRecord, VMState}; +use ceno_emul::{ + ByteAddr, CENO_PLATFORM, EmuContext, InsnKind, Platform, Program, StepRecord, VMState, +}; #[test] fn test_ceno_rt_mini() -> Result<()> { @@ -27,7 +31,7 @@ fn test_ceno_rt_mem() -> Result<()> { let mut state = VMState::new_from_elf(CENO_PLATFORM, program_elf)?; let _steps = run(&mut state)?; - let value = state.peek_memory(CENO_PLATFORM.ram.start.into()); + let value = state.peek_memory(CENO_PLATFORM.heap.start.into()); assert_eq!(value, 6765, "Expected Fibonacci 20, got {}", value); Ok(()) } @@ -60,7 +64,12 @@ fn test_ceno_rt_alloc() -> Result<()> { #[test] fn test_ceno_rt_io() -> Result<()> { let program_elf = ceno_examples::ceno_rt_io; - let mut state = VMState::new_from_elf(CENO_PLATFORM, program_elf)?; + let program = Program::load_elf(program_elf, u32::MAX)?; + let platform = Platform { + prog_data: Some(program.image.keys().copied().collect::>()), + ..CENO_PLATFORM + }; + let mut state = VMState::new(platform, Arc::new(program)); let _steps = run(&mut state)?; let all_messages = read_all_messages(&state); diff --git a/ceno_zkvm/benches/fibonacci.rs b/ceno_zkvm/benches/fibonacci.rs index f75c4ddd1..aa402b9bd 100644 --- a/ceno_zkvm/benches/fibonacci.rs +++ b/ceno_zkvm/benches/fibonacci.rs @@ -4,10 +4,10 @@ use std::{ time::{Duration, Instant}, }; -use ceno_emul::{CENO_PLATFORM, Platform, Program, WORD_SIZE}; +use ceno_emul::{Platform, Program}; use ceno_zkvm::{ self, - e2e::{Checkpoint, run_e2e_with_checkpoint}, + e2e::{Checkpoint, Preset, run_e2e_with_checkpoint, setup_platform}, }; use criterion::*; @@ -28,29 +28,20 @@ type Pcs = BasefoldDefault; type E = GoldilocksExt2; // Relevant init data for fibonacci run -fn setup() -> (Program, Platform, u32, u32) { +fn setup() -> (Program, Platform) { let mut file_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); file_path.push("examples/fibonacci.elf"); let stack_size = 32768; let heap_size = 2097152; + let pub_io_size = 16; let elf_bytes = fs::read(&file_path).expect("read elf file"); let program = Program::load_elf(&elf_bytes, u32::MAX).unwrap(); - - let platform = Platform { - // The stack section is not mentioned in ELF headers, so we repeat the constant STACK_TOP here. - stack_top: 0x0020_0400, - rom: program.base_address - ..program.base_address + (program.instructions.len() * WORD_SIZE) as u32, - ram: 0x0010_0000..0xFFFF_0000, - unsafe_ecall_nop: true, - ..CENO_PLATFORM - }; - - (program, platform, stack_size, heap_size) + let platform = setup_platform(Preset::Sp1, &program, stack_size, heap_size, pub_io_size); + (program, platform) } fn fibonacci_prove(c: &mut Criterion) { - let (program, platform, stack_size, heap_size) = setup(); + let (program, platform) = setup(); for max_steps in [1usize << 20, 1usize << 21, 1usize << 22] { // expand more input size once runtime is acceptable let mut group = c.benchmark_group(format!("fibonacci_max_steps_{}", max_steps)); @@ -68,8 +59,6 @@ fn fibonacci_prove(c: &mut Criterion) { run_e2e_with_checkpoint::( program.clone(), platform.clone(), - stack_size, - heap_size, vec![], max_steps, Checkpoint::PrepE2EProving, diff --git a/ceno_zkvm/benches/fibonacci_witness.rs b/ceno_zkvm/benches/fibonacci_witness.rs index 2f09adaee..e2a46be70 100644 --- a/ceno_zkvm/benches/fibonacci_witness.rs +++ b/ceno_zkvm/benches/fibonacci_witness.rs @@ -1,9 +1,9 @@ use std::{fs, path::PathBuf, time::Duration}; -use ceno_emul::{CENO_PLATFORM, Platform, Program, WORD_SIZE}; +use ceno_emul::{Platform, Program}; use ceno_zkvm::{ self, - e2e::{Checkpoint, run_e2e_with_checkpoint}, + e2e::{Checkpoint, Preset, run_e2e_with_checkpoint, setup_platform}, }; use criterion::*; @@ -23,29 +23,20 @@ type Pcs = BasefoldDefault; type E = GoldilocksExt2; // Relevant init data for fibonacci run -fn setup() -> (Program, Platform, u32, u32) { +fn setup() -> (Program, Platform) { let mut file_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); file_path.push("examples/fibonacci.elf"); let stack_size = 32768; let heap_size = 2097152; + let pub_io_size = 16; let elf_bytes = fs::read(&file_path).expect("read elf file"); let program = Program::load_elf(&elf_bytes, u32::MAX).unwrap(); - - let platform = Platform { - // The stack section is not mentioned in ELF headers, so we repeat the constant STACK_TOP here. - stack_top: 0x0020_0400, - rom: program.base_address - ..program.base_address + (program.instructions.len() * WORD_SIZE) as u32, - ram: 0x0010_0000..0xFFFF_0000, - unsafe_ecall_nop: true, - ..CENO_PLATFORM - }; - - (program, platform, stack_size, heap_size) + let platform = setup_platform(Preset::Sp1, &program, stack_size, heap_size, pub_io_size); + (program, platform) } fn fibonacci_witness(c: &mut Criterion) { - let (program, platform, stack_size, heap_size) = setup(); + let (program, platform) = setup(); let max_steps = usize::MAX; let mut group = c.benchmark_group(format!("fib_wit_max_steps_{}", max_steps)); @@ -63,8 +54,6 @@ fn fibonacci_witness(c: &mut Criterion) { run_e2e_with_checkpoint::( program.clone(), platform.clone(), - stack_size, - heap_size, vec![], max_steps, Checkpoint::PrepWitnessGen, diff --git a/ceno_zkvm/examples/riscv_opcodes.rs b/ceno_zkvm/examples/riscv_opcodes.rs index 81d1f6eb3..54813659e 100644 --- a/ceno_zkvm/examples/riscv_opcodes.rs +++ b/ceno_zkvm/examples/riscv_opcodes.rs @@ -76,7 +76,7 @@ fn main() { program_code, Default::default(), ); - let mem_addresses = CENO_PLATFORM.ram.clone(); + let mem_addresses = CENO_PLATFORM.heap.clone(); let io_addresses = CENO_PLATFORM.public_io.clone(); let mut fmt_layer = fmt::layer() diff --git a/ceno_zkvm/src/bin/e2e.rs b/ceno_zkvm/src/bin/e2e.rs index 07baf8998..d37784e6e 100644 --- a/ceno_zkvm/src/bin/e2e.rs +++ b/ceno_zkvm/src/bin/e2e.rs @@ -1,9 +1,9 @@ -use ceno_emul::{CENO_PLATFORM, IterAddresses, Platform, Program, WORD_SIZE, Word}; +use ceno_emul::{IterAddresses, Program, WORD_SIZE, Word}; use ceno_zkvm::{ - e2e::{Checkpoint, run_e2e_with_checkpoint}, + e2e::{Checkpoint, Preset, run_e2e_with_checkpoint, setup_platform}, with_panic_hook, }; -use clap::{Parser, ValueEnum}; +use clap::Parser; use ff_ext::ff::Field; use goldilocks::{Goldilocks, GoldilocksExt2}; use itertools::Itertools; @@ -51,12 +51,6 @@ struct Args { heap_size: u32, } -#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, ValueEnum)] -enum Preset { - Ceno, - Sp1, -} - fn main() { let args = { let mut args = Args::parse(); @@ -64,6 +58,7 @@ fn main() { args.heap_size = args.heap_size.next_multiple_of(WORD_SIZE as u32); args }; + let pub_io_size = 16; // TODO: configure. // default filter let default_filter = EnvFilter::builder() @@ -99,29 +94,16 @@ fn main() { .with(args.profiling.is_none().then_some(default_filter)) .init(); - let args = { - let mut args = Args::parse(); - args.stack_size = args.stack_size.next_multiple_of(WORD_SIZE as u32); - args.heap_size = args.heap_size.next_multiple_of(WORD_SIZE as u32); - args - }; - tracing::info!("Loading ELF file: {}", &args.elf); let elf_bytes = fs::read(&args.elf).expect("read elf file"); let program = Program::load_elf(&elf_bytes, u32::MAX).unwrap(); - - let platform = match args.platform { - Preset::Ceno => CENO_PLATFORM, - Preset::Sp1 => Platform { - // The stack section is not mentioned in ELF headers, so we repeat the constant STACK_TOP here. - stack_top: 0x0020_0400, - rom: program.base_address - ..program.base_address + (program.instructions.len() * WORD_SIZE) as u32, - ram: 0x0010_0000..0xFFFF_0000, - unsafe_ecall_nop: true, - ..CENO_PLATFORM - }, - }; + let platform = setup_platform( + args.platform, + &program, + args.stack_size, + args.heap_size, + pub_io_size, + ); tracing::info!("Running on platform {:?} {:?}", args.platform, platform); tracing::info!( "Stack: {} bytes. Heap: {} bytes.", @@ -146,8 +128,6 @@ fn main() { let (state, _) = run_e2e_with_checkpoint::( program, platform, - args.stack_size, - args.heap_size, hints, max_steps, Checkpoint::PrepSanityCheck, diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 95bcb73be..ec09c0717 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -11,9 +11,10 @@ use crate::{ tables::{MemFinalRecord, MemInitRecord, ProgramTableCircuit, ProgramTableConfig}, }; use ceno_emul::{ - ByteAddr, EmuContext, InsnKind, IterAddresses, Platform, Program, StepRecord, Tracer, VMState, - WORD_SIZE, WordAddr, + Addr, ByteAddr, CENO_PLATFORM, EmuContext, InsnKind, IterAddresses, Platform, Program, + StepRecord, Tracer, VMState, WORD_SIZE, WordAddr, }; +use clap::ValueEnum; use ff_ext::ExtensionField; use itertools::{Itertools, MinMaxResult, chain}; use mpcs::PolynomialCommitmentScheme; @@ -159,33 +160,79 @@ fn emulate_program( } } -fn init_mem( +#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, ValueEnum)] +pub enum Preset { + Ceno, + Sp1, +} + +pub fn setup_platform( + preset: Preset, program: &Program, - platform: &Platform, - mem_padder: &mut MemPadder, stack_size: u32, heap_size: u32, -) -> Vec { - let stack_addrs = platform.stack_top - stack_size..platform.stack_top; - // Detect heap as starting after program data. - let heap_start = program.image.keys().max().unwrap() + WORD_SIZE as u32; - let heap_addrs = heap_start..heap_start + heap_size; + pub_io_size: u32, +) -> Platform { + let preset = match preset { + Preset::Ceno => CENO_PLATFORM, + Preset::Sp1 => Platform { + // The stack section is not mentioned in ELF headers, so we repeat the constant STACK_TOP here. + stack: 0x0020_0400..0x0020_0400, + unsafe_ecall_nop: true, + ..CENO_PLATFORM + }, + }; + + let prog_data = program.image.keys().copied().collect::>(); + let stack = preset.stack.end - stack_size..preset.stack.end; + let heap = { + // Detect heap as starting after program data. + let heap_start = program.image.keys().max().unwrap() + WORD_SIZE as u32; + let heap = heap_start..heap_start + heap_size; + // Pad the total size to the next power of two. + let mem_size = prog_data.len() + stack.iter_addresses().len() + heap.iter_addresses().len(); + let pad_size = mem_size.next_power_of_two() - mem_size; + let heap_end = heap.end as usize + pad_size * WORD_SIZE; + assert!( + heap_end <= u32::MAX as usize, + "not enough space for padding; reduce heap size" + ); + heap.start..heap_end as u32 + }; + + Platform { + rom: program.base_address + ..program.base_address + (program.instructions.len() * WORD_SIZE) as u32, + prog_data: Some(prog_data), + stack, + heap, + public_io: preset.public_io.start..preset.public_io.start + pub_io_size.next_power_of_two(), + ..preset + } +} + +fn init_mem(program: &Program, platform: &Platform) -> Vec { let program_addrs = program.image.iter().map(|(addr, value)| MemInitRecord { addr: *addr, value: *value, }); - let stack = stack_addrs + let stack = platform + .stack .iter_addresses() .map(|addr| MemInitRecord { addr, value: 0 }); - let heap = heap_addrs + let heap = platform + .heap .iter_addresses() .map(|addr| MemInitRecord { addr, value: 0 }); - let mem_init = chain!(program_addrs, stack, heap).collect_vec(); + let mem_init = chain!(program_addrs, stack, heap) + .sorted_by_key(|record| record.addr) + .collect_vec(); - mem_padder.padded_sorted(mem_init.len().next_power_of_two(), mem_init) + assert!(mem_init.len().is_power_of_two()); + mem_init } pub struct ConstraintSystemConfig { @@ -326,31 +373,26 @@ pub type IntermediateState = (ZKVMProof, ZKVMVerifier); pub fn run_e2e_with_checkpoint + 'static>( program: Program, platform: Platform, - stack_size: u32, - heap_size: u32, hints: Vec, max_steps: usize, checkpoint: Checkpoint, ) -> (Option>, Box) { - // Detect heap as starting after program data. - let heap_start = program.image.keys().max().unwrap() + WORD_SIZE as u32; - let heap_addrs = heap_start..heap_start + heap_size; - let mut mem_padder = MemPadder::new(heap_addrs.end..platform.ram.end); - let mem_init = init_mem(&program, &platform, &mut mem_padder, stack_size, heap_size); + let mem_init = init_mem(&program, &platform); + let pub_io_len = platform.public_io.iter_addresses().len(); let program_params = ProgramParams { platform: platform.clone(), program_size: program.instructions.len(), static_memory_len: mem_init.len(), - ..ProgramParams::default() + pub_io_len, }; let program = Arc::new(program); let system_config = construct_configs::(program_params); + let reg_init = system_config.mmu_config.initial_registers(); // IO is not used in this program, but it must have a particular size at the moment. - let io_init = mem_padder.padded_sorted(system_config.mmu_config.public_io_len(), vec![]); - let reg_init = system_config.mmu_config.initial_registers(); + let io_init = MemPadder::init_mem(platform.public_io.clone(), pub_io_len, &[]); let init_full_mem = InitMemState { mem: mem_init, diff --git a/ceno_zkvm/src/tables/ram.rs b/ceno_zkvm/src/tables/ram.rs index 2c45294e4..7bdf6c32e 100644 --- a/ceno_zkvm/src/tables/ram.rs +++ b/ceno_zkvm/src/tables/ram.rs @@ -19,11 +19,11 @@ impl DynVolatileRamTable for DynMemTable { const ZERO_INIT: bool = true; fn offset_addr(params: &ProgramParams) -> Addr { - params.platform.ram.start + params.platform.heap.start } fn end_addr(params: &ProgramParams) -> Addr { - params.platform.ram.end + params.platform.heap.end } fn name() -> &'static str {