Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update main.rs #1

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 87 additions & 56 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,13 @@ use risc0_build::{get_package, guest_methods, setup_guest_build_env, GuestOption

use risc0_zkvm::{
host::{Prover, Receipt},
serde::{
from_slice,
to_vec,
},
serde::{from_slice, to_vec},
};

const METHODS_DIR: &'static str = env!("METHODS_DIR");

fn main() {
// Define the command-line arguments using Clap
let matches = App::new("ZK eBPF tool")
.author("Eclipse Labs")
.arg(
Expand All @@ -58,92 +56,109 @@ fn main() {
.required_unless_present("assembler"),
)
.arg(
Arg::new("build directory")
Arg::new("build_directory")
.short('d')
.long("build-directory")
.value_name("DIR")
.takes_value(true)
.default_value("build"),
)
.arg(
Arg::new("no execute")
Arg::new("no_execute")
.short('n')
.long("no-execute")
.takes_value(false),
)
.arg(
Arg::new("input data")
Arg::new("input_data")
.short('i')
.long("input-data")
.value_name("FILE")
.takes_value(true),
)
.get_matches();

let target_dir_relative = Path::new(matches.value_of("build directory").unwrap());
let target_dir = fs::canonicalize(target_dir_relative).unwrap(); // should be replaced by std::path::absolute once it's out of nightly
fs::create_dir_all(target_dir.clone()).unwrap();
// Determine the build directory
let target_dir_relative = Path::new(matches.value_of("build_directory").unwrap());
let target_dir = fs::canonicalize(target_dir_relative).unwrap();
fs::create_dir_all(&target_dir).unwrap();

// Determine if the input is assembly code or an ELF file
let (input_filename, needs_assembly) = if let Some(filename) = matches.value_of("assembler") {
(filename, true)
} else {
(matches.value_of("elf").unwrap(), false)
};

let bpf_dir = compile_bpf(
fs::read(input_filename).unwrap(),
target_dir.clone(),
needs_assembly,
);
// Compile BPF input to RISC-V object archive
let bpf_dir = compile_bpf(fs::read(input_filename).unwrap(), &target_dir, needs_assembly);

let (method_path, method_id_vec) = compile_methods(target_dir, bpf_dir);
// Compile the guest methods using the produced BPF code
let (method_path, method_id_vec) = compile_methods(&target_dir, &bpf_dir);
let method_id = method_id_vec.as_slice();

if !matches.contains_id("no execute") {
// If not instructed to skip execution, run the prover
if !matches.is_present("no_execute") {
eprintln!("Executing program...");

let input_data = if let Some(filename) = matches.value_of("input data") {
let input_data = if let Some(filename) = matches.value_of("input_data") {
Some(fs::read(filename).unwrap())
} else {
None
};

let (output, receipt) = execute_prover(method_path, method_id, input_data);
let (output, receipt) = execute_prover(&method_path, method_id, input_data);

println!("The final BPF register values were:");
for i in 0..10 {
println!(" r{}: {:#018x}", i, output[i]);
}

// Verify the receipt
receipt.verify(method_id).unwrap();
}
}

fn execute_prover<P: AsRef<Path>>(method_path: P, method_id: &[u8], input_data: Option<Vec<u8>>) -> ([u64; 10], Receipt) {
let mut prover = Prover::new(&std::fs::read(method_path).unwrap(), method_id).unwrap();

prover.add_input(&to_vec(&input_data.unwrap_or(vec![])).unwrap()).unwrap();
// Executes the prover with the given method and optional input data.
// Returns the final register values and the generated receipt.
fn execute_prover<P: AsRef<Path>>(
method_path: P,
method_id: &[u8],
input_data: Option<Vec<u8>>,
) -> ([u64; 10], Receipt) {
let method_bytes = fs::read(method_path).unwrap();
let mut prover = Prover::new(&method_bytes, method_id).unwrap();

// Add input data to the prover
prover
.add_input(&to_vec(&input_data.unwrap_or(vec![])).unwrap())
.unwrap();

let receipt = prover.run().unwrap();
let output = from_slice(&receipt.get_journal_vec().unwrap()).unwrap();
// Run the prover and obtain the receipt
let receipt = prover.run().unwrap();
let output: [u64; 10] = from_slice(&receipt.get_journal_vec().unwrap()).unwrap();

(output, receipt)
(output, receipt)
}

// Compiles BPF code (either from assembly or ELF) into a RISC-V compatible static library (libbpf.a).
fn compile_bpf<P: AsRef<Path>>(
input: Vec<u8>,
target_dir: P,
needs_assembly: bool,
) -> PathBuf {
// Configure the BPF VM and syscall registry
let config = Config {
encrypt_environment_registers: false,
noop_instruction_rate: 0,
..Config::default()
};
let syscall_registry = SyscallRegistry::default();

// Create an executable from either assembly code or ELF input
let executable = if needs_assembly {
assemble::<UserError, TestInstructionMeter>(
std::str::from_utf8(input.as_slice()).unwrap(),
std::str::from_utf8(&input).unwrap(),
config,
syscall_registry,
)
Expand All @@ -155,18 +170,22 @@ fn compile_bpf<P: AsRef<Path>>(

let (_, text_bytes) = executable.get_text_bytes();
let mut compiler = Compiler::new::<UserError>(text_bytes, &config).unwrap();

compiler.compile(&executable).unwrap();

// Extract relevant data sections from the compiled program
let bpf_elf_bytes = executable.get_ro_section();
let pc_offsets_bytes = unsafe { std::slice::from_raw_parts(
compiler.pc_offsets.as_ptr() as *const u8,
compiler.pc_offsets.len() * std::mem::size_of::<i32>(),
) };
let pc_offsets_bytes = unsafe {
std::slice::from_raw_parts(
compiler.pc_offsets.as_ptr() as *const u8,
compiler.pc_offsets.len() * std::mem::size_of::<i32>(),
)
};
let riscv_bytes = compiler.result.text_section;

// Create a new ELF object file
let mut obj = Object::new(BinaryFormat::Elf, Architecture::Riscv32, Endianness::Little);

// Add the .rodata section and symbols
let rodata_section = obj.add_section(
obj.segment_name(StandardSegment::Data).to_vec(),
b".rodata".to_vec(),
Expand All @@ -188,6 +207,7 @@ fn compile_bpf<P: AsRef<Path>>(
&(bpf_elf_bytes.len() as u32).to_le_bytes(),
0x10,
);

let bpf_ro_section_symbol = obj.add_symbol(Symbol {
name: b"bpf_ro_section".to_vec(),
value: 0,
Expand All @@ -199,6 +219,7 @@ fn compile_bpf<P: AsRef<Path>>(
flags: SymbolFlags::None,
});
obj.add_symbol_data(bpf_ro_section_symbol, rodata_section, bpf_elf_bytes, 0x10);

let pc_offsets_symbol = obj.add_symbol(Symbol {
name: b"pc_offsets".to_vec(),
value: 0,
Expand All @@ -211,6 +232,7 @@ fn compile_bpf<P: AsRef<Path>>(
});
obj.add_symbol_data(pc_offsets_symbol, rodata_section, pc_offsets_bytes, 0x10);

// Add the .text section and main symbol
let text_section = obj.add_section(
obj.segment_name(StandardSegment::Text).to_vec(),
b".text".to_vec(),
Expand All @@ -226,13 +248,14 @@ fn compile_bpf<P: AsRef<Path>>(
section: SymbolSection::Section(text_section),
flags: SymbolFlags::None,
});
obj.add_symbol_data(program_main_symbol, text_section, riscv_bytes, 0x1000); // TODO determine what alignment is necessary
obj.add_symbol_data(program_main_symbol, text_section, riscv_bytes, 0x1000);

// Add RISC-V relocations for call and data references
for reloc in compiler.relocations.iter() {
let (symbol_name, offset, relocation_code) = match reloc {
RiscVRelocation::Call{offset, symbol} => (symbol, offset, 18), // R_RISCV_CALL
RiscVRelocation::Hi20{offset, symbol} => (symbol, offset, 26), // R_RISCV_HI20
RiscVRelocation::Lo12I{offset, symbol} => (symbol, offset, 27), // R_RISCV_LO12_I
RiscVRelocation::Call { offset, symbol } => (symbol, offset, 18), // R_RISCV_CALL
RiscVRelocation::Hi20 { offset, symbol } => (symbol, offset, 26), // R_RISCV_HI20
RiscVRelocation::Lo12I { offset, symbol } => (symbol, offset, 27), // R_RISCV_LO12_I
};
let symbol = obj.add_symbol(Symbol {
name: symbol_name.as_bytes().to_vec(),
Expand All @@ -255,32 +278,34 @@ fn compile_bpf<P: AsRef<Path>>(
obj.add_relocation(text_section, obj_reloc).unwrap();
}

let bpf_target_dir = PathBuf::new().join(target_dir).join("bpf-riscv");
fs::create_dir_all(bpf_target_dir.clone()).unwrap();
// Create a directory for BPF-generated objects and archive them
let bpf_target_dir = target_dir.as_ref().join("bpf-riscv");
fs::create_dir_all(&bpf_target_dir).unwrap();

let obj_path = bpf_target_dir.join("bpf.o");
let obj_file = File::create(&obj_path).unwrap();
obj.write_stream(BufWriter::new(obj_file)).unwrap();

let ar_path = bpf_target_dir.join("libbpf.a");
let ar_file = File::create(ar_path).unwrap();
let ar_file = File::create(&ar_path).unwrap();
let mut ar_builder = Builder::new(ar_file);
ar_builder.append_path(obj_path).unwrap();
ar_builder.append_path(&obj_path).unwrap();

return bpf_target_dir;
bpf_target_dir
}

// Compiles the guest methods using the previously compiled BPF code.
// Returns the path to the compiled method and the corresponding method ID.
fn compile_methods<P: AsRef<Path>, Q: AsRef<Path>>(
target_dir: P,
bpf_target_dir: Q,
) -> (PathBuf, Vec<u8>) {
let pkg = get_package(METHODS_DIR);
let guest_build_env = setup_guest_build_env(target_dir.as_ref());

let target_dir_guest = target_dir.as_ref().join("riscv-guest");

// mostly taken from risc0-build
let args = vec![
// Prepare cargo build arguments for the guest code
let args = &[
"build",
"--release",
"--target",
Expand All @@ -294,10 +319,9 @@ fn compile_methods<P: AsRef<Path>, Q: AsRef<Path>>(
"--target-dir",
target_dir_guest.to_str().unwrap(),
];

eprintln!("Building guest package: cargo {}", args.join(" "));
// The RISC0_STANDARD_LIB variable can be set for testing purposes
// to override the downloaded standard library. It should point
// to the root of the rust repository.

let risc0_standard_lib: String = if let Ok(path) = env::var("RISC0_STANDARD_LIB") {
path
} else {
Expand All @@ -309,20 +333,27 @@ fn compile_methods<P: AsRef<Path>, Q: AsRef<Path>>(
let mut cmd = Command::new("cargo");
cmd.env("BPF_LIB_DIR", bpf_target_dir.as_ref().as_os_str())
.env("CARGO_ENCODED_RUSTFLAGS", "-C\x1fpasses=loweratomic")
.env("__CARGO_TESTS_ONLY_SRC_ROOT", risc0_standard_lib)
.args(args)
.spawn()
.unwrap();
.env("__CARGO_TESTS_ONLY_SRC_ROOT", &risc0_standard_lib)
.args(args);

let status = cmd.status().unwrap();
// Spawn the cargo process and wait for completion
let child = cmd.spawn().expect("Failed to spawn cargo build");
let status = child.wait().expect("Failed to wait for cargo build");

if !status.success() {
std::process::exit(status.code().unwrap());
std::process::exit(status.code().unwrap_or(1));
}

// Extract the compiled guest method
let mut methods = guest_methods(&pkg, target_dir);
if methods.is_empty() {
eprintln!("No methods found.");
std::process::exit(1);
}

let method = guest_methods(&pkg, target_dir).remove(0);
return (
let method = methods.remove(0);
(
method.elf_path.clone(),
method.make_method_id(GuestOptions::default().code_limit),
);
)
}