Skip to content

Commit

Permalink
Fixes needed by screen-13 (#114)
Browse files Browse the repository at this point in the history
* Fix for functions called before their body declaration

* Fix for user-specified specialization constants
  • Loading branch information
attackgoat authored Dec 18, 2023
1 parent f07396b commit 07a37d7
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 5 deletions.
14 changes: 13 additions & 1 deletion spirq-core/src/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,25 @@ pub struct Function {

#[derive(Default)]
pub struct FunctionRegistry {
called: HashSet<FunctionId>,
func_map: HashMap<FunctionId, Function>,
}
impl FunctionRegistry {
pub fn set(&mut self, id: FunctionId, func: Function) -> Result<()> {
pub fn called(&mut self, id: FunctionId) {
if let Some(func) = self.func_map.get_mut(&id) {
func.callees.insert(id);
} else {
self.called.insert(id);
}
}

pub fn set(&mut self, id: FunctionId, mut func: Function) -> Result<()> {
use std::collections::hash_map::Entry;
match self.func_map.entry(id) {
Entry::Vacant(entry) => {
if self.called.remove(&id) {
func.callees.insert(id);
}
entry.insert(func);
Ok(())
}
Expand Down
11 changes: 7 additions & 4 deletions spirq/src/reflect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,12 @@ impl<'a> ReflectIntermediate<'a> {
.get_u32(op.const_id, spirv::Decoration::SpecId)?;
let ty = self.ty_reg.get(op.ty_id)?.clone();
let constant = if let Some(user_value) = self.cfg.spec_values.get(&spec_id) {
Constant::new(name, ty, user_value.clone())
let user_value = if matches!(user_value, ConstantValue::Typeless(_)) {
user_value.to_typed(&ty)?
} else {
user_value.clone()
};
Constant::new(name, ty, user_value)
} else {
let value = match opcode {
Op::SpecConstantTrue => ConstantValue::from(true),
Expand Down Expand Up @@ -617,9 +622,7 @@ impl Inspector for FunctionInspector {
}
Op::FunctionCall => {
let op = OpFunctionCall::try_from(instr)?;
let func_id = op.func_id;
let func = itm.func_reg.get_mut(func_id)?;
func.callees.insert(func_id);
itm.func_reg.called(op.func_id);
}
_ => {
if let Some((_func_id, func)) = self.cur_func.as_mut() {
Expand Down

0 comments on commit 07a37d7

Please sign in to comment.