Skip to content

Commit

Permalink
A couple minor bug fixes (#144)
Browse files Browse the repository at this point in the history
* fix: write collector must visit calls to extracted body functions
* fix: SimplificationPass must record unknowns
* additional debug message
  • Loading branch information
tim-hoffman authored Jul 31, 2024
1 parent cdac086 commit 9447824
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 50 deletions.
32 changes: 28 additions & 4 deletions circuit_passes/src/bucket_interpreter/env/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ use compiler::intermediate_representation::BucketId;
use function_env::FunctionEnvData;
use indexmap::IndexSet;
use crate::passes::loop_unroll::body_extractor::LoopBodyExtractor;
use crate::passes::loop_unroll::{ToOriginalLocation, FuncArgIdx};
use crate::passes::loop_unroll::{FuncArgIdx, ToOriginalLocation, LOOP_BODY_FN_PREFIX};
use crate::passes::GlobalPassData;
use self::extracted_func_env::ExtractedFuncEnvData;
use self::template_env::TemplateEnvData;
use self::unrolled_block_env::UnrolledBlockEnvData;
Expand All @@ -22,6 +23,7 @@ mod function_env;
mod unrolled_block_env;
mod extracted_func_env;

const DEBUG_INTERPRETER: bool = false;
const PRINT_ENV_SORTED: bool = true;

#[inline]
Expand Down Expand Up @@ -222,6 +224,10 @@ impl<'a> Env<'a> {
Env::Template(TemplateEnvData::new(libs))
}

pub fn new_unroll_block_env(base: Env<'a>, extractor: &'a LoopBodyExtractor) -> Self {
Env::UnrolledBlock(UnrolledBlockEnvData::new(base, extractor))
}

pub fn new_source_func_env(
base: Env<'a>,
caller: &BucketId,
Expand All @@ -231,7 +237,7 @@ impl<'a> Env<'a> {
Env::Function(FunctionEnvData::new(base, caller, call_stack, libs))
}

pub fn new_extracted_func_env(
fn _new_extracted_func_env(
base: Env<'a>,
caller: &BucketId,
remap: ToOriginalLocation,
Expand All @@ -240,10 +246,28 @@ impl<'a> Env<'a> {
Env::ExtractedFunction(ExtractedFuncEnvData::new(base, caller, remap, arenas))
}

pub fn new_unroll_block_env(base: Env<'a>, extractor: &'a LoopBodyExtractor) -> Self {
Env::UnrolledBlock(UnrolledBlockEnvData::new(base, extractor))
pub fn new_extracted_func_env(
base: Env<'a>,
caller: &BucketId,
callee_name: &str,
gdat: Ref<GlobalPassData>,
) -> Self {
if callee_name.starts_with(LOOP_BODY_FN_PREFIX) {
if DEBUG_INTERPRETER {
println!("\ncurrent env = {}", base);
println!("callee_name = {}", callee_name);
println!("base.get_vars_sort() = {:?}", base.get_vars_sort());
println!("callee function data = {:?}", gdat.get_data_for_func(callee_name));
}
let fdat = &gdat.get_data_for_func(callee_name)[&base.get_vars_sort()];
Self::_new_extracted_func_env(base, caller, fdat.0.clone(), fdat.1.clone())
} else {
Self::_new_extracted_func_env(base, caller, Default::default(), Default::default())
}
}
}

impl Env<'_> {
// READ OPERATIONS
pub fn peel_extracted_func(self) -> Self {
match self {
Expand Down
13 changes: 2 additions & 11 deletions circuit_passes/src/bucket_interpreter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ use compiler::num_bigint::BigInt;
use observer::Observer;
use program_structure::error_code::ReportCode;
use crate::passes::builders::{build_compute, build_u32_value};
use crate::passes::loop_unroll::LOOP_BODY_FN_PREFIX;
use crate::passes::GlobalPassData;
use self::env::{CallStackFrame, Env, LibraryAccess};
use self::error::BadInterp;
Expand Down Expand Up @@ -745,15 +744,8 @@ impl BucketInterpreter<'_, '_> {
// calls below (that give ownership of the 'env' object into the new Env instance)
// to avoid copying the entire 'env' instance (which is likely more expensive).
let instructions = env.get_function(name).body.clone();
let mut res = (vec![], {
if name.starts_with(LOOP_BODY_FN_PREFIX) {
let gdat = self.global_data.borrow();
let fdat = &gdat.get_data_for_func(name)[&env.get_vars_sort()];
Env::new_extracted_func_env(env, &bucket.id, fdat.0.clone(), fdat.1.clone())
} else {
Env::new_extracted_func_env(env, &bucket.id, Default::default(), Default::default())
}
});
let mut res =
(vec![], Env::new_extracted_func_env(env, &bucket.id, name, self.global_data.borrow()));
//NOTE: Do not change scope for the new interpreter because the mem lookups
// within 'write_collector.rs' need to use the original function context.
let interp = self.mem.build_interpreter_with_flags(
Expand Down Expand Up @@ -820,7 +812,6 @@ impl BucketInterpreter<'_, '_> {
// unless self.flags.allow_nondetermined_return() == false because
// that case could result in no return statements being observed.
let func_val = if body_val.is_empty() && !self.flags.allow_nondetermined_return() {
// Some(Value::Unknown) // TODO: return the correct number of Unknowns
Result::Ok(vec![Value::Unknown; callee.returns.iter().product::<usize>()])
} else {
let vals = into_result(body_val, "value returned from function");
Expand Down
48 changes: 34 additions & 14 deletions circuit_passes/src/bucket_interpreter/write_collector.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,27 @@
use std::collections::HashSet;
use code_producers::llvm_elements::stdlib::GENERATED_FN_PREFIX;
use compiler::intermediate_representation::{
ir_interface::{AddressType, FinalData, LocationRule, LogBucketArg, ReturnType, StoreBucket},
Instruction, InstructionList, InstructionPointer,
};
use super::{env::Env, error::BadInterp, value::Value, BucketInterpreter, InterpRes};
use super::{
env::{Env, LibraryAccess},
error::BadInterp,
value::Value,
BucketInterpreter, InterpRes,
};

pub(crate) fn set_writes_to_unknown<'e>(
interp: &BucketInterpreter,
body: &InstructionList,
env: Env<'e>,
) -> Result<Env<'e>, BadInterp> {
let mut checker = Writes::default();
Result::from(checker.collect_writes(interp, body, env.clone())).map_or_else(
let mut collector = Writes::default();
Result::from(collector.check_body(interp, body, env.clone())).map_or_else(
|b| Result::Err(b),
// For the Ok case, ignore the Env computed within the body
// and just set Unknown to all writes that were found.
|_| checker.set_unknowns(env),
|_| collector.set_unknowns(env),
)
}

Expand Down Expand Up @@ -43,7 +49,7 @@ impl Writes {
.set_subcmps_to_unknown(self.subcmps)
}

fn collect_writes<'e>(
fn check_body<'e>(
&mut self,
interp: &BucketInterpreter,
body: &InstructionList,
Expand All @@ -65,13 +71,27 @@ impl Writes {
match inst {
Instruction::Store(b) => self.check_store_bucket(interp, b, env),
Instruction::Constraint(b) => self.check_inst(interp, b.unwrap(), env),
Instruction::Block(b) => self.collect_writes(interp, &b.body, env),
Instruction::Block(b) => self.check_body(interp, &b.body, env),
Instruction::Branch(b) => {
self.check_branch(interp, &b.cond, &b.if_branch, &b.else_branch, env)
}
Instruction::Loop(b) => {
self.check_branch(interp, &b.continue_condition, &b.body, &vec![], env)
}
Instruction::Call(b) if b.symbol.starts_with(GENERATED_FN_PREFIX) => {
let callee_name = &b.symbol;
let callee_body = env.get_function(callee_name).body.clone();
self.check_body(
interp,
&callee_body,
Env::new_extracted_func_env(
env,
&b.id,
callee_name,
interp.global_data.borrow(),
),
)
}
i => {
debug_assert!(!ContainsStore::contains_store(i));
InterpRes::Continue(env)
Expand All @@ -93,10 +113,10 @@ impl Writes {
// If the condition is unknown, collect all writes from both branches (even if
// there is a return in either, hence an InterpRes::Return result is ignored
// in both cases) and produce InterpRes::Continue with the original Env.
if let InterpRes::Err(e) = self.collect_writes(interp, true_branch, env.clone()) {
if let InterpRes::Err(e) = self.check_body(interp, true_branch, env.clone()) {
return InterpRes::Err(e);
}
if let InterpRes::Err(e) = self.collect_writes(interp, false_branch, env.clone()) {
if let InterpRes::Err(e) = self.check_body(interp, false_branch, env.clone()) {
return InterpRes::Err(e);
}
InterpRes::Continue(env)
Expand All @@ -105,17 +125,17 @@ impl Writes {
// If the condition is true, collect all writes from the false branch
// (ignoring an InterpRes::Return result as above) and then analyze
// and return the result from the true branch.
if let InterpRes::Err(e) = self.collect_writes(interp, false_branch, env.clone()) {
if let InterpRes::Err(e) = self.check_body(interp, false_branch, env.clone()) {
return InterpRes::Err(e);
}
self.collect_writes(interp, true_branch, env)
self.check_body(interp, true_branch, env)
}
Ok(Some(false)) => {
// Reverse of the true case.
if let InterpRes::Err(e) = self.collect_writes(interp, true_branch, env.clone()) {
if let InterpRes::Err(e) = self.check_body(interp, true_branch, env.clone()) {
return InterpRes::Err(e);
}
self.collect_writes(interp, false_branch, env)
self.check_body(interp, false_branch, env)
}
}
}
Expand Down Expand Up @@ -366,7 +386,7 @@ mod tests {
];

let mut checker = Writes::default();
let collect_res = checker.collect_writes(&interp, &body, env);
let collect_res = checker.check_body(&interp, &body, env);
assert!(!matches!(collect_res, InterpRes::Err(_)));
// EXPECT:
// - variables A, B (only index 0 in the vector), and C are written
Expand Down Expand Up @@ -434,7 +454,7 @@ mod tests {
];

let mut checker = Writes::default();
let collect_res = checker.collect_writes(&interp, &body, env);
let collect_res = checker.check_body(&interp, &body, env);
assert!(!matches!(collect_res, InterpRes::Err(_)));
// EXPECT:
// - no variables are written
Expand Down
7 changes: 6 additions & 1 deletion circuit_passes/src/passes/loop_unroll/observer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,12 @@ impl LoopUnrollObserver<'_> {
}
match cond {
// If the conditional becomes unknown just give up.
None => return Ok(None),
None => {
if DEBUG_LOOP_UNROLL {
println!("[UNROLL][try_unroll_loop] OUTCOME: not safe to move or unroll, condition unknown");
}
return Ok(None);
}
// When conditional becomes `false`, iteration count is complete.
Some(false) => break,
// Otherwise, continue counting.
Expand Down
2 changes: 1 addition & 1 deletion circuit_passes/src/passes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -786,7 +786,7 @@ impl GlobalPassData {

pub fn get_data_for_func(
&self,
name: &String,
name: &str,
) -> &BTreeMap<UnrolledIterLvars, (ToOriginalLocation, HashSet<FuncArgIdx>)> {
match self.extract_func_orig_loc.get(name) {
Some(x) => x,
Expand Down
61 changes: 42 additions & 19 deletions circuit_passes/src/passes/simplification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,20 +62,45 @@ impl<'d> SimplificationPass<'d> {
fn insert<V: PartialEq>(
map: &RefCell<HashMap<BucketId, Option<V>>>,
bucket_id: BucketId,
v: V,
new_val: Option<V>,
) {
map.borrow_mut()
.entry(bucket_id)
// If the entry exists and it's not the same as the new value, set to None.
// If the entry exists and it's not the same as the new value,
// or the new value itself is None, then set to result to None
// to indicate that no replacement should be made.
.and_modify(|old| {
if let Some(old_val) = old {
if *old_val != v {
*old = None
match &new_val {
None => *old = None,
Some(x) => {
if *x != *old_val {
*old = None
}
}
}
}
})
// If no entry exists, store the new value.
.or_insert(Some(v));
.or_insert(new_val);
}

fn store_computed_value(
map: &RefCell<HashMap<BucketId, Option<Value>>>,
bucket_id: BucketId,
computed_value: Value,
) -> Result<bool, BadInterp> {
if computed_value.is_unknown() {
// When the bucket's value is Unknown from any execution, add None to the map so
// the bucket will not be replaced (even if known at an execution found later),
// return 'true' so buckets nested within this bucket will be observed.
Self::insert(map, bucket_id, None);
Ok(true)
} else {
// Add known value to the map, return 'false' so observation will not continue within.
Self::insert(map, bucket_id, Some(computed_value));
Ok(false)
}
}
}

Expand All @@ -84,12 +109,7 @@ impl Observer<Env<'_>> for SimplificationPass<'_> {
let interp = self.build_interpreter();
let v = interp.compute_compute_bucket(bucket, env, false)?;
let v = result_types::into_single_result(v, "ComputeBucket")?;
if !v.is_unknown() {
Self::insert(&self.compute_replacements, bucket.id, v);
Ok(false)
} else {
Ok(true)
}
Self::store_computed_value(&self.compute_replacements, bucket.id, v)
}

fn on_call_bucket(&self, bucket: &CallBucket, env: &Env) -> Result<bool, BadInterp> {
Expand All @@ -99,12 +119,10 @@ impl Observer<Env<'_>> for SimplificationPass<'_> {
// rather than 'into_single_result()' and return 'true' in the None case
// so buckets nested within this bucket will be observed.
if let Some(v) = result_types::into_single_option(v) {
if !v.is_unknown() {
Self::insert(&self.call_replacements, bucket.id, v);
return Ok(false);
}
Self::store_computed_value(&self.call_replacements, bucket.id, v)
} else {
Ok(true)
}
Ok(true)
}

fn on_constraint_bucket(
Expand All @@ -113,7 +131,8 @@ impl Observer<Env<'_>> for SimplificationPass<'_> {
env: &Env,
) -> Result<bool, BadInterp> {
self.within_constraint.replace(true);
// Match the expected structure of ConstraintBucket instances but don't fail if there's something different.
// Match the expected structure of ConstraintBucket instances
// but don't fail if there's something different.
match bucket {
ConstraintBucket::Equality(e) => {
if let Instruction::Assert(AssertBucket { evaluate, .. }) = e.as_ref() {
Expand All @@ -130,7 +149,11 @@ impl Observer<Env<'_>> for SimplificationPass<'_> {
}
// If at least one is a known value, then we can (likely) simplify
if values.iter().any(Value::is_known) {
Self::insert(&self.constraint_eq_replacements, e.get_id(), values);
Self::insert(
&self.constraint_eq_replacements,
e.get_id(),
Some(values),
);
}
}
}
Expand Down Expand Up @@ -172,7 +195,7 @@ impl Observer<Env<'_>> for SimplificationPass<'_> {
Self::insert(
&self.constraint_sub_replacements,
e.get_id(),
(src, dest, dest_address_type),
Some((src, dest, dest_address_type)),
);
}
}
Expand Down

0 comments on commit 9447824

Please sign in to comment.