Skip to content

Commit

Permalink
Refactor block machine test
Browse files Browse the repository at this point in the history
  • Loading branch information
georgwiese committed Dec 20, 2024
1 parent ff40622 commit eaa9c65
Show file tree
Hide file tree
Showing 7 changed files with 136 additions and 123 deletions.
11 changes: 11 additions & 0 deletions executor/src/witgen/data_structures/mutable_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,17 @@ impl<'a, T: FieldElement, Q: QueryCallback<T>> MutableState<'a, T, Q> {
}
}

#[cfg(test)]
pub fn get_machine(&self, name: &str) -> &RefCell<KnownMachine<'a, T>> {
for m in &self.machines {
println!("Machine name: {}", m.borrow().name());
}
self.machines
.iter()
.find(|m| m.borrow().name() == name)
.unwrap()
}

/// Runs the first machine (unless there are no machines) end returns the generated columns.
/// The first machine might call other machines, which is handled automatically.
pub fn run(self) -> HashMap<String, Vec<T>> {
Expand Down
143 changes: 41 additions & 102 deletions executor/src/witgen/jit/block_machine_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,99 +150,46 @@ impl<T: FieldElement> FixedEvaluator<T> for &BlockMachineProcessor<'_, T> {

#[cfg(test)]
mod test {
use std::{collections::BTreeMap, fs::read_to_string};
use std::fs::read_to_string;

use test_log::test;

use powdr_ast::analyzed::{AlgebraicExpression, Analyzed, SelectedExpressions};
use powdr_number::GoldilocksField;

use crate::{
constant_evaluator,
witgen::{
global_constraints,
jit::{effect::Effect, test_util::format_code},
machines::{Connection, ConnectionKind, MachineParts},
range_constraints::RangeConstraint,
FixedData,
use crate::witgen::{
data_structures::mutable_state::MutableState,
global_constraints,
jit::{
effect::Effect,
test_util::{format_code, read_pil},
},
machines::{machine_extractor::MachineExtractor, KnownMachine, Machine},
};

use super::*;

#[derive(Clone)]
struct CannotProcessSubcalls;
impl<T: FieldElement> CanProcessCall<T> for CannotProcessSubcalls {
fn can_process_call_fully(
&self,
_identity_id: u64,
_known_inputs: &BitVec,
_range_constraints: &[Option<RangeConstraint<T>>],
) -> bool {
false
}
}

fn generate_for_block_machine(
input_pil: &str,
block_size: usize,
latch_row: usize,
selector_name: &str,
input_names: &[&str],
output_names: &[&str],
machine_name: &str,
num_inputs: usize,
num_outputs: usize,
) -> Result<Vec<Effect<GoldilocksField, Variable>>, String> {
let analyzed: Analyzed<GoldilocksField> =
powdr_pil_analyzer::analyze_string(input_pil).unwrap();
let fixed_col_vals = constant_evaluator::generate(&analyzed);
let (analyzed, fixed_col_vals) = read_pil(input_pil);

let fixed_data = FixedData::new(&analyzed, &fixed_col_vals, &[], Default::default(), 0);
let (fixed_data, retained_identities) =
global_constraints::set_global_constraints(fixed_data, &analyzed.identities);

// Build a connection that encodes:
// [] is <selector_name> $ [<input_names...>, <output_names...>]
let witnesses_by_name = analyzed
.committed_polys_in_source_order()
.flat_map(|(symbol, _)| symbol.array_elements())
.collect::<BTreeMap<_, _>>();
let to_expr = |name: &str| {
let (column_name, next) = if let Some(name) = name.strip_suffix("'") {
(name, true)
} else {
(name, false)
};
AlgebraicExpression::Reference(AlgebraicReference {
name: name.to_string(),
poly_id: witnesses_by_name[column_name],
next,
})
let machines = MachineExtractor::new(&fixed_data).split_out_machines(retained_identities);
let mutable_state = MutableState::new(machines.into_iter(), &|_| {
Err("Query not implemented".to_string())
});

let machine = mutable_state.get_machine(machine_name);
let ((machine_parts, block_size, latch_row), connection_ids) = match *machine.borrow() {
KnownMachine::BlockMachine(ref m) => (m.machine_info(), m.identity_ids()),
_ => panic!("Expected a block machine"),
};
let rhs = input_names
.iter()
.chain(output_names)
.map(|name| to_expr(name))
.collect::<Vec<_>>();
let right = SelectedExpressions {
selector: to_expr(selector_name),
expressions: rhs,
};
// The LHS is not used by the processor.
let left = SelectedExpressions::default();
let connection = Connection {
id: 0,
left: &left,
right: &right,
kind: ConnectionKind::Permutation,
multiplicity_column: None,
};

let machine_parts = MachineParts::new(
&fixed_data,
[(0, connection)].into_iter().collect(),
retained_identities,
witnesses_by_name.values().copied().collect(),
// No prover functions
Vec::new(),
);
assert_eq!(connection_ids.len(), 1);

let processor = BlockMachineProcessor {
fixed_data: &fixed_data,
Expand All @@ -252,24 +199,26 @@ mod test {
};

let known_values = BitVec::from_iter(
input_names
.iter()
(0..num_inputs)
.map(|_| true)
.chain(output_names.iter().map(|_| false)),
.chain((0..num_outputs).map(|_| false)),
);

processor.generate_code(CannotProcessSubcalls, 0, &known_values)
processor.generate_code(&mutable_state, connection_ids[0], &known_values)
}

#[test]
fn add() {
let input = "
namespace Main(256);
col witness a, b, c;
[a, b, c] is Add.sel $ [Add.a, Add.b, Add.c];
namespace Add(256);
col witness sel, a, b, c;
c = a + b;
";
let code =
generate_for_block_machine(input, 1, 0, "Add::sel", &["Add::a", "Add::b"], &["Add::c"]);
generate_for_block_machine(input, "Secondary machine 0: Add (BlockMachine)", 2, 1);
assert_eq!(
format_code(&code.unwrap()),
"Add::sel[0] = 1;
Expand All @@ -283,17 +232,18 @@ params[2] = Add::c[0];"
#[test]
fn unconstrained_output() {
let input = "
namespace Main(256);
col witness a, b, c;
[a, b, c] is Unconstrained.sel $ [Unconstrained.a, Unconstrained.b, Unconstrained.c];
namespace Unconstrained(256);
col witness sel, a, b, c;
a + b = 0;
";
let err_str = generate_for_block_machine(
input,
"Secondary machine 0: Unconstrained (BlockMachine)",
2,
1,
0,
"Unconstrained::sel",
&["Unconstrained::a", "Unconstrained::b"],
&["Unconstrained::c"],
)
.err()
.unwrap();
Expand All @@ -310,32 +260,21 @@ params[2] = Add::c[0];"
let input = read_to_string("../test_data/pil/binary.pil").unwrap();
generate_for_block_machine(
&input,
4,
3,
"main_binary::sel[0]",
&["main_binary::A", "main_binary::B"],
&["main_binary::C"],
"Secondary machine 0: main_binary (BlockMachine)",
2,
1,
)
.unwrap();
}

#[test]
fn poseidon() {
let input = read_to_string("../test_data/pil/poseidon_gl.pil").unwrap();
let array_element = |name: &str, i: usize| {
&*Box::leak(format!("main_poseidon::{name}[{i}]").into_boxed_str())
};
generate_for_block_machine(
&input,
31,
0,
"main_poseidon::sel[0]",
&(0..12)
.map(|i| array_element("state", i))
.collect::<Vec<_>>(),
&(0..4)
.map(|i| array_element("output", i))
.collect::<Vec<_>>(),
"Secondary machine 0: main_poseidon (BlockMachine)",
12,
4,
)
.unwrap();
}
Expand Down
49 changes: 30 additions & 19 deletions executor/src/witgen/jit/test_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@ use powdr_number::{FieldElement, GoldilocksField};
use crate::{
constant_evaluator,
witgen::{
data_structures::mutable_state::MutableState, global_constraints,
jit::effect::MachineCallArgument, machines::machine_extractor::MachineExtractor, FixedData,
data_structures::mutable_state::MutableState,
global_constraints,
jit::effect::MachineCallArgument,
machines::{machine_extractor::MachineExtractor, KnownMachine},
FixedData, QueryCallback,
},
};

Expand Down Expand Up @@ -57,21 +60,29 @@ pub fn read_pil<T: FieldElement>(
(analyzed, fixed_col_vals)
}

pub fn prepare<'a, T: FieldElement>(
analyzed: &'a Analyzed<T>,
fixed_col_vals: &'a [(String, VariablySizedColumn<T>)],
) -> (
FixedData<'a, T>,
MutableState<'a, T, _>,
Vec<&'a Identity<T>>,
) {
let fixed_data = FixedData::new(analyzed, fixed_col_vals, &[], Default::default(), 0);
let (fixed_data, retained_identities) =
global_constraints::set_global_constraints(fixed_data, &analyzed.identities);
// TODO: Doesn't compile
// pub fn prepare<'a, T: FieldElement, Q: QueryCallback<T>>(
// analyzed: &'a Analyzed<T>,
// fixed_col_vals: &'a [(String, VariablySizedColumn<T>)],
// external_witness_values: &'a [(String, Vec<T>)],
// query_callback: &'a Q,
// ) -> (
// FixedData<'a, T>,
// MutableState<'a, T, _>,
// Vec<&'a Identity<T>>,
// ) {
// let fixed_data = FixedData::new(
// analyzed,
// fixed_col_vals,
// external_witness_values,
// Default::default(),
// 0,
// );
// let (fixed_data, retained_identities) =
// global_constraints::set_global_constraints(fixed_data, &analyzed.identities);

let machines = MachineExtractor::new(&fixed_data).split_out_machines(retained_identities);
let mutable_state = MutableState::new(machines.into_iter(), &|_| {
Err("Query not implemented".to_string())
});
(fixed_data, mutable_state, retained_identities)
}
// let machines =
// MachineExtractor::new(&fixed_data).split_out_machines(retained_identities.clone());
// let mutable_state = MutableState::new(machines.into_iter(), query_callback);
// (fixed_data, mutable_state, retained_identities)
// }
22 changes: 20 additions & 2 deletions executor/src/witgen/jit/witgen_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -477,10 +477,12 @@ mod test {
use test_log::test;

use crate::witgen::{
global_constraints,
jit::{
test_util::{format_code, prepare, read_pil},
test_util::{format_code, read_pil},
variable::Cell,
},
machines::{Connection, FixedLookup, KnownMachine},
FixedData,
};

Expand All @@ -503,7 +505,23 @@ mod test {
expected_complete: Option<usize>,
) -> String {
let (analyzed, fixed_col_vals) = read_pil(input);
let (fixed_data, mutable_state, retained_identities) = prepare(&analyzed, &fixed_col_vals);
let fixed_data = FixedData::new(&analyzed, &fixed_col_vals, &[], Default::default(), 0);
let (fixed_data, retained_identities) =
global_constraints::set_global_constraints(fixed_data, &analyzed.identities);

let fixed_lookup_connections = retained_identities
.iter()
.filter_map(|i| Connection::try_from(*i).ok())
.filter(|c| FixedLookup::is_responsible(c))
.map(|c| (c.id, c))
.collect();

let global_constr = fixed_data.global_range_constraints.clone();
let fixed_machine = FixedLookup::new(global_constr, &fixed_data, fixed_lookup_connections);
let known_fixed = KnownMachine::FixedLookup(fixed_machine);
let mutable_state = MutableState::new([known_fixed].into_iter(), &|_| {
Err("Query not implemented".to_string())
});

let known_cells = known_cells.iter().map(|(name, row_offset)| {
let id = fixed_data.try_column_by_name(name).unwrap().id;
Expand Down
5 changes: 5 additions & 0 deletions executor/src/witgen/machines/block_machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,11 @@ impl<'a, T: FieldElement> BlockMachine<'a, T> {
block_count_runtime: 0,
})
}

#[cfg(test)]
pub fn machine_info(&self) -> (MachineParts<'a, T>, usize, usize) {
(self.parts.clone(), self.block_size, self.latch_row)
}
}

impl<'a, T: FieldElement> Machine<'a, T> for BlockMachine<'a, T> {
Expand Down
5 changes: 5 additions & 0 deletions test_data/pil/binary.pil
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
// A compiled version of std/machines/large_field/binary.asm

namespace main(128);
col witness a, b, c;
// Dummy connection constraint
[a, b, c] is main_binary::latch * main_binary::sel[0] $ [main_binary::A, main_binary::B, main_binary::C];

namespace main_binary(128);
col witness operation_id;
(main_binary::operation_id' - main_binary::operation_id) * (1 - main_binary::latch) = 0;
Expand Down
24 changes: 24 additions & 0 deletions test_data/pil/poseidon_gl.pil
Original file line number Diff line number Diff line change
@@ -1,5 +1,29 @@
// A compiled version of std/machines/hash/poseidon_gl.asm

namespace main(256);
col witness i1, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11, i12, o1, o2, o3, o4;
// Dummy connection constraint
[i1, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11, i12, o1, o2, o3, o4]
is
main_poseidon::FIRSTBLOCK * main_poseidon::sel[0] $ [
main_poseidon::state[0],
main_poseidon::state[1],
main_poseidon::state[2],
main_poseidon::state[3],
main_poseidon::state[4],
main_poseidon::state[5],
main_poseidon::state[6],
main_poseidon::state[7],
main_poseidon::state[8],
main_poseidon::state[9],
main_poseidon::state[10],
main_poseidon::state[11],
main_poseidon::output[0],
main_poseidon::output[1],
main_poseidon::output[2],
main_poseidon::output[3]
];

namespace main_poseidon(256);
let FULL_ROUNDS: int = 8_int;
let PARTIAL_ROUNDS: int = 22_int;
Expand Down

0 comments on commit eaa9c65

Please sign in to comment.