Skip to content

Commit

Permalink
Block machine processor (#2226)
Browse files Browse the repository at this point in the history
This PR adds a basic block machine processor. Currently, we assume a
rectangular block shape and just iterate over all rows and identities of
the block until no more progress is made. This is sufficient to generate
code for Poseidon.
  • Loading branch information
georgwiese authored Dec 17, 2024
1 parent a68a20c commit e2131f5
Show file tree
Hide file tree
Showing 5 changed files with 442 additions and 0 deletions.
304 changes: 304 additions & 0 deletions executor/src/witgen/jit/block_machine_processor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,304 @@
#![allow(unused)]
use std::{
collections::{BTreeSet, HashSet},
fmt::Display,
};

use bit_vec::BitVec;
use powdr_ast::analyzed::{AlgebraicReference, Identity};
use powdr_number::FieldElement;

use crate::witgen::{
evaluators::fixed_evaluator, jit::affine_symbolic_expression::AffineSymbolicExpression,
machines::MachineParts, FixedData,
};

use super::{
affine_symbolic_expression::Effect,
variable::{Cell, Variable},
witgen_inference::{FixedEvaluator, WitgenInference},
};

/// A processor for generating JIT code for a block machine.
struct BlockMachineProcessor<'a, T: FieldElement> {
fixed_data: &'a FixedData<'a, T>,
machine_parts: MachineParts<'a, T>,
block_size: usize,
latch_row: usize,
}

impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> {
/// Generates the JIT code for a given combination of connection and known arguments.
/// Fails if it cannot solve for the outputs, or if any sub-machine calls cannot be completed.
pub fn generate_code(
&self,
identity_id: u64,
known_args: BitVec,
) -> Result<Vec<Effect<T, Variable>>, String> {
let connection_rhs = self.machine_parts.connections[&identity_id].right;
assert_eq!(connection_rhs.expressions.len(), known_args.len());

// Set up WitgenInference with known arguments.
let known_variables = known_args
.iter()
.enumerate()
.filter_map(|(i, is_input)| is_input.then_some(Variable::Param(i)))
.collect::<HashSet<_>>();
let mut witgen = WitgenInference::new(self.fixed_data, self, known_variables);

// In the latch row, set the RHS selector to 1.
witgen.assign(
&connection_rhs.selector,
self.latch_row as i32,
T::one().into(),
)?;

// For each known argument, transfer the value to the expression in the connection's RHS.
for (index, expr) in connection_rhs.expressions.iter().enumerate() {
if known_args[index] {
let param_i =
AffineSymbolicExpression::from_known_symbol(Variable::Param(index), None);
witgen.assign(expr, self.latch_row as i32, param_i)?;
}
}

// Solve for the block witness.
// Fails if any machine call cannot be completed.
self.solve_block(&mut witgen)?;

// For each unknown argument, get the value from the expression in the connection's RHS.
// If assign() fails, it means that we weren't able to to solve for the operation's output.
for (index, expr) in connection_rhs.expressions.iter().enumerate() {
if !known_args[index] {
let param_i =
AffineSymbolicExpression::from_unknown_variable(Variable::Param(index), None);
witgen
.assign(expr, self.latch_row as i32, param_i)
.map_err(|_| format!("Could not solve for params[{index}]"))?;
}
}

Ok(witgen.code())
}

/// Repeatedly processes all identities on all rows, until no progress is made.
/// Fails iff there are incomplete machine calls in the latch row.
fn solve_block(&self, witgen: &mut WitgenInference<T, &Self>) -> Result<(), String> {
let mut complete = HashSet::new();
for iteration in 0.. {
let mut progress = false;

// TODO: This algorithm is assuming a rectangular block shape.
for row in (0..self.block_size) {
for id in &self.machine_parts.identities {
if !complete.contains(&(id.id(), row)) {
let result = witgen.process_identity(id, row as i32);
if result.complete {
complete.insert((id.id(), row));
}
progress |= result.progress;
}
}
}
if !progress {
log::debug!("Finishing after {iteration} iterations");
break;
}
}

// If any machine call could not be completed, that's bad because machine calls typically have side effects.
// So, the underlying lookup / permutation / bus argument likely does not hold.
// TODO: This assumes a rectangular block shape.
let has_incomplete_machine_calls = (0..self.block_size)
.flat_map(|row| {
self.machine_parts
.identities
.iter()
.filter(|id| is_machine_call(id))
.map(move |id| (id, row))
})
.any(|(identity, row)| !complete.contains(&(identity.id(), row)));

match has_incomplete_machine_calls {
true => Err("Incomplete machine calls".to_string()),
false => Ok(()),
}
}
}

