Skip to content

Commit

Permalink
fix(ssa): Track all local allocations during flattening (#6619)
Browse files Browse the repository at this point in the history
Co-authored-by: Tom French <[email protected]>
  • Loading branch information
vezenovm and TomAFrench authored Nov 26, 2024
1 parent 10a9f81 commit 6491175
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 96 deletions.
47 changes: 30 additions & 17 deletions compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@
//! v11 = mul v4, Field 12
//! v12 = add v10, v11
//! store v12 at v5 (new store)
use fxhash::FxHashMap as HashMap;
use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet};

use acvm::{acir::AcirField, acir::BlackBoxFunc, FieldElement};
use iter_extended::vecmap;
Expand Down Expand Up @@ -201,6 +201,15 @@ struct Context<'f> {
/// When processing a block, we pop this stack to get its arguments
/// and at the end we push the arguments for his successor
arguments_stack: Vec<Vec<ValueId>>,

/// Stores all allocations local to the current branch.
///
/// Since these branches are local to the current branch (i.e. only defined within one branch of
/// an if expression), they should not be merged with their previous value or stored value in
/// the other branch since there is no such value.
///
/// The `ValueId` here is that which is returned by the allocate instruction.
local_allocations: HashSet<ValueId>,
}

#[derive(Clone)]
Expand All @@ -211,6 +220,8 @@ struct ConditionalBranch {
old_condition: ValueId,
// The condition of the branch
condition: ValueId,
// The allocations accumulated when processing the branch
local_allocations: HashSet<ValueId>,
}

