From c1b776312ad97a311631ee52935ed73ff88527b6 Mon Sep 17 00:00:00 2001 From: chriseth Date: Fri, 20 Dec 2024 14:31:41 +0100 Subject: [PATCH] use can_process (#2256) Co-authored-by: Georg Wiese --- .../witgen/data_structures/mutable_state.rs | 18 ++ .../src/witgen/jit/block_machine_processor.rs | 40 +++- executor/src/witgen/jit/compiler.rs | 1 + executor/src/witgen/jit/function_cache.rs | 23 +- executor/src/witgen/jit/witgen_inference.rs | 215 +++++++++++------- executor/src/witgen/machines/block_machine.rs | 2 +- 6 files changed, 194 insertions(+), 105 deletions(-) diff --git a/executor/src/witgen/data_structures/mutable_state.rs b/executor/src/witgen/data_structures/mutable_state.rs index 904a73e1e3..7ef23ca84c 100644 --- a/executor/src/witgen/data_structures/mutable_state.rs +++ b/executor/src/witgen/data_structures/mutable_state.rs @@ -3,10 +3,12 @@ use std::{ collections::{BTreeMap, HashMap}, }; +use bit_vec::BitVec; use powdr_number::FieldElement; use crate::witgen::{ machines::{KnownMachine, LookupCell, Machine}, + range_constraints::RangeConstraint, rows::RowPair, EvalError, EvalResult, QueryCallback, }; @@ -49,6 +51,22 @@ impl<'a, T: FieldElement, Q: QueryCallback> MutableState<'a, T, Q> { self.take_witness_col_values() } + pub fn can_process_call_fully( + &self, + identity_id: u64, + known_inputs: &BitVec, + range_constraints: &[Option>], + ) -> bool { + // TODO We are currently ignoring bus interaction (also, but not only because there is no + // unique machine responsible for handling a bus send), so just answer "false" if the identity + // has no responsible machine. + self.responsible_machine(identity_id) + .ok() + .map_or(false, |mut machine| { + machine.can_process_call_fully(identity_id, known_inputs, range_constraints) + }) + } + /// Call the machine responsible for the right-hand-side of an identity given its ID /// and the row pair of the caller. pub fn call(&self, identity_id: u64, caller_rows: &RowPair<'_, 'a, T>) -> EvalResult<'a, T> { diff --git a/executor/src/witgen/jit/block_machine_processor.rs b/executor/src/witgen/jit/block_machine_processor.rs index 2b6151d197..48114c9818 100644 --- a/executor/src/witgen/jit/block_machine_processor.rs +++ b/executor/src/witgen/jit/block_machine_processor.rs @@ -9,7 +9,7 @@ use crate::witgen::{machines::MachineParts, FixedData}; use super::{ effect::Effect, variable::Variable, - witgen_inference::{FixedEvaluator, WitgenInference}, + witgen_inference::{CanProcessCall, FixedEvaluator, WitgenInference}, }; /// A processor for generating JIT code for a block machine. @@ -37,8 +37,9 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> { /// Generates the JIT code for a given combination of connection and known arguments. /// Fails if it cannot solve for the outputs, or if any sub-machine calls cannot be completed. - pub fn generate_code( + pub fn generate_code + Clone>( &self, + can_process: CanProcess, identity_id: u64, known_args: &BitVec, ) -> Result>, String> { @@ -63,7 +64,7 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> { // Solve for the block witness. // Fails if any machine call cannot be completed. - self.solve_block(&mut witgen)?; + self.solve_block(can_process, &mut witgen)?; for (index, expr) in connection_rhs.expressions.iter().enumerate() { if !witgen.is_known(&Variable::Param(index)) { @@ -78,7 +79,11 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> { /// Repeatedly processes all identities on all rows, until no progress is made. /// Fails iff there are incomplete machine calls in the latch row. - fn solve_block(&self, witgen: &mut WitgenInference<'a, T, &Self>) -> Result<(), String> { + fn solve_block + Clone>( + &self, + can_process: CanProcess, + witgen: &mut WitgenInference<'a, T, &Self>, + ) -> Result<(), String> { let mut complete = HashSet::new(); for iteration in 0.. { let mut progress = false; @@ -87,7 +92,7 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> { for row in 0..self.block_size { for id in &self.machine_parts.identities { if !complete.contains(&(id.id(), row)) { - let result = witgen.process_identity(id, row as i32); + let result = witgen.process_identity(can_process.clone(), id, row as i32); if result.complete { complete.insert((id.id(), row)); } @@ -147,10 +152,9 @@ impl FixedEvaluator for &BlockMachineProcessor<'_, T> { mod test { use std::{collections::BTreeMap, fs::read_to_string}; - use bit_vec::BitVec; - use powdr_ast::analyzed::{ - AlgebraicExpression, AlgebraicReference, Analyzed, SelectedExpressions, - }; + use test_log::test; + + use powdr_ast::analyzed::{AlgebraicExpression, Analyzed, SelectedExpressions}; use powdr_number::GoldilocksField; use crate::{ @@ -159,11 +163,25 @@ mod test { global_constraints, jit::{effect::Effect, test_util::format_code}, machines::{Connection, ConnectionKind, MachineParts}, + range_constraints::RangeConstraint, FixedData, }, }; - use super::{BlockMachineProcessor, Variable}; + use super::*; + + #[derive(Clone)] + struct CannotProcessSubcalls; + impl CanProcessCall for CannotProcessSubcalls { + fn can_process_call_fully( + &self, + _identity_id: u64, + _known_inputs: &BitVec, + _range_constraints: &[Option>], + ) -> bool { + false + } + } fn generate_for_block_machine( input_pil: &str, @@ -240,7 +258,7 @@ mod test { .chain(output_names.iter().map(|_| false)), ); - processor.generate_code(0, &known_values) + processor.generate_code(CannotProcessSubcalls, 0, &known_values) } #[test] diff --git a/executor/src/witgen/jit/compiler.rs b/executor/src/witgen/jit/compiler.rs index bd1b73147d..656a1713f6 100644 --- a/executor/src/witgen/jit/compiler.rs +++ b/executor/src/witgen/jit/compiler.rs @@ -417,6 +417,7 @@ fn util_code(first_column_id: u64, column_count: usize) -> Resu #[cfg(test)] mod tests { use pretty_assertions::assert_eq; + use test_log::test; use powdr_number::GoldilocksField; diff --git a/executor/src/witgen/jit/function_cache.rs b/executor/src/witgen/jit/function_cache.rs index 4b0b2f73eb..2a79ba2d1f 100644 --- a/executor/src/witgen/jit/function_cache.rs +++ b/executor/src/witgen/jit/function_cache.rs @@ -50,8 +50,9 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> { /// Compiles the JIT function for the given identity and known arguments. /// Returns true if the function was successfully compiled. - pub fn compile_cached( + pub fn compile_cached>( &mut self, + mutable_state: &MutableState<'a, T, Q>, identity_id: u64, known_args: &BitVec, ) -> &Option> { @@ -59,27 +60,37 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> { identity_id, known_args: known_args.clone(), }; - self.ensure_cache(&cache_key); + self.ensure_cache(mutable_state, &cache_key); self.witgen_functions.get(&cache_key).unwrap() } - fn ensure_cache(&mut self, cache_key: &CacheKey) { + fn ensure_cache>( + &mut self, + mutable_state: &MutableState<'a, T, Q>, + cache_key: &CacheKey, + ) { if self.witgen_functions.contains_key(cache_key) { return; } let f = match T::known_field() { // Currently, we only support the Goldilocks fields - Some(KnownField::GoldilocksField) => self.compile_witgen_function(cache_key), + Some(KnownField::GoldilocksField) => { + self.compile_witgen_function(mutable_state, cache_key) + } _ => None, }; assert!(self.witgen_functions.insert(cache_key.clone(), f).is_none()) } - fn compile_witgen_function(&self, cache_key: &CacheKey) -> Option> { + fn compile_witgen_function>( + &self, + mutable_state: &MutableState<'a, T, Q>, + cache_key: &CacheKey, + ) -> Option> { log::trace!("Compiling JIT function for {:?}", cache_key); self.processor - .generate_code(cache_key.identity_id, &cache_key.known_args) + .generate_code(mutable_state, cache_key.identity_id, &cache_key.known_args) .ok() .map(|code| { log::trace!("Generated code ({} steps)", code.len()); diff --git a/executor/src/witgen/jit/witgen_inference.rs b/executor/src/witgen/jit/witgen_inference.rs index f52d18a293..ffb2e98f0e 100644 --- a/executor/src/witgen/jit/witgen_inference.rs +++ b/executor/src/witgen/jit/witgen_inference.rs @@ -1,18 +1,21 @@ use std::collections::{HashMap, HashSet}; +use bit_vec::BitVec; use itertools::Itertools; use powdr_ast::analyzed::{ AlgebraicBinaryOperation, AlgebraicBinaryOperator, AlgebraicExpression as Expression, AlgebraicReference, AlgebraicUnaryOperation, AlgebraicUnaryOperator, Identity, LookupIdentity, - PermutationIdentity, PhantomLookupIdentity, PhantomPermutationIdentity, PolynomialIdentity, - PolynomialType, SelectedExpressions, + PermutationIdentity, PhantomBusInteractionIdentity, PhantomLookupIdentity, + PhantomPermutationIdentity, PolynomialIdentity, PolynomialType, }; use powdr_number::FieldElement; -use crate::witgen::global_constraints::RangeConstraintSet; +use crate::witgen::{ + data_structures::mutable_state::MutableState, global_constraints::RangeConstraintSet, + range_constraints::RangeConstraint, FixedData, QueryCallback, +}; use super::{ - super::{range_constraints::RangeConstraint, FixedData}, affine_symbolic_expression::{AffineSymbolicExpression, ProcessResult}, effect::{Effect, MachineCallArgument}, variable::{MachineCallReturnVariable, Variable}, @@ -63,28 +66,32 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator> WitgenInference<'a, T, F } /// Process an identity on a certain row. - pub fn process_identity(&mut self, id: &'a Identity, row_offset: i32) -> ProcessSummary { + pub fn process_identity>( + &mut self, + can_process: CanProcess, + id: &'a Identity, + row_offset: i32, + ) -> ProcessSummary { let result = match id { Identity::Polynomial(PolynomialIdentity { expression, .. }) => { self.process_equality_on_row(expression, row_offset, T::from(0).into()) } - Identity::Lookup(LookupIdentity { - id, left, right, .. - }) - | Identity::Permutation(PermutationIdentity { - id, left, right, .. - }) - | Identity::PhantomPermutation(PhantomPermutationIdentity { - id, left, right, .. - }) - | Identity::PhantomLookup(PhantomLookupIdentity { - id, left, right, .. - }) => self.process_lookup(*id, left, right, row_offset), - Identity::PhantomBusInteraction(_) => { - // TODO(bus_interaction) Once we have a concept of "can_be_answered", bus interactions - // should be as easy as lookups. - ProcessResult::empty() - } + Identity::Lookup(LookupIdentity { id, left, .. }) + | Identity::Permutation(PermutationIdentity { id, left, .. }) + | Identity::PhantomPermutation(PhantomPermutationIdentity { id, left, .. }) + | Identity::PhantomLookup(PhantomLookupIdentity { id, left, .. }) => self.process_call( + can_process, + *id, + &left.selector, + &left.expressions, + row_offset, + ), + Identity::PhantomBusInteraction(PhantomBusInteractionIdentity { + id, + multiplicity, + tuple, + .. + }) => self.process_call(can_process, *id, multiplicity, &tuple.0, row_offset), Identity::Connect(_) => ProcessResult::empty(), }; self.ingest_effects(result) @@ -136,74 +143,69 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator> WitgenInference<'a, T, F } } - fn process_lookup( + fn process_call>( &mut self, - identity_id: u64, - left: &'a SelectedExpressions, - right: &'a SelectedExpressions, + can_process_call: CanProcess, + lookup_id: u64, + selector: &Expression, + arguments: &'a [Expression], row_offset: i32, ) -> ProcessResult { - // TODO: In the future, call the 'mutable state' to check if the - // lookup can always be answered. - - // If the RHS is fully fixed columns... - if right.expressions.iter().all(|e| match e { - Expression::Reference(r) => r.is_fixed(), - Expression::Number(_) => true, - _ => false, - }) { - // and the selector is known to be 1... - if self - .evaluate(&left.selector, row_offset) - .and_then(|s| s.try_to_known().map(|k| k.is_known_one())) - == Some(true) - { - if let Some(lhs) = left - .expressions - .iter() - .map(|e| self.evaluate(e, row_offset)) - .collect::>>() - { - // and all except one expression is known on the LHS. - let unknown = lhs - .iter() - .filter(|e| e.try_to_known().is_none()) - .collect_vec(); - if unknown.len() == 1 && unknown[0].single_unknown_variable().is_some() { - let args = left - .expressions - .iter() - .enumerate() - .map(|(index, e)| { - if let Some(known_value) = self - .evaluate(e, row_offset) - .and_then(|e| e.try_to_known().cloned()) - { - MachineCallArgument::Known(known_value.clone()) - } else { - let ret_var = MachineCallReturnVariable { - identity_id, - row_offset, - index, - }; - self.assign_variable( - e, - row_offset, - Variable::MachineCallReturnValue(ret_var.clone()), - ); - ret_var.into_argument() - } - }) - .collect(); - return ProcessResult::complete(vec![Effect::MachineCall( - identity_id, - args, - )]); - } + // We need to know the selector. + if self + .evaluate(selector, row_offset) + .and_then(|s| s.try_to_known().map(|k| k.is_known_one())) + != Some(true) + { + return ProcessResult::empty(); + } + let evaluated = arguments + .iter() + .map(|a| { + self.evaluate(a, row_offset) + .and_then(|a| a.try_to_known().cloned()) + }) + .collect::>(); + let range_constraints = evaluated + .iter() + .map(|e| e.as_ref().and_then(|e| e.range_constraint())) + .collect_vec(); + let known = evaluated.iter().map(|e| e.is_some()).collect(); + + if !can_process_call.can_process_call_fully(lookup_id, &known, &range_constraints) { + log::trace!( + "Sub-machine cannot process call fully (will retry later): {lookup_id}, arguments: {}", + arguments.iter().zip(known).map(|(arg, known)| { + format!("{arg} [{}]", if known { "known" } else { "unknown" }) + }).format(", ")); + return ProcessResult::empty(); + } + let args = evaluated + .into_iter() + .zip(arguments) + .enumerate() + .map(|(index, (eval_expr, arg))| { + if let Some(e) = eval_expr { + MachineCallArgument::Known(e) + } else { + let ret_var = MachineCallReturnVariable { + identity_id: lookup_id, + row_offset, + index, + }; + self.assign_variable( + arg, + row_offset, + Variable::MachineCallReturnValue(ret_var.clone()), + ); + ret_var.into_argument() } - } + }) + .collect_vec(); + ProcessResult { + effects: vec![Effect::MachineCall(lookup_id, args)], + complete: true, } - ProcessResult::empty() } fn process_assignments(&mut self) { @@ -445,10 +447,34 @@ pub trait FixedEvaluator { } } +pub trait CanProcessCall { + /// Returns true if a call to the machine that handles the given identity + /// can always be processed with the given known inputs and range constraints + /// on the parameters. + /// @see Machine::can_process_call + fn can_process_call_fully( + &self, + _identity_id: u64, + _known_inputs: &BitVec, + _range_constraints: &[Option>], + ) -> bool; +} + +impl<'a, T: FieldElement, Q: QueryCallback> CanProcessCall for &MutableState<'a, T, Q> { + fn can_process_call_fully( + &self, + identity_id: u64, + known_inputs: &BitVec, + range_constraints: &[Option>], + ) -> bool { + MutableState::can_process_call_fully(self, identity_id, known_inputs, range_constraints) + } +} + #[cfg(test)] mod test { - use pretty_assertions::assert_eq; + use test_log::test; use powdr_ast::analyzed::Analyzed; use powdr_number::GoldilocksField; @@ -458,6 +484,7 @@ mod test { witgen::{ global_constraints, jit::{test_util::format_code, variable::Cell}, + machines::{Connection, FixedLookup, KnownMachine}, FixedData, }, }; @@ -487,6 +514,20 @@ mod test { let (fixed_data, retained_identities) = global_constraints::set_global_constraints(fixed_data, &analyzed.identities); + let fixed_lookup_connections = retained_identities + .iter() + .filter_map(|i| Connection::try_from(*i).ok()) + .filter(|c| FixedLookup::is_responsible(c)) + .map(|c| (c.id, c)) + .collect(); + + let global_constr = fixed_data.global_range_constraints.clone(); + let fixed_machine = FixedLookup::new(global_constr, &fixed_data, fixed_lookup_connections); + let known_fixed = KnownMachine::FixedLookup(fixed_machine); + let mutable_state = MutableState::new([known_fixed].into_iter(), &|_| { + Err("Query not implemented".to_string()) + }); + let known_cells = known_cells.iter().map(|(name, row_offset)| { let id = fixed_data.try_column_by_name(name).unwrap().id; Variable::Cell(Cell { @@ -506,7 +547,7 @@ mod test { for row in rows { for id in retained_identities.iter() { if !complete.contains(&(id.id(), *row)) - && witgen.process_identity(id, *row).complete + && witgen.process_identity(&mutable_state, id, *row).complete { complete.insert((id.id(), *row)); } diff --git a/executor/src/witgen/machines/block_machine.rs b/executor/src/witgen/machines/block_machine.rs index fafe29624e..561e53a8a7 100644 --- a/executor/src/witgen/machines/block_machine.rs +++ b/executor/src/witgen/machines/block_machine.rs @@ -399,7 +399,7 @@ impl<'a, T: FieldElement> BlockMachine<'a, T> { let known_inputs = outer_query.left.iter().map(|e| e.is_constant()).collect(); if self .function_cache - .compile_cached(identity_id, &known_inputs) + .compile_cached(mutable_state, identity_id, &known_inputs) .is_some() { let updates = self.process_lookup_via_jit(mutable_state, identity_id, outer_query)?;