Skip to content

Commit

Permalink
Use arith-memory in RISCV (#2199)
Browse files Browse the repository at this point in the history
  • Loading branch information
leonardoalt authored Dec 11, 2024
1 parent 0180542 commit 4f1aa4a
Show file tree
Hide file tree
Showing 10 changed files with 1,034 additions and 1,689 deletions.
17 changes: 15 additions & 2 deletions powdr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use std::time::Instant;
pub struct SessionBuilder {
guest_path: String,
out_path: String,
asm_file: Option<String>,
chunk_size_log2: Option<u8>,
precompiles: RuntimeLibs,
}
Expand All @@ -47,14 +48,20 @@ const DEFAULT_MIN_MAX_DEGREE_LOG: u8 = 18;
impl SessionBuilder {
/// Builds a session with the given parameters.
pub fn build(self) -> Session {
Session {
pipeline: pipeline_from_guest(
let pipeline = match self.asm_file {
Some(asm_file) => Pipeline::<GoldilocksField>::default()
.from_asm_file(asm_file.into())
.with_output(Path::new(&self.out_path).to_path_buf(), true),
None => pipeline_from_guest(
&self.guest_path,
Path::new(&self.out_path),
DEFAULT_MIN_DEGREE_LOG,
self.chunk_size_log2.unwrap_or(DEFAULT_MAX_DEGREE_LOG),
self.precompiles,
),
};
Session {
pipeline,
out_path: self.out_path,
}
.with_backend(powdr_backend::BackendType::Plonky3)
Expand All @@ -72,6 +79,12 @@ impl SessionBuilder {
self
}

/// Re-use a previously compiled guest program.
pub fn asm_file(mut self, asm_file: &str) -> Self {
self.asm_file = Some(asm_file.into());
self
}

/// Set the chunk size, represented by its log2.
/// Example: for a chunk size of 2^20, set chunk_size_log2 to 20.
/// If the execution trace is longer than the 2^chunk_size_log2,
Expand Down
193 changes: 122 additions & 71 deletions riscv-executor/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1865,9 +1865,9 @@ impl<'a, 'b, F: FieldElement> Executor<'a, 'b, F> {
let reg1 = args[0].u();
let reg2 = args[1].u();
let input_ptr = self.reg_read(0, reg1, 0);
assert_eq!(input_ptr.u() % 4, 0);
assert!(is_multiple_of_4(input_ptr.u()));
let output_ptr = self.reg_read(1, reg2, 1);
assert_eq!(output_ptr.u() % 4, 0);
assert!(is_multiple_of_4(output_ptr.u()));

set_col!(tmp1_col, input_ptr);
set_col!(tmp2_col, output_ptr);
Expand Down Expand Up @@ -1951,7 +1951,7 @@ impl<'a, 'b, F: FieldElement> Executor<'a, 'b, F> {
}
"poseidon2_gl" => {
let input_ptr = self.proc.get_reg_mem(args[0].u()).u();
assert_eq!(input_ptr % 4, 0);
assert!(is_multiple_of_4(input_ptr));

let inputs: [u64; 8] = (0..16)
.map(|i| self.proc.get_mem(input_ptr + i * 4, 0, 0)) // TODO: step/selector for poseidon2
Expand All @@ -1971,104 +1971,157 @@ impl<'a, 'b, F: FieldElement> Executor<'a, 'b, F> {
.flat_map(|v| vec![(v & 0xffffffff) as u32, (v >> 32) as u32]);

let output_ptr = self.proc.get_reg_mem(args[1].u()).u();
assert_eq!(output_ptr % 4, 0);
assert!(is_multiple_of_4(output_ptr));
result.enumerate().for_each(|(i, v)| {
self.proc.set_mem(output_ptr + i as u32 * 4, v, 0, 0); // TODO: step/selector for poseidon2
});

vec![]
}
"affine_256" => {
assert!(args.is_empty());
// take input from registers
let x1 = (0..8)
.map(|i| self.proc.get_reg(&register_by_idx(i)).into_fe())
// a * b + c = d
let input_ptr_a = self.proc.get_reg_mem(args[0].u()).u();
assert!(is_multiple_of_4(input_ptr_a));
let input_ptr_b = self.proc.get_reg_mem(args[1].u()).u();
assert!(is_multiple_of_4(input_ptr_b));
let input_ptr_c = self.proc.get_reg_mem(args[2].u()).u();
assert!(is_multiple_of_4(input_ptr_c));
let output_ptr_d = self.proc.get_reg_mem(args[3].u()).u();
assert!(is_multiple_of_4(output_ptr_d));

let a = (0..8)
.map(|i| F::from(self.proc.get_mem(input_ptr_a + i * 4, 0, 0)))
.collect::<Vec<_>>();
let y1 = (0..8)
.map(|i| self.proc.get_reg(&register_by_idx(i + 8)).into_fe())
let b = (0..8)
.map(|i| F::from(self.proc.get_mem(input_ptr_b + i * 4, 0, 0)))
.collect::<Vec<_>>();
let x2 = (0..8)
.map(|i| self.proc.get_reg(&register_by_idx(i + 16)).into_fe())
let c = (0..8)
.map(|i| F::from(self.proc.get_mem(input_ptr_c + i * 4, 0, 0)))
.collect::<Vec<_>>();
let result = arith::affine_256(&x1, &y1, &x2);
// store result in registers
(0..8).for_each(|i| {
self.proc
.set_reg(&register_by_idx(i), Elem::Field(result.0[i]))
let result = arith::affine_256(&a, &b, &c);

result.0.iter().enumerate().for_each(|(i, &v)| {
self.proc.set_mem(
output_ptr_d + i as u32 * 4,
v.to_integer().try_into_u32().unwrap(),
1,
1,
);
});
(0..8).for_each(|i| {
self.proc
.set_reg(&register_by_idx(i + 8), Elem::Field(result.1[i]))
result.1.iter().enumerate().for_each(|(i, &v)| {
self.proc.set_mem(
output_ptr_d + (result.0.len() as u32 * 4) + (i as u32 * 4),
v.to_integer().try_into_u32().unwrap(),
1,
1,
);
});

vec![]
}
"mod_256" => {
assert!(args.is_empty());
// take input from registers
let y2 = (0..8)
.map(|i| self.proc.get_reg(&register_by_idx(i)).into_fe())
// a mod b = c
let input_ptr_a = self.proc.get_reg_mem(args[0].u()).u();
assert!(is_multiple_of_4(input_ptr_a));
let input_ptr_b = self.proc.get_reg_mem(args[1].u()).u();
assert!(is_multiple_of_4(input_ptr_b));
let output_ptr_c = self.proc.get_reg_mem(args[2].u()).u();
assert!(is_multiple_of_4(output_ptr_c));

let ah = (0..8)
.map(|i| F::from(self.proc.get_mem(input_ptr_a + i * 4, 0, 0)))
.collect::<Vec<_>>();
let y3 = (0..8)
.map(|i| self.proc.get_reg(&register_by_idx(i + 8)).into_fe())
let al = (8..16)
.map(|i| F::from(self.proc.get_mem(input_ptr_a + i * 4, 0, 0)))
.collect::<Vec<_>>();
let x1 = (0..8)
.map(|i| self.proc.get_reg(&register_by_idx(i + 16)).into_fe())
let b = (0..8)
.map(|i| F::from(self.proc.get_mem(input_ptr_b + i * 4, 0, 0)))
.collect::<Vec<_>>();
let result = arith::mod_256(&y2, &y3, &x1);
// store result in registers
(0..8).for_each(|i| {
self.proc
.set_reg(&register_by_idx(i), Elem::Field(result[i]))
let result = arith::mod_256(&ah, &al, &b);

result.iter().enumerate().for_each(|(i, &v)| {
self.proc.set_mem(
output_ptr_c + i as u32 * 4,
v.to_integer().try_into_u32().unwrap(),
1,
1,
);
});

vec![]
}
"ec_add" => {
assert!(args.is_empty());
// take input from registers
let x1 = (0..8)
.map(|i| self.proc.get_reg(&register_by_idx(i)).into_fe())
// a + b = c
let input_ptr_a = self.proc.get_reg_mem(args[0].u()).u();
assert!(is_multiple_of_4(input_ptr_a));
let input_ptr_b = self.proc.get_reg_mem(args[1].u()).u();
assert!(is_multiple_of_4(input_ptr_b));
let output_ptr_c = self.proc.get_reg_mem(args[2].u()).u();
assert!(is_multiple_of_4(output_ptr_c));

let ax = (0..8)
.map(|i| F::from(self.proc.get_mem(input_ptr_a + i * 4, 0, 0)))
.collect::<Vec<_>>();
let y1 = (0..8)
.map(|i| self.proc.get_reg(&register_by_idx(i + 8)).into_fe())
let ay = (8..16)
.map(|i| F::from(self.proc.get_mem(input_ptr_a + i * 4, 0, 0)))
.collect::<Vec<_>>();
let x2 = (0..8)
.map(|i| self.proc.get_reg(&register_by_idx(i + 16)).into_fe())
let bx = (0..8)
.map(|i| F::from(self.proc.get_mem(input_ptr_b + i * 4, 0, 0)))
.collect::<Vec<_>>();
let y2 = (0..8)
.map(|i| self.proc.get_reg(&register_by_idx(i + 24)).into_fe())
let by = (8..16)
.map(|i| F::from(self.proc.get_mem(input_ptr_b + i * 4, 0, 0)))
.collect::<Vec<_>>();
let result = arith::ec_add(&x1, &y1, &x2, &y2);
// store result in registers
(0..8).for_each(|i| {
self.proc
.set_reg(&register_by_idx(i), Elem::Field(result.0[i]))

let result = arith::ec_add(&ax, &ay, &bx, &by);
result.0.iter().enumerate().for_each(|(i, &v)| {
self.proc.set_mem(
output_ptr_c + i as u32 * 4,
v.to_integer().try_into_u32().unwrap(),
1,
1,
);
});
(0..8).for_each(|i| {
self.proc
.set_reg(&register_by_idx(i + 8), Elem::Field(result.1[i]))
result.1.iter().enumerate().for_each(|(i, &v)| {
self.proc.set_mem(
output_ptr_c + (result.0.len() as u32 * 4) + (i as u32 * 4),
v.to_integer().try_into_u32().unwrap(),
1,
1,
);
});

vec![]
}
"ec_double" => {
assert!(args.is_empty());
// take input from registers
let x = (0..8)
.map(|i| self.proc.get_reg(&register_by_idx(i)).into_fe())
// a * 2 = b
let input_ptr_a = self.proc.get_reg_mem(args[0].u()).u();
assert!(is_multiple_of_4(input_ptr_a));
let output_ptr_b = self.proc.get_reg_mem(args[1].u()).u();
assert!(is_multiple_of_4(output_ptr_b));

let ax = (0..8)
.map(|i| F::from(self.proc.get_mem(input_ptr_a + i * 4, 0, 0)))
.collect::<Vec<_>>();
let y = (0..8)
.map(|i| self.proc.get_reg(&register_by_idx(i + 8)).into_fe())
let ay = (8..16)
.map(|i| F::from(self.proc.get_mem(input_ptr_a + i * 4, 0, 0)))
.collect::<Vec<_>>();
let result = arith::ec_double(&x, &y);
// store result in registers
(0..8).for_each(|i| {
self.proc
.set_reg(&register_by_idx(i), Elem::Field(result.0[i]))

let result = arith::ec_double(&ax, &ay);
result.0.iter().enumerate().for_each(|(i, &v)| {
self.proc.set_mem(
output_ptr_b + i as u32 * 4,
v.to_integer().try_into_u32().unwrap(),
1,
1,
);
});
(0..8).for_each(|i| {
self.proc
.set_reg(&register_by_idx(i + 8), Elem::Field(result.1[i]))
result.1.iter().enumerate().for_each(|(i, &v)| {
self.proc.set_mem(
output_ptr_b + (result.0.len() as u32 * 4) + (i as u32 * 4),
v.to_integer().try_into_u32().unwrap(),
1,
1,
);
});

vec![]
Expand Down Expand Up @@ -2334,12 +2387,6 @@ pub fn execute<F: FieldElement>(
)
}

/// FIXME: copied from `riscv/runtime.rs` instead of adding dependency.
/// Helper function for register names used in submachine instruction params.
fn register_by_idx(idx: usize) -> String {
format!("xtra{idx}")
}

#[allow(clippy::too_many_arguments)]
fn execute_inner<F: FieldElement>(
asm: &AnalysisASMFile,
Expand Down Expand Up @@ -2629,3 +2676,7 @@ pub fn write_executor_csv<F: FieldElement, P: AsRef<Path>>(
&columns[..],
);
}

fn is_multiple_of_4(n: u32) -> bool {
n % 4 == 0
}
Loading

0 comments on commit 4f1aa4a

Please sign in to comment.