Skip to content

Commit

Permalink
JIT for block machines with non-rectangular shapes
Browse files Browse the repository at this point in the history
  • Loading branch information
georgwiese committed Dec 21, 2024
1 parent 38e371f commit 506e625
Show file tree
Hide file tree
Showing 9 changed files with 305 additions and 103 deletions.
237 changes: 196 additions & 41 deletions executor/src/witgen/jit/block_machine_processor.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
use std::collections::HashSet;
use std::collections::{BTreeSet, HashSet};

use bit_vec::BitVec;
use powdr_ast::analyzed::{AlgebraicReference, Identity};
use powdr_ast::analyzed::{AlgebraicReference, Identity, SelectedExpressions};
use powdr_number::FieldElement;

use crate::witgen::{machines::MachineParts, FixedData};
use crate::witgen::{
jit::{effect::format_code, witgen_inference::Value},
machines::MachineParts,
FixedData,
};

use super::{
effect::Effect,
variable::Variable,
variable::{Cell, Variable},
witgen_inference::{CanProcessCall, FixedEvaluator, WitgenInference},
};

Expand Down Expand Up @@ -64,17 +68,21 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> {

// Solve for the block witness.
// Fails if any machine call cannot be completed.
self.solve_block(can_process, &mut witgen)?;

for (index, expr) in connection_rhs.expressions.iter().enumerate() {
if !witgen.is_known(&Variable::Param(index)) {
return Err(format!(
"Unable to derive algorithm to compute output value \"{expr}\""
));
match self.solve_block(can_process, &mut witgen, connection_rhs) {
Ok(()) => Ok(witgen.code()),
Err(e) => {
log::debug!("Code generation failed: {e}");
log::debug!(
"The following code was generated so far:\n{}",
format_code(witgen.code().as_slice())
);
Err(format!("Code generation failed: {e}\nRun with RUST_LOG=debug to see the code generated so far."))
}
}
}

Ok(witgen.code())
fn row_range(&self) -> std::ops::Range<i32> {
-1..(self.block_size + 1) as i32
}

/// Repeatedly processes all identities on all rows, until no progress is made.
Expand All @@ -83,16 +91,16 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> {
&self,
can_process: CanProcess,
witgen: &mut WitgenInference<'a, T, &Self>,
connection_rhs: &SelectedExpressions<T>,
) -> Result<(), String> {
let mut complete = HashSet::new();
for iteration in 0.. {
let mut progress = false;

// TODO: This algorithm is assuming a rectangular block shape.
for row in 0..self.block_size {
for row in self.row_range() {
for id in &self.machine_parts.identities {
if !complete.contains(&(id.id(), row)) {
let result = witgen.process_identity(can_process.clone(), id, row as i32);
let result = witgen.process_identity(can_process.clone(), id, row);
if result.complete {
complete.insert((id.id(), row));
}
Expand All @@ -108,22 +116,106 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> {
}
}

// If any machine call could not be completed, that's bad because machine calls typically have side effects.
// So, the underlying lookup / permutation / bus argument likely does not hold.
// TODO: This assumes a rectangular block shape.
let has_incomplete_machine_calls = (0..self.block_size)
.flat_map(|row| {
self.machine_parts
.identities
.iter()
.filter(|id| is_machine_call(id))
.map(move |id| (id, row))
self.check_block_shape(witgen)?;
self.check_incomplete_machine_calls(&complete)?;

for (index, expr) in connection_rhs.expressions.iter().enumerate() {
if !witgen.is_known(&Variable::Param(index)) {
return Err(format!(
"Unable to derive algorithm to compute output value \"{expr}\""
));
}
}

Ok(())
}

/// After solving, the known values should be such that we can stack different blocks.
fn check_block_shape(&self, witgen: &mut WitgenInference<'a, T, &Self>) -> Result<(), String> {
let known_columns = witgen
.known_variables()
.iter()
.filter_map(|var| match var {
Variable::Cell(cell) => Some(cell.id),
_ => None,
})
.any(|(identity, row)| !complete.contains(&(identity.id(), row)));
.collect::<BTreeSet<_>>();

let can_stack = known_columns.iter().all(|column_id| {
let values = self
.row_range()
.map(|row| {
witgen.value(&Variable::Cell(Cell {
id: *column_id,
row_offset: row,
column_name: "".to_string(),
}))
})
.collect::<Vec<_>>();
// TODO: Improve error message?
// let column_name = self.fixed_data.column_name(&PolyID {
// id: *column_id,
// ptype: PolynomialType::Committed,
// });
let is_compatible = |v1: Value<T>, v2: Value<T>| match (v1, v2) {
(Value::Unknown, _) | (_, Value::Unknown) => true,
(Value::Concrete(a), Value::Concrete(b)) => a == b,
_ => false,
};
is_compatible(values[0], values[self.block_size])
&& is_compatible(values[1], values[self.block_size + 1])
});

match has_incomplete_machine_calls {
true => Err("Incomplete machine calls".to_string()),
false => Ok(()),
match can_stack {
true => Ok(()),
false => Err("Block machine shape does not allow stacking".to_string()),
}
}

/// If any machine call could not be completed, that's bad because machine calls typically have side effects.
/// So, the underlying lookup / permutation / bus argument likely does not hold.
/// This function checks that all machine calls are complete, at least for a window of <block_size> rows.
fn check_incomplete_machine_calls(&self, complete: &HashSet<(u64, i32)>) -> Result<(), String> {
let machine_calls = self
.machine_parts
.identities
.iter()
.filter(|id| is_machine_call(id));

let incomplete_machine_calls = machine_calls
.flat_map(|call| {
let complete_rows = self
.row_range()
.filter(|row| complete.contains(&(call.id(), *row)))
.collect::<Vec<_>>();
// Because we process rows -1..block_size+1, it is fine to have two incomplete machine calls,
// as long as <block_size> consecutive rows are complete.
if complete_rows.len() >= self.block_size {
let is_consecutive = complete_rows.iter().max().unwrap()
- complete_rows.iter().min().unwrap()
== complete_rows.len() as i32 - 1;
if is_consecutive {
return vec![];
}
}
self.row_range()
.filter(|row| !complete.contains(&(call.id(), *row)))
.map(|row| (call, row))
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();

if !incomplete_machine_calls.is_empty() {
Err(format!(
"Incomplete machine calls:\n {}",
incomplete_machine_calls
.iter()
.map(|(id, row)| format!("{id} (row {row})"))
.collect::<Vec<_>>()
.join("\n ")
))
} else {
Ok(())
}
}
}
Expand All @@ -143,7 +235,16 @@ impl<T: FieldElement> FixedEvaluator<T> for &BlockMachineProcessor<'_, T> {
fn evaluate(&self, var: &AlgebraicReference, row_offset: i32) -> Option<T> {
assert!(var.is_fixed());
let values = self.fixed_data.fixed_cols[&var.poly_id].values_max_size();
let row = (row_offset + var.next as i32 + values.len() as i32) as usize % values.len();

// By assumption of the block machine, all fixed columns are cyclic with a period of <block_size>.
// An exception might be the first and last row.
assert!(row_offset >= -1);
assert!(self.block_size >= 1);
// The current row is guaranteed to be at least 1.
let current_row = (2 * self.block_size as i32 + row_offset) as usize;
let row = current_row + var.next as usize;

assert!(values.len() >= self.block_size * 4);
Some(values[row])
}
}
Expand All @@ -159,10 +260,7 @@ mod test {
use crate::witgen::{
data_structures::mutable_state::MutableState,
global_constraints,
jit::{
effect::Effect,
test_util::{format_code, read_pil},
},
jit::{effect::Effect, test_util::read_pil},
machines::{machine_extractor::MachineExtractor, KnownMachine, Machine},
};

Expand Down Expand Up @@ -241,18 +339,75 @@ params[2] = Add::c[0];"
let err_str = generate_for_block_machine(input, "Unconstrained", 2, 1)
.err()
.unwrap();
assert_eq!(
err_str,
"Unable to derive algorithm to compute output value \"Unconstrained::c\""
);
assert!(err_str
.contains("Unable to derive algorithm to compute output value \"Unconstrained::c\""));
}

#[test]
#[should_panic = "Block machine shape does not allow stacking"]
fn not_stackable() {
let input = "
namespace Main(256);
col witness a, b, c;
[a] is NotStackable.sel $ [NotStackable.a];
namespace NotStackable(256);
col witness sel, a;
a = a';
";
generate_for_block_machine(input, "NotStackable", 1, 0).unwrap();
}

#[test]
// TODO: Currently fails, because the machine has a non-rectangular block shape.
#[should_panic = "Incomplete machine calls"]
fn binary() {
let input = read_to_string("../test_data/pil/binary.pil").unwrap();
generate_for_block_machine(&input, "main_binary", 3, 1).unwrap();
let code = generate_for_block_machine(&input, "main_binary", 3, 1).unwrap();
assert_eq!(
format_code(&code),
"main_binary::sel[0][3] = 1;
main_binary::operation_id[3] = params[0];
main_binary::A[3] = params[1];
main_binary::B[3] = params[2];
main_binary::operation_id[2] = main_binary::operation_id[3];
main_binary::A_byte[2] = ((main_binary::A[3] & 4278190080) // 16777216);
main_binary::A[2] = (main_binary::A[3] & 16777215);
assert (main_binary::A[3] & 18446744069414584320) == 0;
main_binary::B_byte[2] = ((main_binary::B[3] & 4278190080) // 16777216);
main_binary::B[2] = (main_binary::B[3] & 16777215);
assert (main_binary::B[3] & 18446744069414584320) == 0;
main_binary::operation_id_next[2] = main_binary::operation_id[3];
machine_call(9, [Known(main_binary::operation_id_next[2]), Known(main_binary::A_byte[2]), Known(main_binary::B_byte[2]), Unknown(ret(9, 2, 3))]);
main_binary::C_byte[2] = ret(9, 2, 3);
main_binary::operation_id[1] = main_binary::operation_id[2];
main_binary::A_byte[1] = ((main_binary::A[2] & 16711680) // 65536);
main_binary::A[1] = (main_binary::A[2] & 65535);
assert (main_binary::A[2] & 18446744073692774400) == 0;
main_binary::B_byte[1] = ((main_binary::B[2] & 16711680) // 65536);
main_binary::B[1] = (main_binary::B[2] & 65535);
assert (main_binary::B[2] & 18446744073692774400) == 0;
main_binary::operation_id_next[1] = main_binary::operation_id[2];
machine_call(9, [Known(main_binary::operation_id_next[1]), Known(main_binary::A_byte[1]), Known(main_binary::B_byte[1]), Unknown(ret(9, 1, 3))]);
main_binary::C_byte[1] = ret(9, 1, 3);
main_binary::operation_id[0] = main_binary::operation_id[1];
main_binary::A_byte[0] = ((main_binary::A[1] & 65280) // 256);
main_binary::A[0] = (main_binary::A[1] & 255);
assert (main_binary::A[1] & 18446744073709486080) == 0;
main_binary::B_byte[0] = ((main_binary::B[1] & 65280) // 256);
main_binary::B[0] = (main_binary::B[1] & 255);
assert (main_binary::B[1] & 18446744073709486080) == 0;
main_binary::operation_id_next[0] = main_binary::operation_id[1];
machine_call(9, [Known(main_binary::operation_id_next[0]), Known(main_binary::A_byte[0]), Known(main_binary::B_byte[0]), Unknown(ret(9, 0, 3))]);
main_binary::C_byte[0] = ret(9, 0, 3);
main_binary::A_byte[-1] = main_binary::A[0];
main_binary::B_byte[-1] = main_binary::B[0];
main_binary::operation_id_next[-1] = main_binary::operation_id[0];
machine_call(9, [Known(main_binary::operation_id_next[-1]), Known(main_binary::A_byte[-1]), Known(main_binary::B_byte[-1]), Unknown(ret(9, -1, 3))]);
main_binary::C_byte[-1] = ret(9, -1, 3);
main_binary::C[0] = main_binary::C_byte[-1];
main_binary::C[1] = (main_binary::C[0] + (main_binary::C_byte[0] * 256));
main_binary::C[2] = (main_binary::C[1] + (main_binary::C_byte[1] * 65536));
main_binary::C[3] = (main_binary::C[2] + (main_binary::C_byte[2] * 16777216));
params[3] = main_binary::C[3];"
)
}

#[test]
Expand Down
36 changes: 35 additions & 1 deletion executor/src/witgen/jit/effect.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use itertools::Itertools;
use powdr_number::FieldElement;

use crate::witgen::range_constraints::RangeConstraint;

use super::symbolic_expression::SymbolicExpression;
use super::{symbolic_expression::SymbolicExpression, variable::Variable};

/// The effect of solving a symbolic equation.
pub enum Effect<T: FieldElement, V> {
Expand Down Expand Up @@ -55,3 +56,36 @@ pub enum MachineCallArgument<T: FieldElement, V> {
Known(SymbolicExpression<T, V>),
Unknown(V),
}

pub fn format_code<T: FieldElement>(effects: &[Effect<T, Variable>]) -> String {
effects
.iter()
.map(|effect| match effect {
Effect::Assignment(v, expr) => format!("{v} = {expr};"),
Effect::Assertion(Assertion {
lhs,
rhs,
expected_equal,
}) => {
format!(
"assert {lhs} {} {rhs};",
if *expected_equal { "==" } else { "!=" }
)
}
Effect::MachineCall(id, args) => {
format!(
"machine_call({id}, [{}]);",
args.iter()
.map(|arg| match arg {
MachineCallArgument::Known(k) => format!("Known({k})"),
MachineCallArgument::Unknown(u) => format!("Unknown({u})"),
})
.join(", ")
)
}
Effect::RangeConstraint(..) => {
panic!("Range constraints should not be part of the code.")
}
})
.join("\n")
}
Loading

0 comments on commit 506e625

Please sign in to comment.