struct ConditionalContext {
Expand Down Expand Up @@ -243,6 +254,7 @@ fn flatten_function_cfg(function: &mut Function, no_predicates: &HashMap<Functio
slice_sizes: HashMap::default(),
condition_stack: Vec::new(),
arguments_stack: Vec::new(),
local_allocations: HashSet::default(),
};
context.flatten(no_predicates);
}
Expand Down Expand Up @@ -317,7 +329,6 @@ impl<'f> Context<'f> {
// If this is not a separate variable, clippy gets confused and says the to_vec is
// unnecessary, when removing it actually causes an aliasing/mutability error.
let instructions = self.inserter.function.dfg[block].instructions().to_vec();
let mut previous_allocate_result = None;

for instruction in instructions.iter() {
if self.is_no_predicate(no_predicates, instruction) {
Expand All @@ -332,10 +343,10 @@ impl<'f> Context<'f> {
None,
im::Vector::new(),
);
self.push_instruction(*instruction, &mut previous_allocate_result);
self.push_instruction(*instruction);
self.insert_current_side_effects_enabled();
} else {
self.push_instruction(*instruction, &mut previous_allocate_result);
self.push_instruction(*instruction);
}
}
}
Expand Down Expand Up @@ -405,10 +416,12 @@ impl<'f> Context<'f> {
let old_condition = *condition;
let then_condition = self.inserter.resolve(old_condition);

let old_allocations = std::mem::take(&mut self.local_allocations);
let branch = ConditionalBranch {
old_condition,
condition: self.link_condition(then_condition),
last_block: *then_destination,
local_allocations: old_allocations,
};
let cond_context = ConditionalContext {
condition: then_condition,
Expand All @@ -435,11 +448,14 @@ impl<'f> Context<'f> {
);
let else_condition = self.link_condition(else_condition);

let old_allocations = std::mem::take(&mut self.local_allocations);
let else_branch = ConditionalBranch {
old_condition: cond_context.then_branch.old_condition,
condition: else_condition,
last_block: *block,
local_allocations: old_allocations,
};
cond_context.then_branch.local_allocations.clear();
cond_context.else_branch = Some(else_branch);
self.condition_stack.push(cond_context);

Expand All @@ -461,6 +477,7 @@ impl<'f> Context<'f> {
}

let mut else_branch = cond_context.else_branch.unwrap();
self.local_allocations = std::mem::take(&mut else_branch.local_allocations);
else_branch.last_block = *block;
cond_context.else_branch = Some(else_branch);

Expand Down Expand Up @@ -593,22 +610,19 @@ impl<'f> Context<'f> {
/// `previous_allocate_result` should only be set to the result of an allocate instruction
/// if that instruction was the instruction immediately previous to this one - if there are
/// any instructions in between it should be None.
fn push_instruction(
&mut self,
id: InstructionId,
previous_allocate_result: &mut Option<ValueId>,
) {
fn push_instruction(&mut self, id: InstructionId) {
let (instruction, call_stack) = self.inserter.map_instruction(id);
let instruction = self.handle_instruction_side_effects(
instruction,
call_stack.clone(),
*previous_allocate_result,
);
let instruction = self.handle_instruction_side_effects(instruction, call_stack.clone());

let instruction_is_allocate = matches!(&instruction, Instruction::Allocate);
let entry = self.inserter.function.entry_block();
let results = self.inserter.push_instruction_value(instruction, id, entry, call_stack);
*previous_allocate_result = instruction_is_allocate.then(|| results.first());

// Remember an allocate was created local to this branch so that we do not try to merge store
// values across branches for it later.
if instruction_is_allocate {
self.local_allocations.insert(results.first());
}
}

/// If we are currently in a branch, we need to modify constrain instructions
Expand All @@ -621,7 +635,6 @@ impl<'f> Context<'f> {
&mut self,
instruction: Instruction,
call_stack: CallStack,
previous_allocate_result: Option<ValueId>,
) -> Instruction {
if let Some(condition) = self.get_last_condition() {
match instruction {
Expand Down Expand Up @@ -652,7 +665,7 @@ impl<'f> Context<'f> {
Instruction::Store { address, value } => {
// If this instruction immediately follows an allocate, and stores to that
// address there is no previous value to load and we don't need a merge anyway.
if Some(address) == previous_allocate_result {
if self.local_allocations.contains(&address) {
Instruction::Store { address, value }
} else {
// Instead of storing `value`, store `if condition { value } else { previous_value }`
Expand Down
120 changes: 41 additions & 79 deletions compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -415,13 +415,11 @@ impl<'f> PerFunctionContext<'f> {
let address = self.inserter.function.dfg.resolve(*address);
let value = self.inserter.function.dfg.resolve(*value);

// FIXME: This causes errors in the sha256 tests
//
// If there was another store to this instruction without any (unremoved) loads or
// function calls in-between, we can remove the previous store.
// if let Some(last_store) = references.last_stores.get(&address) {
// self.instructions_to_remove.insert(*last_store);
// }
if let Some(last_store) = references.last_stores.get(&address) {
self.instructions_to_remove.insert(*last_store);
}

if self.inserter.function.dfg.value_is_reference(value) {
if let Some(expression) = references.expressions.get(&value) {
Expand Down Expand Up @@ -614,6 +612,8 @@ mod tests {
map::Id,
types::Type,
},
opt::assert_normalized_ssa_equals,
Ssa,
};

#[test]
Expand Down Expand Up @@ -824,91 +824,53 @@ mod tests {
// is later stored in a successor block
#[test]
fn load_aliases_in_predecessor_block() {
// fn main {
// b0():
// v0 = allocate
// store Field 0 at v0
// v2 = allocate
// store v0 at v2
// v3 = load v2
// v4 = load v2
// jmp b1()
// b1():
// store Field 1 at v3
// store Field 2 at v4
// v7 = load v3
// v8 = eq v7, Field 2
// return
// }
let main_id = Id::test_new(0);
let mut builder = FunctionBuilder::new("main".into(), main_id);

let v0 = builder.insert_allocate(Type::field());

let zero = builder.field_constant(0u128);
builder.insert_store(v0, zero);

let v2 = builder.insert_allocate(Type::Reference(Arc::new(Type::field())));
builder.insert_store(v2, v0);

let v3 = builder.insert_load(v2, Type::field());
let v4 = builder.insert_load(v2, Type::field());
let b1 = builder.insert_block();
builder.terminate_with_jmp(b1, vec![]);

builder.switch_to_block(b1);

let one = builder.field_constant(1u128);
builder.insert_store(v3, one);

let two = builder.field_constant(2u128);
builder.insert_store(v4, two);

let v8 = builder.insert_load(v3, Type::field());
let _ = builder.insert_binary(v8, BinaryOp::Eq, two);

builder.terminate_with_return(vec![]);

let ssa = builder.finish();
assert_eq!(ssa.main().reachable_blocks().len(), 2);
let src = "
acir(inline) fn main f0 {
b0():
v0 = allocate -> &mut Field
store Field 0 at v0
v2 = allocate -> &mut &mut Field
store v0 at v2
v3 = load v2 -> &mut Field
v4 = load v2 -> &mut Field
jmp b1()
b1():
store Field 1 at v3
store Field 2 at v4
v7 = load v3 -> Field
v8 = eq v7, Field 2
return
}
";

// Expected result:
// acir fn main f0 {
// b0():
// v9 = allocate
// store Field 0 at v9
// v10 = allocate
// jmp b1()
// b1():
// return
// }
let ssa = ssa.mem2reg();
println!("{}", ssa);
let mut ssa = Ssa::from_str(src).unwrap();
let main = ssa.main_mut();

let main = ssa.main();
assert_eq!(main.reachable_blocks().len(), 2);
let instructions = main.dfg[main.entry_block()].instructions();
assert_eq!(instructions.len(), 6); // The final return is not counted

// All loads should be removed
assert_eq!(count_loads(main.entry_block(), &main.dfg), 0);
assert_eq!(count_loads(b1, &main.dfg), 0);

// The first store is not removed as it is used as a nested reference in another store.
// We would need to track whether the store where `v9` is the store value gets removed to know whether
// We would need to track whether the store where `v0` is the store value gets removed to know whether
// to remove it.
assert_eq!(count_stores(main.entry_block(), &main.dfg), 1);

// The first store in b1 is removed since there is another store to the same reference
// in the same block, and the store is not needed before the later store.
// The rest of the stores are also removed as no loads are done within any blocks
// to the stored values.
//
// NOTE: This store is not removed due to the FIXME when handling Instruction::Store.
assert_eq!(count_stores(b1, &main.dfg), 1);

let b1_instructions = main.dfg[b1].instructions();
let expected = "
acir(inline) fn main f0 {
b0():
v0 = allocate -> &mut Field
store Field 0 at v0
v2 = allocate -> &mut &mut Field
jmp b1()
b1():
return
}
";

// We expect the last eq to be optimized out, only the store from above remains
assert_eq!(b1_instructions.len(), 1);
let ssa = ssa.mem2reg();
assert_normalized_ssa_equals(ssa, expected);
}

#[test]
Expand Down

0 comments on commit 6491175

Please sign in to comment.