-
Notifications
You must be signed in to change notification settings - Fork 92
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This PR adds a basic block machine processor. Currently, we assume a rectangular block shape and just iterate over all rows and identities of the block until no more progress is made. This is sufficient to generate code for Poseidon.
- Loading branch information
1 parent
a68a20c
commit e2131f5
Showing
5 changed files
with
442 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,304 @@ | ||
#![allow(unused)] | ||
use std::{ | ||
collections::{BTreeSet, HashSet}, | ||
fmt::Display, | ||
}; | ||
|
||
use bit_vec::BitVec; | ||
use powdr_ast::analyzed::{AlgebraicReference, Identity}; | ||
use powdr_number::FieldElement; | ||
|
||
use crate::witgen::{ | ||
evaluators::fixed_evaluator, jit::affine_symbolic_expression::AffineSymbolicExpression, | ||
machines::MachineParts, FixedData, | ||
}; | ||
|
||
use super::{ | ||
affine_symbolic_expression::Effect, | ||
variable::{Cell, Variable}, | ||
witgen_inference::{FixedEvaluator, WitgenInference}, | ||
}; | ||
|
||
/// A processor for generating JIT code for a block machine. | ||
struct BlockMachineProcessor<'a, T: FieldElement> { | ||
fixed_data: &'a FixedData<'a, T>, | ||
machine_parts: MachineParts<'a, T>, | ||
block_size: usize, | ||
latch_row: usize, | ||
} | ||
|
||
impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> { | ||
/// Generates the JIT code for a given combination of connection and known arguments. | ||
/// Fails if it cannot solve for the outputs, or if any sub-machine calls cannot be completed. | ||
pub fn generate_code( | ||
&self, | ||
identity_id: u64, | ||
known_args: BitVec, | ||
) -> Result<Vec<Effect<T, Variable>>, String> { | ||
let connection_rhs = self.machine_parts.connections[&identity_id].right; | ||
assert_eq!(connection_rhs.expressions.len(), known_args.len()); | ||
|
||
// Set up WitgenInference with known arguments. | ||
let known_variables = known_args | ||
.iter() | ||
.enumerate() | ||
.filter_map(|(i, is_input)| is_input.then_some(Variable::Param(i))) | ||
.collect::<HashSet<_>>(); | ||
let mut witgen = WitgenInference::new(self.fixed_data, self, known_variables); | ||
|
||
// In the latch row, set the RHS selector to 1. | ||
witgen.assign( | ||
&connection_rhs.selector, | ||
self.latch_row as i32, | ||
T::one().into(), | ||
)?; | ||
|
||
// For each known argument, transfer the value to the expression in the connection's RHS. | ||
for (index, expr) in connection_rhs.expressions.iter().enumerate() { | ||
if known_args[index] { | ||
let param_i = | ||
AffineSymbolicExpression::from_known_symbol(Variable::Param(index), None); | ||
witgen.assign(expr, self.latch_row as i32, param_i)?; | ||
} | ||
} | ||
|
||
// Solve for the block witness. | ||
// Fails if any machine call cannot be completed. | ||
self.solve_block(&mut witgen)?; | ||
|
||
// For each unknown argument, get the value from the expression in the connection's RHS. | ||
// If assign() fails, it means that we weren't able to to solve for the operation's output. | ||
for (index, expr) in connection_rhs.expressions.iter().enumerate() { | ||
if !known_args[index] { | ||
let param_i = | ||
AffineSymbolicExpression::from_unknown_variable(Variable::Param(index), None); | ||
witgen | ||
.assign(expr, self.latch_row as i32, param_i) | ||
.map_err(|_| format!("Could not solve for params[{index}]"))?; | ||
} | ||
} | ||
|
||
Ok(witgen.code()) | ||
} | ||
|
||
/// Repeatedly processes all identities on all rows, until no progress is made. | ||
/// Fails iff there are incomplete machine calls in the latch row. | ||
fn solve_block(&self, witgen: &mut WitgenInference<T, &Self>) -> Result<(), String> { | ||
let mut complete = HashSet::new(); | ||
for iteration in 0.. { | ||
let mut progress = false; | ||
|
||
// TODO: This algorithm is assuming a rectangular block shape. | ||
for row in (0..self.block_size) { | ||
for id in &self.machine_parts.identities { | ||
if !complete.contains(&(id.id(), row)) { | ||
let result = witgen.process_identity(id, row as i32); | ||
if result.complete { | ||
complete.insert((id.id(), row)); | ||
} | ||
progress |= result.progress; | ||
} | ||
} | ||
} | ||
if !progress { | ||
log::debug!("Finishing after {iteration} iterations"); | ||
break; | ||
} | ||
} | ||
|
||
// If any machine call could not be completed, that's bad because machine calls typically have side effects. | ||
// So, the underlying lookup / permutation / bus argument likely does not hold. | ||
// TODO: This assumes a rectangular block shape. | ||
let has_incomplete_machine_calls = (0..self.block_size) | ||
.flat_map(|row| { | ||
self.machine_parts | ||
.identities | ||
.iter() | ||
.filter(|id| is_machine_call(id)) | ||
.map(move |id| (id, row)) | ||
}) | ||
.any(|(identity, row)| !complete.contains(&(identity.id(), row))); | ||
|
||
match has_incomplete_machine_calls { | ||
true => Err("Incomplete machine calls".to_string()), | ||
false => Ok(()), | ||
} | ||
} | ||
} | ||
|
||
fn is_machine_call<T>(identity: &Identity<T>) -> bool { | ||
match identity { | ||
Identity::Lookup(_) | ||
| Identity::Permutation(_) | ||
| Identity::PhantomLookup(_) | ||
| Identity::PhantomPermutation(_) | ||
| Identity::PhantomBusInteraction(_) => true, | ||
Identity::Polynomial(_) | Identity::Connect(_) => false, | ||
} | ||
} | ||
|
||
impl<T: FieldElement> FixedEvaluator<T> for &BlockMachineProcessor<'_, T> { | ||
fn evaluate(&self, var: &AlgebraicReference, row_offset: i32) -> Option<T> { | ||
assert!(var.is_fixed()); | ||
let values = self.fixed_data.fixed_cols[&var.poly_id].values_max_size(); | ||
let row = (row_offset + var.next as i32 + values.len() as i32) as usize % values.len(); | ||
Some(values[row]) | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod test { | ||
use std::{collections::BTreeMap, fs::read_to_string}; | ||
|
||
use bit_vec::BitVec; | ||
use powdr_ast::analyzed::{ | ||
AlgebraicExpression, AlgebraicReference, Analyzed, SelectedExpressions, | ||
}; | ||
use powdr_number::GoldilocksField; | ||
|
||
use crate::{ | ||
constant_evaluator, | ||
witgen::{ | ||
global_constraints, | ||
jit::{affine_symbolic_expression::Effect, test_util::format_code}, | ||
machines::{Connection, ConnectionKind, MachineParts}, | ||
FixedData, | ||
}, | ||
}; | ||
|
||
use super::{BlockMachineProcessor, Variable}; | ||
|
||
fn generate_for_block_machine( | ||
input_pil: &str, | ||
block_size: usize, | ||
latch_row: usize, | ||
selector_name: &str, | ||
input_names: &[&str], | ||
output_names: &[&str], | ||
) -> Result<Vec<Effect<GoldilocksField, Variable>>, String> { | ||
let analyzed: Analyzed<GoldilocksField> = | ||
powdr_pil_analyzer::analyze_string(input_pil).unwrap(); | ||
let fixed_col_vals = constant_evaluator::generate(&analyzed); | ||
let fixed_data = FixedData::new(&analyzed, &fixed_col_vals, &[], Default::default(), 0); | ||
let (fixed_data, retained_identities) = | ||
global_constraints::set_global_constraints(fixed_data, &analyzed.identities); | ||
|
||
// Build a connection that encodes: | ||
// [] is <selector_name> $ [<input_names...>, <output_names...>] | ||
let witnesses_by_name = analyzed | ||
.committed_polys_in_source_order() | ||
.flat_map(|(symbol, _)| symbol.array_elements()) | ||
.collect::<BTreeMap<_, _>>(); | ||
let to_expr = |name: &str| { | ||
let (column_name, next) = if let Some(name) = name.strip_suffix("'") { | ||
(name, true) | ||
} else { | ||
(name, false) | ||
}; | ||
AlgebraicExpression::Reference(AlgebraicReference { | ||
name: name.to_string(), | ||
poly_id: witnesses_by_name[column_name], | ||
next, | ||
}) | ||
}; | ||
let rhs = input_names | ||
.iter() | ||
.chain(output_names) | ||
.map(|name| to_expr(name)) | ||
.collect::<Vec<_>>(); | ||
let right = SelectedExpressions { | ||
selector: to_expr(selector_name), | ||
expressions: rhs, | ||
}; | ||
// The LHS is not used by the processor. | ||
let left = SelectedExpressions::default(); | ||
let connection = Connection { | ||
id: 0, | ||
left: &left, | ||
right: &right, | ||
kind: ConnectionKind::Permutation, | ||
multiplicity_column: None, | ||
}; | ||
|
||
let machine_parts = MachineParts::new( | ||
&fixed_data, | ||
[(0, connection)].into_iter().collect(), | ||
retained_identities, | ||
witnesses_by_name.values().copied().collect(), | ||
// No prover functions | ||
Vec::new(), | ||
); | ||
|
||
let processor = BlockMachineProcessor { | ||
fixed_data: &fixed_data, | ||
machine_parts, | ||
block_size, | ||
latch_row, | ||
}; | ||
|
||
let known_values = BitVec::from_iter( | ||
input_names | ||
.iter() | ||
.map(|_| true) | ||
.chain(output_names.iter().map(|_| false)), | ||
); | ||
|
||
processor.generate_code(0, known_values) | ||
} | ||
|
||
#[test] | ||
fn add() { | ||
let input = " | ||
namespace Add(256); | ||
col witness sel, a, b, c; | ||
c = a + b; | ||
"; | ||
let code = | ||
generate_for_block_machine(input, 1, 0, "Add::sel", &["Add::a", "Add::b"], &["Add::c"]); | ||
assert_eq!( | ||
format_code(&code.unwrap()), | ||
"Add::sel[0] = 1; | ||
Add::a[0] = params[0]; | ||
Add::b[0] = params[1]; | ||
Add::c[0] = (Add::a[0] + Add::b[0]); | ||
params[2] = Add::c[0];" | ||
); | ||
} | ||
|
||
#[test] | ||
// TODO: Currently fails, because the machine has a non-rectangular block shape. | ||
#[should_panic = "Incomplete machine calls"] | ||
fn binary() { | ||
let input = read_to_string("../test_data/pil/binary.pil").unwrap(); | ||
generate_for_block_machine( | ||
&input, | ||
4, | ||
3, | ||
"main_binary::sel[0]", | ||
&["main_binary::A", "main_binary::B"], | ||
&["main_binary::C"], | ||
) | ||
.unwrap(); | ||
} | ||
|
||
#[test] | ||
fn poseidon() { | ||
let input = read_to_string("../test_data/pil/poseidon_gl.pil").unwrap(); | ||
let array_element = |name: &str, i: usize| { | ||
&*Box::leak(format!("main_poseidon::{name}[{i}]").into_boxed_str()) | ||
}; | ||
generate_for_block_machine( | ||
&input, | ||
31, | ||
0, | ||
"main_poseidon::sel[0]", | ||
&(0..12) | ||
.map(|i| array_element("state", i)) | ||
.collect::<Vec<_>>(), | ||
&(0..4) | ||
.map(|i| array_element("output", i)) | ||
.collect::<Vec<_>>(), | ||
) | ||
.unwrap(); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
// A compiled version of std/machines/large_field/binary.asm | ||
|
||
namespace main_binary(128); | ||
col witness operation_id; | ||
(main_binary::operation_id' - main_binary::operation_id) * (1 - main_binary::latch) = 0; | ||
col fixed latch(i) { if i % 4_int == 3_int { 1_fe } else { 0_fe } }; | ||
col fixed FACTOR(i) { 1_int << (i + 1_int) % 4_int * 8_int }; | ||
col witness A_byte; | ||
col witness B_byte; | ||
col witness C_byte; | ||
col witness A; | ||
col witness B; | ||
col witness C; | ||
main_binary::A' = main_binary::A * (1 - main_binary::latch) + main_binary::A_byte * main_binary::FACTOR; | ||
main_binary::B' = main_binary::B * (1 - main_binary::latch) + main_binary::B_byte * main_binary::FACTOR; | ||
main_binary::C' = main_binary::C * (1 - main_binary::latch) + main_binary::C_byte * main_binary::FACTOR; | ||
col witness operation_id_next; | ||
main_binary::operation_id' = main_binary::operation_id_next; | ||
col witness sel[3]; | ||
main_binary::sel[0] * (1 - main_binary::sel[0]) = 0; | ||
main_binary::sel[1] * (1 - main_binary::sel[1]) = 0; | ||
main_binary::sel[2] * (1 - main_binary::sel[2]) = 0; | ||
[main_binary::operation_id_next, main_binary::A_byte, main_binary::B_byte, main_binary::C_byte] in [main_byte_binary::P_operation, main_byte_binary::P_A, main_byte_binary::P_B, main_byte_binary::P_C]; | ||
|
||
namespace main_byte_binary(262144); | ||
let bit_counts: int[] = [256_int, 256_int, 3_int]; | ||
let inputs: (int -> int)[] = std::utils::cross_product(main_byte_binary::bit_counts); | ||
let a: int -> int = main_byte_binary::inputs[0_int]; | ||
let b: int -> int = main_byte_binary::inputs[1_int]; | ||
let op: int -> int = main_byte_binary::inputs[2_int]; | ||
let P_A: col = main_byte_binary::a; | ||
let P_B: col = main_byte_binary::b; | ||
let P_operation: col = main_byte_binary::op; | ||
col fixed P_C(i) { match main_byte_binary::op(i) { | ||
0 => main_byte_binary::a(i) & main_byte_binary::b(i), | ||
1 => main_byte_binary::a(i) | main_byte_binary::b(i), | ||
2 => main_byte_binary::a(i) ^ main_byte_binary::b(i), | ||
} }; | ||
namespace std::array; | ||
let<T> len: T[] -> int = []; | ||
|
||
namespace std::utils; | ||
let cross_product: int[] -> (int -> int)[] = |sizes| std::utils::cross_product_internal(1_int, 0_int, sizes); | ||
let cross_product_internal: int, int, int[] -> (int -> int)[] = |cycle_len, pos, sizes| if pos >= std::array::len::<int>(sizes) { [] } else { [|i| i / cycle_len % sizes[pos]] + std::utils::cross_product_internal(cycle_len * sizes[pos], pos + 1_int, sizes) }; |
Oops, something went wrong.