fn is_machine_call<T>(identity: &Identity<T>) -> bool {
match identity {
Identity::Lookup(_)
| Identity::Permutation(_)
| Identity::PhantomLookup(_)
| Identity::PhantomPermutation(_)
| Identity::PhantomBusInteraction(_) => true,
Identity::Polynomial(_) | Identity::Connect(_) => false,
}
}

impl<T: FieldElement> FixedEvaluator<T> for &BlockMachineProcessor<'_, T> {
fn evaluate(&self, var: &AlgebraicReference, row_offset: i32) -> Option<T> {
assert!(var.is_fixed());
let values = self.fixed_data.fixed_cols[&var.poly_id].values_max_size();
let row = (row_offset + var.next as i32 + values.len() as i32) as usize % values.len();
Some(values[row])
}
}

#[cfg(test)]
mod test {
use std::{collections::BTreeMap, fs::read_to_string};

use bit_vec::BitVec;
use powdr_ast::analyzed::{
AlgebraicExpression, AlgebraicReference, Analyzed, SelectedExpressions,
};
use powdr_number::GoldilocksField;

use crate::{
constant_evaluator,
witgen::{
global_constraints,
jit::{affine_symbolic_expression::Effect, test_util::format_code},
machines::{Connection, ConnectionKind, MachineParts},
FixedData,
},
};

use super::{BlockMachineProcessor, Variable};

fn generate_for_block_machine(
input_pil: &str,
block_size: usize,
latch_row: usize,
selector_name: &str,
input_names: &[&str],
output_names: &[&str],
) -> Result<Vec<Effect<GoldilocksField, Variable>>, String> {
let analyzed: Analyzed<GoldilocksField> =
powdr_pil_analyzer::analyze_string(input_pil).unwrap();
let fixed_col_vals = constant_evaluator::generate(&analyzed);
let fixed_data = FixedData::new(&analyzed, &fixed_col_vals, &[], Default::default(), 0);
let (fixed_data, retained_identities) =
global_constraints::set_global_constraints(fixed_data, &analyzed.identities);

// Build a connection that encodes:
// [] is <selector_name> $ [<input_names...>, <output_names...>]
let witnesses_by_name = analyzed
.committed_polys_in_source_order()
.flat_map(|(symbol, _)| symbol.array_elements())
.collect::<BTreeMap<_, _>>();
let to_expr = |name: &str| {
let (column_name, next) = if let Some(name) = name.strip_suffix("'") {
(name, true)
} else {
(name, false)
};
AlgebraicExpression::Reference(AlgebraicReference {
name: name.to_string(),
poly_id: witnesses_by_name[column_name],
next,
})
};
let rhs = input_names
.iter()
.chain(output_names)
.map(|name| to_expr(name))
.collect::<Vec<_>>();
let right = SelectedExpressions {
selector: to_expr(selector_name),
expressions: rhs,
};
// The LHS is not used by the processor.
let left = SelectedExpressions::default();
let connection = Connection {
id: 0,
left: &left,
right: &right,
kind: ConnectionKind::Permutation,
multiplicity_column: None,
};

let machine_parts = MachineParts::new(
&fixed_data,
[(0, connection)].into_iter().collect(),
retained_identities,
witnesses_by_name.values().copied().collect(),
// No prover functions
Vec::new(),
);

let processor = BlockMachineProcessor {
fixed_data: &fixed_data,
machine_parts,
block_size,
latch_row,
};

let known_values = BitVec::from_iter(
input_names
.iter()
.map(|_| true)
.chain(output_names.iter().map(|_| false)),
);

processor.generate_code(0, known_values)
}

#[test]
fn add() {
let input = "
namespace Add(256);
col witness sel, a, b, c;
c = a + b;
";
let code =
generate_for_block_machine(input, 1, 0, "Add::sel", &["Add::a", "Add::b"], &["Add::c"]);
assert_eq!(
format_code(&code.unwrap()),
"Add::sel[0] = 1;
Add::a[0] = params[0];
Add::b[0] = params[1];
Add::c[0] = (Add::a[0] + Add::b[0]);
params[2] = Add::c[0];"
);
}

#[test]
// TODO: Currently fails, because the machine has a non-rectangular block shape.
#[should_panic = "Incomplete machine calls"]
fn binary() {
let input = read_to_string("../test_data/pil/binary.pil").unwrap();
generate_for_block_machine(
&input,
4,
3,
"main_binary::sel[0]",
&["main_binary::A", "main_binary::B"],
&["main_binary::C"],
)
.unwrap();
}

#[test]
fn poseidon() {
let input = read_to_string("../test_data/pil/poseidon_gl.pil").unwrap();
let array_element = |name: &str, i: usize| {
&*Box::leak(format!("main_poseidon::{name}[{i}]").into_boxed_str())
};
generate_for_block_machine(
&input,
31,
0,
"main_poseidon::sel[0]",
&(0..12)
.map(|i| array_element("state", i))
.collect::<Vec<_>>(),
&(0..4)
.map(|i| array_element("output", i))
.collect::<Vec<_>>(),
)
.unwrap();
}
}
1 change: 1 addition & 0 deletions executor/src/witgen/jit/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub(crate) mod affine_symbolic_expression;
mod block_machine_processor;
mod compiler;
pub(crate) mod jit_processor;
mod symbolic_expression;
Expand Down
21 changes: 21 additions & 0 deletions executor/src/witgen/jit/witgen_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,27 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator<T>> WitgenInference<'a, T, F
self.ingest_effects(result)
}

