diff --git a/crates/rustc_codegen_spirv/src/linker/inline.rs b/crates/rustc_codegen_spirv/src/linker/inline.rs index 0ef6db52f4..34a4886ca0 100644 --- a/crates/rustc_codegen_spirv/src/linker/inline.rs +++ b/crates/rustc_codegen_spirv/src/linker/inline.rs @@ -94,7 +94,6 @@ pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> { header, debug_string_source: &mut module.debug_string_source, annotations: &mut module.annotations, - types_global_values: &mut module.types_global_values, legal_globals, @@ -493,7 +492,6 @@ struct Inliner<'a, 'b> { header: &'b mut ModuleHeader, debug_string_source: &'b mut Vec, annotations: &'b mut Vec, - types_global_values: &'b mut Vec, legal_globals: FxHashMap, functions_that_may_abort: FxHashSet, @@ -523,29 +521,6 @@ impl Inliner<'_, '_> { } } - fn ptr_ty(&mut self, pointee: Word) -> Word { - // TODO: This is horribly slow, fix this - let existing = self.types_global_values.iter().find(|inst| { - inst.class.opcode == Op::TypePointer - && inst.operands[0].unwrap_storage_class() == StorageClass::Function - && inst.operands[1].unwrap_id_ref() == pointee - }); - if let Some(existing) = existing { - return existing.result_id.unwrap(); - } - let inst_id = self.id(); - self.types_global_values.push(Instruction::new( - Op::TypePointer, - None, - Some(inst_id), - vec![ - Operand::StorageClass(StorageClass::Function), - Operand::IdRef(pointee), - ], - )); - inst_id - } - fn inline_fn( &mut self, function: &mut Function, @@ -622,15 +597,19 @@ impl Inliner<'_, '_> { .insert(caller.def_id().unwrap()); } - let call_result_type = { + let mut maybe_call_result_phi = { let ty = call_inst.result_type.unwrap(); if ty == self.op_type_void_id { None } else { - Some(ty) + Some(Instruction::new( + Op::Phi, + Some(ty), + Some(call_inst.result_id.unwrap()), + vec![], + )) } }; - let call_result_id = call_inst.result_id.unwrap(); // Get the debug "source location" instruction that applies to the call. let custom_ext_inst_set_import = self.custom_ext_inst_set_import; @@ -667,17 +646,12 @@ impl Inliner<'_, '_> { }); let mut rewrite_rules = callee_parameters.zip(call_arguments).collect(); - let return_variable = if call_result_type.is_some() { - Some(self.id()) - } else { - None - }; let return_jump = self.id(); // Rewrite OpReturns of the callee. let mut inlined_callee_blocks = self.get_inlined_blocks( callee, call_debug_src_loc_inst, - return_variable, + maybe_call_result_phi.as_mut(), return_jump, ); // Clone the IDs of the callee, because otherwise they'd be defined multiple times if the @@ -686,6 +660,55 @@ impl Inliner<'_, '_> { apply_rewrite_rules(&rewrite_rules, &mut inlined_callee_blocks); self.apply_rewrite_for_decorations(&rewrite_rules); + if let Some(call_result_phi) = &mut maybe_call_result_phi { + // HACK(eddyb) new IDs should be generated earlier, to avoid pushing + // callee IDs to `call_result_phi.operands` only to rewrite them here. + for op in &mut call_result_phi.operands { + if let Some(id) = op.id_ref_any_mut() { + if let Some(&rewrite) = rewrite_rules.get(id) { + *id = rewrite; + } + } + } + + // HACK(eddyb) this special-casing of the single-return case is + // really necessary for passes like `mem2reg` which are not capable + // of skipping through the extraneous `OpPhi`s on their own. + if let [returned_value, _return_block] = &call_result_phi.operands[..] { + let call_result_id = call_result_phi.result_id.unwrap(); + let returned_value_id = returned_value.unwrap_id_ref(); + + maybe_call_result_phi = None; + + // HACK(eddyb) this is a conservative approximation of all the + // instructions that could potentially reference the call result. + let reaching_insts = { + let (pre_call_blocks, call_and_post_call_blocks) = + caller.blocks.split_at_mut(block_idx); + (pre_call_blocks.iter_mut().flat_map(|block| { + block + .instructions + .iter_mut() + .take_while(|inst| inst.class.opcode == Op::Phi) + })) + .chain( + call_and_post_call_blocks + .iter_mut() + .flat_map(|block| &mut block.instructions), + ) + }; + for reaching_inst in reaching_insts { + for op in &mut reaching_inst.operands { + if let Some(id) = op.id_ref_any_mut() { + if *id == call_result_id { + *id = returned_value_id; + } + } + } + } + } + } + // Split the block containing the `OpFunctionCall` into pre-call vs post-call. let pre_call_block_idx = block_idx; #[expect(unused)] @@ -701,18 +724,6 @@ impl Inliner<'_, '_> { .unwrap(); assert!(call.class.opcode == Op::FunctionCall); - if let Some(call_result_type) = call_result_type { - // Generate the storage space for the return value: Do this *after* the split above, - // because if block_idx=0, inserting a variable here shifts call_index. - let ret_var_inst = Instruction::new( - Op::Variable, - Some(self.ptr_ty(call_result_type)), - Some(return_variable.unwrap()), - vec![Operand::StorageClass(StorageClass::Function)], - ); - self.insert_opvariables(&mut caller.blocks[0], [ret_var_inst]); - } - // Insert non-entry inlined callee blocks just after the pre-call block. let non_entry_inlined_callee_blocks = inlined_callee_blocks.drain(1..); let num_non_entry_inlined_callee_blocks = non_entry_inlined_callee_blocks.len(); @@ -721,18 +732,9 @@ impl Inliner<'_, '_> { non_entry_inlined_callee_blocks, ); - if let Some(call_result_type) = call_result_type { - // Add the load of the result value after the inlined function. Note there's guaranteed no - // OpPhi instructions since we just split this block. - post_call_block_insts.insert( - 0, - Instruction::new( - Op::Load, - Some(call_result_type), - Some(call_result_id), - vec![Operand::IdRef(return_variable.unwrap())], - ), - ); + if let Some(call_result_phi) = maybe_call_result_phi { + // Add the `OpPhi` for the call result value, after the inlined function. + post_call_block_insts.insert(0, call_result_phi); } // Insert the post-call block, after all the inlined callee blocks. @@ -899,7 +901,7 @@ impl Inliner<'_, '_> { &mut self, callee: &Function, call_debug_src_loc_inst: Option<&Instruction>, - return_variable: Option, + mut maybe_call_result_phi: Option<&mut Instruction>, return_jump: Word, ) -> Vec { let Self { @@ -997,17 +999,13 @@ impl Inliner<'_, '_> { if let Op::Return | Op::ReturnValue = terminator.class.opcode { if Op::ReturnValue == terminator.class.opcode { let return_value = terminator.operands[0].id_ref_any().unwrap(); - block.instructions.push(Instruction::new( - Op::Store, - None, - None, - vec![ - Operand::IdRef(return_variable.unwrap()), - Operand::IdRef(return_value), - ], - )); + let call_result_phi = maybe_call_result_phi.as_deref_mut().unwrap(); + call_result_phi.operands.extend([ + Operand::IdRef(return_value), + Operand::IdRef(block.label_id().unwrap()), + ]); } else { - assert!(return_variable.is_none()); + assert!(maybe_call_result_phi.is_none()); } terminator = Instruction::new(Op::Branch, None, None, vec![Operand::IdRef(return_jump)]);