From 9447824b3963afe994c4e059cdc837eedf86f5a4 Mon Sep 17 00:00:00 2001 From: Timothy Hoffman <4001421+tim-hoffman@users.noreply.github.com> Date: Wed, 31 Jul 2024 10:01:27 -0500 Subject: [PATCH] A couple minor bug fixes (#144) * fix: write collector must visit calls to extracted body functions * fix: SimplificationPass must record unknowns * additional debug message --- .../src/bucket_interpreter/env/mod.rs | 32 ++++++++-- circuit_passes/src/bucket_interpreter/mod.rs | 13 +--- .../src/bucket_interpreter/write_collector.rs | 48 ++++++++++----- .../src/passes/loop_unroll/observer.rs | 7 ++- circuit_passes/src/passes/mod.rs | 2 +- circuit_passes/src/passes/simplification.rs | 61 +++++++++++++------ 6 files changed, 113 insertions(+), 50 deletions(-) diff --git a/circuit_passes/src/bucket_interpreter/env/mod.rs b/circuit_passes/src/bucket_interpreter/env/mod.rs index d9b82a7ec..b070eb57f 100644 --- a/circuit_passes/src/bucket_interpreter/env/mod.rs +++ b/circuit_passes/src/bucket_interpreter/env/mod.rs @@ -8,7 +8,8 @@ use compiler::intermediate_representation::BucketId; use function_env::FunctionEnvData; use indexmap::IndexSet; use crate::passes::loop_unroll::body_extractor::LoopBodyExtractor; -use crate::passes::loop_unroll::{ToOriginalLocation, FuncArgIdx}; +use crate::passes::loop_unroll::{FuncArgIdx, ToOriginalLocation, LOOP_BODY_FN_PREFIX}; +use crate::passes::GlobalPassData; use self::extracted_func_env::ExtractedFuncEnvData; use self::template_env::TemplateEnvData; use self::unrolled_block_env::UnrolledBlockEnvData; @@ -22,6 +23,7 @@ mod function_env; mod unrolled_block_env; mod extracted_func_env; +const DEBUG_INTERPRETER: bool = false; const PRINT_ENV_SORTED: bool = true; #[inline] @@ -222,6 +224,10 @@ impl<'a> Env<'a> { Env::Template(TemplateEnvData::new(libs)) } + pub fn new_unroll_block_env(base: Env<'a>, extractor: &'a LoopBodyExtractor) -> Self { + Env::UnrolledBlock(UnrolledBlockEnvData::new(base, extractor)) + } + pub fn new_source_func_env( base: Env<'a>, caller: &BucketId, @@ -231,7 +237,7 @@ impl<'a> Env<'a> { Env::Function(FunctionEnvData::new(base, caller, call_stack, libs)) } - pub fn new_extracted_func_env( + fn _new_extracted_func_env( base: Env<'a>, caller: &BucketId, remap: ToOriginalLocation, @@ -240,10 +246,28 @@ impl<'a> Env<'a> { Env::ExtractedFunction(ExtractedFuncEnvData::new(base, caller, remap, arenas)) } - pub fn new_unroll_block_env(base: Env<'a>, extractor: &'a LoopBodyExtractor) -> Self { - Env::UnrolledBlock(UnrolledBlockEnvData::new(base, extractor)) + pub fn new_extracted_func_env( + base: Env<'a>, + caller: &BucketId, + callee_name: &str, + gdat: Ref, + ) -> Self { + if callee_name.starts_with(LOOP_BODY_FN_PREFIX) { + if DEBUG_INTERPRETER { + println!("\ncurrent env = {}", base); + println!("callee_name = {}", callee_name); + println!("base.get_vars_sort() = {:?}", base.get_vars_sort()); + println!("callee function data = {:?}", gdat.get_data_for_func(callee_name)); + } + let fdat = &gdat.get_data_for_func(callee_name)[&base.get_vars_sort()]; + Self::_new_extracted_func_env(base, caller, fdat.0.clone(), fdat.1.clone()) + } else { + Self::_new_extracted_func_env(base, caller, Default::default(), Default::default()) + } } +} +impl Env<'_> { // READ OPERATIONS pub fn peel_extracted_func(self) -> Self { match self { diff --git a/circuit_passes/src/bucket_interpreter/mod.rs b/circuit_passes/src/bucket_interpreter/mod.rs index 5cd9b2888..904273336 100644 --- a/circuit_passes/src/bucket_interpreter/mod.rs +++ b/circuit_passes/src/bucket_interpreter/mod.rs @@ -24,7 +24,6 @@ use compiler::num_bigint::BigInt; use observer::Observer; use program_structure::error_code::ReportCode; use crate::passes::builders::{build_compute, build_u32_value}; -use crate::passes::loop_unroll::LOOP_BODY_FN_PREFIX; use crate::passes::GlobalPassData; use self::env::{CallStackFrame, Env, LibraryAccess}; use self::error::BadInterp; @@ -745,15 +744,8 @@ impl BucketInterpreter<'_, '_> { // calls below (that give ownership of the 'env' object into the new Env instance) // to avoid copying the entire 'env' instance (which is likely more expensive). let instructions = env.get_function(name).body.clone(); - let mut res = (vec![], { - if name.starts_with(LOOP_BODY_FN_PREFIX) { - let gdat = self.global_data.borrow(); - let fdat = &gdat.get_data_for_func(name)[&env.get_vars_sort()]; - Env::new_extracted_func_env(env, &bucket.id, fdat.0.clone(), fdat.1.clone()) - } else { - Env::new_extracted_func_env(env, &bucket.id, Default::default(), Default::default()) - } - }); + let mut res = + (vec![], Env::new_extracted_func_env(env, &bucket.id, name, self.global_data.borrow())); //NOTE: Do not change scope for the new interpreter because the mem lookups // within 'write_collector.rs' need to use the original function context. let interp = self.mem.build_interpreter_with_flags( @@ -820,7 +812,6 @@ impl BucketInterpreter<'_, '_> { // unless self.flags.allow_nondetermined_return() == false because // that case could result in no return statements being observed. let func_val = if body_val.is_empty() && !self.flags.allow_nondetermined_return() { - // Some(Value::Unknown) // TODO: return the correct number of Unknowns Result::Ok(vec![Value::Unknown; callee.returns.iter().product::()]) } else { let vals = into_result(body_val, "value returned from function"); diff --git a/circuit_passes/src/bucket_interpreter/write_collector.rs b/circuit_passes/src/bucket_interpreter/write_collector.rs index 8d7650a63..9ad3e25b3 100644 --- a/circuit_passes/src/bucket_interpreter/write_collector.rs +++ b/circuit_passes/src/bucket_interpreter/write_collector.rs @@ -1,21 +1,27 @@ use std::collections::HashSet; +use code_producers::llvm_elements::stdlib::GENERATED_FN_PREFIX; use compiler::intermediate_representation::{ ir_interface::{AddressType, FinalData, LocationRule, LogBucketArg, ReturnType, StoreBucket}, Instruction, InstructionList, InstructionPointer, }; -use super::{env::Env, error::BadInterp, value::Value, BucketInterpreter, InterpRes}; +use super::{ + env::{Env, LibraryAccess}, + error::BadInterp, + value::Value, + BucketInterpreter, InterpRes, +}; pub(crate) fn set_writes_to_unknown<'e>( interp: &BucketInterpreter, body: &InstructionList, env: Env<'e>, ) -> Result, BadInterp> { - let mut checker = Writes::default(); - Result::from(checker.collect_writes(interp, body, env.clone())).map_or_else( + let mut collector = Writes::default(); + Result::from(collector.check_body(interp, body, env.clone())).map_or_else( |b| Result::Err(b), // For the Ok case, ignore the Env computed within the body // and just set Unknown to all writes that were found. - |_| checker.set_unknowns(env), + |_| collector.set_unknowns(env), ) } @@ -43,7 +49,7 @@ impl Writes { .set_subcmps_to_unknown(self.subcmps) } - fn collect_writes<'e>( + fn check_body<'e>( &mut self, interp: &BucketInterpreter, body: &InstructionList, @@ -65,13 +71,27 @@ impl Writes { match inst { Instruction::Store(b) => self.check_store_bucket(interp, b, env), Instruction::Constraint(b) => self.check_inst(interp, b.unwrap(), env), - Instruction::Block(b) => self.collect_writes(interp, &b.body, env), + Instruction::Block(b) => self.check_body(interp, &b.body, env), Instruction::Branch(b) => { self.check_branch(interp, &b.cond, &b.if_branch, &b.else_branch, env) } Instruction::Loop(b) => { self.check_branch(interp, &b.continue_condition, &b.body, &vec![], env) } + Instruction::Call(b) if b.symbol.starts_with(GENERATED_FN_PREFIX) => { + let callee_name = &b.symbol; + let callee_body = env.get_function(callee_name).body.clone(); + self.check_body( + interp, + &callee_body, + Env::new_extracted_func_env( + env, + &b.id, + callee_name, + interp.global_data.borrow(), + ), + ) + } i => { debug_assert!(!ContainsStore::contains_store(i)); InterpRes::Continue(env) @@ -93,10 +113,10 @@ impl Writes { // If the condition is unknown, collect all writes from both branches (even if // there is a return in either, hence an InterpRes::Return result is ignored // in both cases) and produce InterpRes::Continue with the original Env. - if let InterpRes::Err(e) = self.collect_writes(interp, true_branch, env.clone()) { + if let InterpRes::Err(e) = self.check_body(interp, true_branch, env.clone()) { return InterpRes::Err(e); } - if let InterpRes::Err(e) = self.collect_writes(interp, false_branch, env.clone()) { + if let InterpRes::Err(e) = self.check_body(interp, false_branch, env.clone()) { return InterpRes::Err(e); } InterpRes::Continue(env) @@ -105,17 +125,17 @@ impl Writes { // If the condition is true, collect all writes from the false branch // (ignoring an InterpRes::Return result as above) and then analyze // and return the result from the true branch. - if let InterpRes::Err(e) = self.collect_writes(interp, false_branch, env.clone()) { + if let InterpRes::Err(e) = self.check_body(interp, false_branch, env.clone()) { return InterpRes::Err(e); } - self.collect_writes(interp, true_branch, env) + self.check_body(interp, true_branch, env) } Ok(Some(false)) => { // Reverse of the true case. - if let InterpRes::Err(e) = self.collect_writes(interp, true_branch, env.clone()) { + if let InterpRes::Err(e) = self.check_body(interp, true_branch, env.clone()) { return InterpRes::Err(e); } - self.collect_writes(interp, false_branch, env) + self.check_body(interp, false_branch, env) } } } @@ -366,7 +386,7 @@ mod tests { ]; let mut checker = Writes::default(); - let collect_res = checker.collect_writes(&interp, &body, env); + let collect_res = checker.check_body(&interp, &body, env); assert!(!matches!(collect_res, InterpRes::Err(_))); // EXPECT: // - variables A, B (only index 0 in the vector), and C are written @@ -434,7 +454,7 @@ mod tests { ]; let mut checker = Writes::default(); - let collect_res = checker.collect_writes(&interp, &body, env); + let collect_res = checker.check_body(&interp, &body, env); assert!(!matches!(collect_res, InterpRes::Err(_))); // EXPECT: // - no variables are written diff --git a/circuit_passes/src/passes/loop_unroll/observer.rs b/circuit_passes/src/passes/loop_unroll/observer.rs index 538f40fab..f00e0277a 100644 --- a/circuit_passes/src/passes/loop_unroll/observer.rs +++ b/circuit_passes/src/passes/loop_unroll/observer.rs @@ -183,7 +183,12 @@ impl LoopUnrollObserver<'_> { } match cond { // If the conditional becomes unknown just give up. - None => return Ok(None), + None => { + if DEBUG_LOOP_UNROLL { + println!("[UNROLL][try_unroll_loop] OUTCOME: not safe to move or unroll, condition unknown"); + } + return Ok(None); + } // When conditional becomes `false`, iteration count is complete. Some(false) => break, // Otherwise, continue counting. diff --git a/circuit_passes/src/passes/mod.rs b/circuit_passes/src/passes/mod.rs index 049f2baa7..523222dc7 100644 --- a/circuit_passes/src/passes/mod.rs +++ b/circuit_passes/src/passes/mod.rs @@ -786,7 +786,7 @@ impl GlobalPassData { pub fn get_data_for_func( &self, - name: &String, + name: &str, ) -> &BTreeMap)> { match self.extract_func_orig_loc.get(name) { Some(x) => x, diff --git a/circuit_passes/src/passes/simplification.rs b/circuit_passes/src/passes/simplification.rs index fd1bd5ab7..bf9fc57c1 100644 --- a/circuit_passes/src/passes/simplification.rs +++ b/circuit_passes/src/passes/simplification.rs @@ -62,20 +62,45 @@ impl<'d> SimplificationPass<'d> { fn insert( map: &RefCell>>, bucket_id: BucketId, - v: V, + new_val: Option, ) { map.borrow_mut() .entry(bucket_id) - // If the entry exists and it's not the same as the new value, set to None. + // If the entry exists and it's not the same as the new value, + // or the new value itself is None, then set to result to None + // to indicate that no replacement should be made. .and_modify(|old| { if let Some(old_val) = old { - if *old_val != v { - *old = None + match &new_val { + None => *old = None, + Some(x) => { + if *x != *old_val { + *old = None + } + } } } }) // If no entry exists, store the new value. - .or_insert(Some(v)); + .or_insert(new_val); + } + + fn store_computed_value( + map: &RefCell>>, + bucket_id: BucketId, + computed_value: Value, + ) -> Result { + if computed_value.is_unknown() { + // When the bucket's value is Unknown from any execution, add None to the map so + // the bucket will not be replaced (even if known at an execution found later), + // return 'true' so buckets nested within this bucket will be observed. + Self::insert(map, bucket_id, None); + Ok(true) + } else { + // Add known value to the map, return 'false' so observation will not continue within. + Self::insert(map, bucket_id, Some(computed_value)); + Ok(false) + } } } @@ -84,12 +109,7 @@ impl Observer> for SimplificationPass<'_> { let interp = self.build_interpreter(); let v = interp.compute_compute_bucket(bucket, env, false)?; let v = result_types::into_single_result(v, "ComputeBucket")?; - if !v.is_unknown() { - Self::insert(&self.compute_replacements, bucket.id, v); - Ok(false) - } else { - Ok(true) - } + Self::store_computed_value(&self.compute_replacements, bucket.id, v) } fn on_call_bucket(&self, bucket: &CallBucket, env: &Env) -> Result { @@ -99,12 +119,10 @@ impl Observer> for SimplificationPass<'_> { // rather than 'into_single_result()' and return 'true' in the None case // so buckets nested within this bucket will be observed. if let Some(v) = result_types::into_single_option(v) { - if !v.is_unknown() { - Self::insert(&self.call_replacements, bucket.id, v); - return Ok(false); - } + Self::store_computed_value(&self.call_replacements, bucket.id, v) + } else { + Ok(true) } - Ok(true) } fn on_constraint_bucket( @@ -113,7 +131,8 @@ impl Observer> for SimplificationPass<'_> { env: &Env, ) -> Result { self.within_constraint.replace(true); - // Match the expected structure of ConstraintBucket instances but don't fail if there's something different. + // Match the expected structure of ConstraintBucket instances + // but don't fail if there's something different. match bucket { ConstraintBucket::Equality(e) => { if let Instruction::Assert(AssertBucket { evaluate, .. }) = e.as_ref() { @@ -130,7 +149,11 @@ impl Observer> for SimplificationPass<'_> { } // If at least one is a known value, then we can (likely) simplify if values.iter().any(Value::is_known) { - Self::insert(&self.constraint_eq_replacements, e.get_id(), values); + Self::insert( + &self.constraint_eq_replacements, + e.get_id(), + Some(values), + ); } } } @@ -172,7 +195,7 @@ impl Observer> for SimplificationPass<'_> { Self::insert( &self.constraint_sub_replacements, e.get_id(), - (src, dest, dest_address_type), + Some((src, dest, dest_address_type)), ); } }