From 07a37d7fcdd96699c91772dc28ab6c57ecded246 Mon Sep 17 00:00:00 2001 From: John Wells Date: Mon, 18 Dec 2023 01:46:24 -0500 Subject: [PATCH] Fixes needed by screen-13 (#114) * Fix for functions called before their body declaration * Fix for user-specified specialization constants --- spirq-core/src/func.rs | 14 +++++++++++++- spirq/src/reflect.rs | 11 +++++++---- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/spirq-core/src/func.rs b/spirq-core/src/func.rs index 3633cef..da938b2 100644 --- a/spirq-core/src/func.rs +++ b/spirq-core/src/func.rs @@ -26,13 +26,25 @@ pub struct Function { #[derive(Default)] pub struct FunctionRegistry { + called: HashSet, func_map: HashMap, } impl FunctionRegistry { - pub fn set(&mut self, id: FunctionId, func: Function) -> Result<()> { + pub fn called(&mut self, id: FunctionId) { + if let Some(func) = self.func_map.get_mut(&id) { + func.callees.insert(id); + } else { + self.called.insert(id); + } + } + + pub fn set(&mut self, id: FunctionId, mut func: Function) -> Result<()> { use std::collections::hash_map::Entry; match self.func_map.entry(id) { Entry::Vacant(entry) => { + if self.called.remove(&id) { + func.callees.insert(id); + } entry.insert(func); Ok(()) } diff --git a/spirq/src/reflect.rs b/spirq/src/reflect.rs index 83cf959..9e869b2 100644 --- a/spirq/src/reflect.rs +++ b/spirq/src/reflect.rs @@ -537,7 +537,12 @@ impl<'a> ReflectIntermediate<'a> { .get_u32(op.const_id, spirv::Decoration::SpecId)?; let ty = self.ty_reg.get(op.ty_id)?.clone(); let constant = if let Some(user_value) = self.cfg.spec_values.get(&spec_id) { - Constant::new(name, ty, user_value.clone()) + let user_value = if matches!(user_value, ConstantValue::Typeless(_)) { + user_value.to_typed(&ty)? + } else { + user_value.clone() + }; + Constant::new(name, ty, user_value) } else { let value = match opcode { Op::SpecConstantTrue => ConstantValue::from(true), @@ -617,9 +622,7 @@ impl Inspector for FunctionInspector { } Op::FunctionCall => { let op = OpFunctionCall::try_from(instr)?; - let func_id = op.func_id; - let func = itm.func_reg.get_mut(func_id)?; - func.callees.insert(func_id); + itm.func_reg.called(op.func_id); } _ => { if let Some((_func_id, func)) = self.cur_func.as_mut() {