/// Process the constraint that the expression evaluated at the given offset equals the given affine expression.
/// Note that either the expression or the value might contain unknown variables, but if we are not able to
/// solve the equation, we return an error.
pub fn assign(
&mut self,
expression: &Expression<T>,
offset: i32,
value: AffineSymbolicExpression<T, Variable>,
) -> Result<(), String> {
let affine_expression = self
.evaluate(expression, offset)
.ok_or_else(|| format!("Expression is not affine: {expression}"))?;
let result = (affine_expression - value.clone())
.solve()
.map_err(|err| format!("Could not solve ({expression} - {value}): {err}"))?;
match self.ingest_effects(result).complete {
true => Ok(()),
false => Err("Wasn't able to complete the assignment".to_string()),
}
}

fn process_polynomial_identity(
&self,
expression: &Expression<T>,
Expand Down
44 changes: 44 additions & 0 deletions test_data/pil/binary.pil
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// A compiled version of std/machines/large_field/binary.asm

namespace main_binary(128);
col witness operation_id;
(main_binary::operation_id' - main_binary::operation_id) * (1 - main_binary::latch) = 0;
col fixed latch(i) { if i % 4_int == 3_int { 1_fe } else { 0_fe } };
col fixed FACTOR(i) { 1_int << (i + 1_int) % 4_int * 8_int };
col witness A_byte;
col witness B_byte;
col witness C_byte;
col witness A;
col witness B;
col witness C;
main_binary::A' = main_binary::A * (1 - main_binary::latch) + main_binary::A_byte * main_binary::FACTOR;
main_binary::B' = main_binary::B * (1 - main_binary::latch) + main_binary::B_byte * main_binary::FACTOR;
main_binary::C' = main_binary::C * (1 - main_binary::latch) + main_binary::C_byte * main_binary::FACTOR;
col witness operation_id_next;
main_binary::operation_id' = main_binary::operation_id_next;
col witness sel[3];
main_binary::sel[0] * (1 - main_binary::sel[0]) = 0;
main_binary::sel[1] * (1 - main_binary::sel[1]) = 0;
main_binary::sel[2] * (1 - main_binary::sel[2]) = 0;
[main_binary::operation_id_next, main_binary::A_byte, main_binary::B_byte, main_binary::C_byte] in [main_byte_binary::P_operation, main_byte_binary::P_A, main_byte_binary::P_B, main_byte_binary::P_C];

namespace main_byte_binary(262144);
let bit_counts: int[] = [256_int, 256_int, 3_int];
let inputs: (int -> int)[] = std::utils::cross_product(main_byte_binary::bit_counts);
let a: int -> int = main_byte_binary::inputs[0_int];
let b: int -> int = main_byte_binary::inputs[1_int];
let op: int -> int = main_byte_binary::inputs[2_int];
let P_A: col = main_byte_binary::a;
let P_B: col = main_byte_binary::b;
let P_operation: col = main_byte_binary::op;
col fixed P_C(i) { match main_byte_binary::op(i) {
0 => main_byte_binary::a(i) & main_byte_binary::b(i),
1 => main_byte_binary::a(i) | main_byte_binary::b(i),
2 => main_byte_binary::a(i) ^ main_byte_binary::b(i),
} };
namespace std::array;
let<T> len: T[] -> int = [];

namespace std::utils;
let cross_product: int[] -> (int -> int)[] = |sizes| std::utils::cross_product_internal(1_int, 0_int, sizes);
let cross_product_internal: int, int, int[] -> (int -> int)[] = |cycle_len, pos, sizes| if pos >= std::array::len::<int>(sizes) { [] } else { [|i| i / cycle_len % sizes[pos]] + std::utils::cross_product_internal(cycle_len * sizes[pos], pos + 1_int, sizes) };
Loading

0 comments on commit e2131f5

Please sign in to comment.