Skip to content

Commit

Permalink
WIP: mem2reg speedup
Browse files Browse the repository at this point in the history
  • Loading branch information
eddyb authored and Firestar99 committed Oct 10, 2024
1 parent 3f96056 commit fc0db1b
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 45 deletions.
11 changes: 6 additions & 5 deletions crates/rustc_codegen_spirv/src/linker/dce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
//! *references* a rooted thing is also rooted, not the other way around - but that's the basic
//! concept.

use rspirv::dr::{Function, Instruction, Module, Operand};
use rspirv::dr::{Block, Function, Instruction, Module, Operand};
use rspirv::spirv::{Decoration, LinkageType, Op, StorageClass, Word};
use rustc_data_structures::fx::FxIndexSet;
use rustc_data_structures::fx::{FxIndexMap, FxIndexSet};
use std::hash::Hash;

pub fn dce(module: &mut Module) {
let mut rooted = collect_roots(module);
Expand Down Expand Up @@ -137,11 +138,11 @@ fn kill_unrooted(module: &mut Module, rooted: &FxIndexSet<Word>) {
}
}

pub fn dce_phi(func: &mut Function) {
pub fn dce_phi(blocks: &mut FxIndexMap<impl Eq + Hash, &mut Block>) {
let mut used = FxIndexSet::default();
loop {
let mut changed = false;
for inst in func.all_inst_iter() {
for inst in blocks.values().flat_map(|block| &block.instructions) {
if inst.class.opcode != Op::Phi || used.contains(&inst.result_id.unwrap()) {
for op in &inst.operands {
if let Some(id) = op.id_ref_any() {
Expand All @@ -154,7 +155,7 @@ pub fn dce_phi(func: &mut Function) {
break;
}
}
for block in &mut func.blocks {
for block in blocks.values_mut() {
block
.instructions
.retain(|inst| inst.class.opcode != Op::Phi || used.contains(&inst.result_id.unwrap()));
Expand Down
91 changes: 52 additions & 39 deletions crates/rustc_codegen_spirv/src/linker/mem2reg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,31 @@ use super::simple_passes::outgoing_edges;
use super::{apply_rewrite_rules, id};
use rspirv::dr::{Block, Function, Instruction, ModuleHeader, Operand};
use rspirv::spirv::{Op, Word};
use rustc_data_structures::fx::{FxHashMap, FxHashSet};
use rustc_data_structures::fx::{FxHashMap, FxHashSet, FxIndexMap};
use rustc_middle::bug;
use std::collections::hash_map;

// HACK(eddyb) newtype instead of type alias to avoid mistakes.
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
struct LabelId(Word);

pub fn mem2reg(
header: &mut ModuleHeader,
types_global_values: &mut Vec<Instruction>,
pointer_to_pointee: &FxHashMap<Word, Word>,
constants: &FxHashMap<Word, u32>,
func: &mut Function,
) {
let reachable = compute_reachable(&func.blocks);
let preds = compute_preds(&func.blocks, &reachable);
// HACK(eddyb) this ad-hoc indexing might be useful elsewhere as well, but
// it's made completely irrelevant by SPIR-T so only applies to legacy code.
let mut blocks: FxIndexMap<_, _> = func
.blocks
.iter_mut()
.map(|block| (LabelId(block.label_id().unwrap()), block))
.collect();

let reachable = compute_reachable(&blocks);
let preds = compute_preds(&blocks, &reachable);
let idom = compute_idom(&preds, &reachable);
let dominance_frontier = compute_dominance_frontier(&preds, &idom);
loop {
Expand All @@ -34,31 +46,27 @@ pub fn mem2reg(
types_global_values,
pointer_to_pointee,
constants,
&mut func.blocks,
&mut blocks,
&dominance_frontier,
);
if !changed {
break;
}
// mem2reg produces minimal SSA form, not pruned, so DCE the dead ones
super::dce::dce_phi(func);
super::dce::dce_phi(&mut blocks);
}
}

fn label_to_index(blocks: &[Block], id: Word) -> usize {
blocks
.iter()
.position(|b| b.label_id().unwrap() == id)
.unwrap()
}

fn compute_reachable(blocks: &[Block]) -> Vec<bool> {
fn recurse(blocks: &[Block], reachable: &mut [bool], block: usize) {
fn compute_reachable(blocks: &FxIndexMap<LabelId, &mut Block>) -> Vec<bool> {
fn recurse(blocks: &FxIndexMap<LabelId, &mut Block>, reachable: &mut [bool], block: usize) {
if !reachable[block] {
reachable[block] = true;
for dest_id in outgoing_edges(&blocks[block]) {
let dest_idx = label_to_index(blocks, dest_id);
recurse(blocks, reachable, dest_idx);
for dest_id in outgoing_edges(blocks[block]) {
recurse(
blocks,
reachable,
blocks.get_index_of(&LabelId(dest_id)).unwrap(),
);
}
}
}
Expand All @@ -67,17 +75,19 @@ fn compute_reachable(blocks: &[Block]) -> Vec<bool> {
reachable
}

fn compute_preds(blocks: &[Block], reachable_blocks: &[bool]) -> Vec<Vec<usize>> {
fn compute_preds(
blocks: &FxIndexMap<LabelId, &mut Block>,
reachable_blocks: &[bool],
) -> Vec<Vec<usize>> {
let mut result = vec![vec![]; blocks.len()];
// Do not count unreachable blocks as valid preds of blocks
for (source_idx, source) in blocks
.iter()
.values()
.enumerate()
.filter(|&(b, _)| reachable_blocks[b])
{
for dest_id in outgoing_edges(source) {
let dest_idx = label_to_index(blocks, dest_id);
result[dest_idx].push(source_idx);
result[blocks.get_index_of(&LabelId(dest_id)).unwrap()].push(source_idx);
}
}
result
Expand Down Expand Up @@ -161,7 +171,7 @@ fn insert_phis_all(
types_global_values: &mut Vec<Instruction>,
pointer_to_pointee: &FxHashMap<Word, Word>,
constants: &FxHashMap<Word, u32>,
blocks: &mut [Block],
blocks: &mut FxIndexMap<LabelId, &mut Block>,
dominance_frontier: &[FxHashSet<usize>],
) -> bool {
let var_maps_and_types = blocks[0]
Expand Down Expand Up @@ -198,7 +208,11 @@ fn insert_phis_all(
rewrite_rules: FxHashMap::default(),
};
renamer.rename(0, None);
apply_rewrite_rules(&renamer.rewrite_rules, blocks);
// FIXME(eddyb) shouldn't this full rescan of the function be done once?
apply_rewrite_rules(
&renamer.rewrite_rules,
blocks.values_mut().map(|block| &mut **block),
);
remove_nops(blocks);
}
remove_old_variables(blocks, &var_maps_and_types);
Expand All @@ -216,7 +230,7 @@ struct VarInfo {
fn collect_access_chains(
pointer_to_pointee: &FxHashMap<Word, Word>,
constants: &FxHashMap<Word, u32>,
blocks: &[Block],
blocks: &FxIndexMap<LabelId, &mut Block>,
base_var: Word,
base_var_ty: Word,
) -> Option<FxHashMap<Word, VarInfo>> {
Expand Down Expand Up @@ -249,7 +263,7 @@ fn collect_access_chains(
// Loop in case a previous block references a later AccessChain
loop {
let mut changed = false;
for inst in blocks.iter().flat_map(|b| &b.instructions) {
for inst in blocks.values().flat_map(|b| &b.instructions) {
for (index, op) in inst.operands.iter().enumerate() {
if let Operand::IdRef(id) = op {
if variables.contains_key(id) {
Expand Down Expand Up @@ -307,10 +321,10 @@ fn collect_access_chains(
// same var map (e.g. `s.x = s.y;`).
fn split_copy_memory(
header: &mut ModuleHeader,
blocks: &mut [Block],
blocks: &mut FxIndexMap<LabelId, &mut Block>,
var_map: &FxHashMap<Word, VarInfo>,
) {
for block in blocks {
for block in blocks.values_mut() {
let mut inst_index = 0;
while inst_index < block.instructions.len() {
let inst = &block.instructions[inst_index];
Expand Down Expand Up @@ -369,7 +383,7 @@ fn has_store(block: &Block, var_map: &FxHashMap<Word, VarInfo>) -> bool {
}

fn insert_phis(
blocks: &[Block],
blocks: &FxIndexMap<LabelId, &mut Block>,
dominance_frontier: &[FxHashSet<usize>],
var_map: &FxHashMap<Word, VarInfo>,
) -> FxHashSet<usize> {
Expand All @@ -378,7 +392,7 @@ fn insert_phis(
let mut ever_on_work_list = FxHashSet::default();
let mut work_list = Vec::new();
let mut blocks_with_phi = FxHashSet::default();
for (block_idx, block) in blocks.iter().enumerate() {
for (block_idx, block) in blocks.values().enumerate() {
if has_store(block, var_map) {
ever_on_work_list.insert(block_idx);
work_list.push(block_idx);
Expand Down Expand Up @@ -423,10 +437,10 @@ fn top_stack_or_undef(
}
}

struct Renamer<'a> {
struct Renamer<'a, 'b> {
header: &'a mut ModuleHeader,
types_global_values: &'a mut Vec<Instruction>,
blocks: &'a mut [Block],
blocks: &'a mut FxIndexMap<LabelId, &'b mut Block>,
blocks_with_phi: FxHashSet<usize>,
base_var_type: Word,
var_map: &'a FxHashMap<Word, VarInfo>,
Expand All @@ -436,7 +450,7 @@ struct Renamer<'a> {
rewrite_rules: FxHashMap<Word, Word>,
}

impl Renamer<'_> {
impl Renamer<'_, '_> {
// Returns the phi definition.
fn insert_phi_value(&mut self, block: usize, from_block: usize) -> Word {
let from_block_label = self.blocks[from_block].label_id().unwrap();
Expand Down Expand Up @@ -558,9 +572,8 @@ impl Renamer<'_> {
}
}

for dest_id in outgoing_edges(&self.blocks[block]).collect::<Vec<_>>() {
// TODO: Don't do this find
let dest_idx = label_to_index(self.blocks, dest_id);
for dest_id in outgoing_edges(self.blocks[block]).collect::<Vec<_>>() {
let dest_idx = self.blocks.get_index_of(&LabelId(dest_id)).unwrap();
self.rename(dest_idx, Some(block));
}

Expand All @@ -570,16 +583,16 @@ impl Renamer<'_> {
}
}

fn remove_nops(blocks: &mut [Block]) {
for block in blocks {
fn remove_nops(blocks: &mut FxIndexMap<LabelId, &mut Block>) {
for block in blocks.values_mut() {
block
.instructions
.retain(|inst| inst.class.opcode != Op::Nop);
}
}

fn remove_old_variables(
blocks: &mut [Block],
blocks: &mut FxIndexMap<LabelId, &mut Block>,
var_maps_and_types: &[(FxHashMap<u32, VarInfo>, u32)],
) {
blocks[0].instructions.retain(|inst| {
Expand All @@ -590,7 +603,7 @@ fn remove_old_variables(
.all(|(var_map, _)| !var_map.contains_key(&result_id))
}
});
for block in blocks {
for block in blocks.values_mut() {
block.instructions.retain(|inst| {
!matches!(inst.class.opcode, Op::AccessChain | Op::InBoundsAccessChain)
|| inst.operands.iter().all(|op| {
Expand Down
5 changes: 4 additions & 1 deletion crates/rustc_codegen_spirv/src/linker/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,10 @@ fn id(header: &mut ModuleHeader) -> Word {
result
}

fn apply_rewrite_rules(rewrite_rules: &FxHashMap<Word, Word>, blocks: &mut [Block]) {
fn apply_rewrite_rules<'a>(
rewrite_rules: &FxHashMap<Word, Word>,
blocks: impl IntoIterator<Item = &'a mut Block>,
) {
let apply = |inst: &mut Instruction| {
if let Some(ref mut id) = &mut inst.result_id {
if let Some(&rewrite) = rewrite_rules.get(id) {
Expand Down

0 comments on commit fc0db1b

Please sign in to comment.