Skip to content

Commit

Permalink
use can_process (#2256)
Browse files Browse the repository at this point in the history
Co-authored-by: Georg Wiese <[email protected]>
  • Loading branch information
chriseth and georgwiese authored Dec 20, 2024
1 parent 550a1e6 commit c1b7763
Show file tree
Hide file tree
Showing 6 changed files with 194 additions and 105 deletions.
18 changes: 18 additions & 0 deletions executor/src/witgen/data_structures/mutable_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ use std::{
collections::{BTreeMap, HashMap},
};

use bit_vec::BitVec;
use powdr_number::FieldElement;

use crate::witgen::{
machines::{KnownMachine, LookupCell, Machine},
range_constraints::RangeConstraint,
rows::RowPair,
EvalError, EvalResult, QueryCallback,
};
Expand Down Expand Up @@ -49,6 +51,22 @@ impl<'a, T: FieldElement, Q: QueryCallback<T>> MutableState<'a, T, Q> {
self.take_witness_col_values()
}

pub fn can_process_call_fully(
&self,
identity_id: u64,
known_inputs: &BitVec,
range_constraints: &[Option<RangeConstraint<T>>],
) -> bool {
// TODO We are currently ignoring bus interaction (also, but not only because there is no
// unique machine responsible for handling a bus send), so just answer "false" if the identity
// has no responsible machine.
self.responsible_machine(identity_id)
.ok()
.map_or(false, |mut machine| {
machine.can_process_call_fully(identity_id, known_inputs, range_constraints)
})
}

/// Call the machine responsible for the right-hand-side of an identity given its ID
/// and the row pair of the caller.
pub fn call(&self, identity_id: u64, caller_rows: &RowPair<'_, 'a, T>) -> EvalResult<'a, T> {
Expand Down
40 changes: 29 additions & 11 deletions executor/src/witgen/jit/block_machine_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::witgen::{machines::MachineParts, FixedData};
use super::{
effect::Effect,
variable::Variable,
witgen_inference::{FixedEvaluator, WitgenInference},
witgen_inference::{CanProcessCall, FixedEvaluator, WitgenInference},
};

/// A processor for generating JIT code for a block machine.
Expand Down Expand Up @@ -37,8 +37,9 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> {

/// Generates the JIT code for a given combination of connection and known arguments.
/// Fails if it cannot solve for the outputs, or if any sub-machine calls cannot be completed.
pub fn generate_code(
pub fn generate_code<CanProcess: CanProcessCall<T> + Clone>(
&self,
can_process: CanProcess,
identity_id: u64,
known_args: &BitVec,
) -> Result<Vec<Effect<T, Variable>>, String> {
Expand All @@ -63,7 +64,7 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> {

// Solve for the block witness.
// Fails if any machine call cannot be completed.
self.solve_block(&mut witgen)?;
self.solve_block(can_process, &mut witgen)?;

for (index, expr) in connection_rhs.expressions.iter().enumerate() {
if !witgen.is_known(&Variable::Param(index)) {
Expand All @@ -78,7 +79,11 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> {

/// Repeatedly processes all identities on all rows, until no progress is made.
/// Fails iff there are incomplete machine calls in the latch row.
fn solve_block(&self, witgen: &mut WitgenInference<'a, T, &Self>) -> Result<(), String> {
fn solve_block<CanProcess: CanProcessCall<T> + Clone>(
&self,
can_process: CanProcess,
witgen: &mut WitgenInference<'a, T, &Self>,
) -> Result<(), String> {
let mut complete = HashSet::new();
for iteration in 0.. {
let mut progress = false;
Expand All @@ -87,7 +92,7 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> {
for row in 0..self.block_size {
for id in &self.machine_parts.identities {
if !complete.contains(&(id.id(), row)) {
let result = witgen.process_identity(id, row as i32);
let result = witgen.process_identity(can_process.clone(), id, row as i32);
if result.complete {
complete.insert((id.id(), row));
}
Expand Down Expand Up @@ -147,10 +152,9 @@ impl<T: FieldElement> FixedEvaluator<T> for &BlockMachineProcessor<'_, T> {
mod test {
use std::{collections::BTreeMap, fs::read_to_string};

use bit_vec::BitVec;
use powdr_ast::analyzed::{
AlgebraicExpression, AlgebraicReference, Analyzed, SelectedExpressions,
};
use test_log::test;

use powdr_ast::analyzed::{AlgebraicExpression, Analyzed, SelectedExpressions};
use powdr_number::GoldilocksField;

use crate::{
Expand All @@ -159,11 +163,25 @@ mod test {
global_constraints,
jit::{effect::Effect, test_util::format_code},
machines::{Connection, ConnectionKind, MachineParts},
range_constraints::RangeConstraint,
FixedData,
},
};

use super::{BlockMachineProcessor, Variable};
use super::*;

#[derive(Clone)]
struct CannotProcessSubcalls;
impl<T: FieldElement> CanProcessCall<T> for CannotProcessSubcalls {
fn can_process_call_fully(
&self,
_identity_id: u64,
_known_inputs: &BitVec,
_range_constraints: &[Option<RangeConstraint<T>>],
) -> bool {
false
}
}

fn generate_for_block_machine(
input_pil: &str,
Expand Down Expand Up @@ -240,7 +258,7 @@ mod test {
.chain(output_names.iter().map(|_| false)),
);

processor.generate_code(0, &known_values)
processor.generate_code(CannotProcessSubcalls, 0, &known_values)
}

#[test]
Expand Down
1 change: 1 addition & 0 deletions executor/src/witgen/jit/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,7 @@ fn util_code<T: FieldElement>(first_column_id: u64, column_count: usize) -> Resu
#[cfg(test)]
mod tests {
use pretty_assertions::assert_eq;
use test_log::test;

use powdr_number::GoldilocksField;

Expand Down
23 changes: 17 additions & 6 deletions executor/src/witgen/jit/function_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,36 +50,47 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {

/// Compiles the JIT function for the given identity and known arguments.
/// Returns true if the function was successfully compiled.
pub fn compile_cached(
pub fn compile_cached<Q: QueryCallback<T>>(
&mut self,
mutable_state: &MutableState<'a, T, Q>,
identity_id: u64,
known_args: &BitVec,
) -> &Option<WitgenFunction<T>> {
let cache_key = CacheKey {
identity_id,
known_args: known_args.clone(),
};
self.ensure_cache(&cache_key);
self.ensure_cache(mutable_state, &cache_key);
self.witgen_functions.get(&cache_key).unwrap()
}

fn ensure_cache(&mut self, cache_key: &CacheKey) {
fn ensure_cache<Q: QueryCallback<T>>(
&mut self,
mutable_state: &MutableState<'a, T, Q>,
cache_key: &CacheKey,
) {
if self.witgen_functions.contains_key(cache_key) {
return;
}

let f = match T::known_field() {
// Currently, we only support the Goldilocks fields
Some(KnownField::GoldilocksField) => self.compile_witgen_function(cache_key),
Some(KnownField::GoldilocksField) => {
self.compile_witgen_function(mutable_state, cache_key)
}
_ => None,
};
assert!(self.witgen_functions.insert(cache_key.clone(), f).is_none())
}

fn compile_witgen_function(&self, cache_key: &CacheKey) -> Option<WitgenFunction<T>> {
fn compile_witgen_function<Q: QueryCallback<T>>(
&self,
mutable_state: &MutableState<'a, T, Q>,
cache_key: &CacheKey,
) -> Option<WitgenFunction<T>> {
log::trace!("Compiling JIT function for {:?}", cache_key);
self.processor
.generate_code(cache_key.identity_id, &cache_key.known_args)
.generate_code(mutable_state, cache_key.identity_id, &cache_key.known_args)
.ok()
.map(|code| {
log::trace!("Generated code ({} steps)", code.len());
Expand Down
Loading

0 comments on commit c1b7763

Please sign in to comment.