Skip to content

Commit

Permalink
Add support for explicitly typed constants (#118)
Browse files Browse the repository at this point in the history
* Add support for explicitly typed constants

* cargo fmt
  • Loading branch information
attackgoat authored Dec 20, 2023
1 parent cabff1d commit ea2336e
Showing 1 changed file with 55 additions and 0 deletions.
55 changes: 55 additions & 0 deletions spirq-core/src/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,17 @@ use crate::{
pub enum ConstantValue {
Typeless(Box<[u8]>),
Bool(bool),
S8(i8),
S16(i16),
S32(i32),
S64(i64),
U8(u8),
U16(u16),
U32(u32),
U64(u64),
F16(),
F32(OrderedFloat<f32>),
F64(OrderedFloat<f64>),
}
impl From<&[u32]> for ConstantValue {
fn from(x: &[u32]) -> Self {
Expand Down Expand Up @@ -71,6 +79,34 @@ impl ConstantValue {
if let Some(scalar_ty) = ty.as_scalar() {
match scalar_ty {
ScalarType::Boolean => Ok(ConstantValue::Bool(x.iter().any(|x| x != &0))),
ScalarType::Integer {
bits: 8,
is_signed: true,
} if x.len() == 4 => {
let x = i8::from_ne_bytes([x[0]]);
Ok(ConstantValue::S8(x))
}
ScalarType::Integer {
bits: 8,
is_signed: false,
} if x.len() == 4 => {
let x = u8::from_ne_bytes([x[0]]);
Ok(ConstantValue::U8(x))
}
ScalarType::Integer {
bits: 16,
is_signed: true,
} if x.len() == 4 => {
let x = i16::from_ne_bytes([x[0], x[1]]);
Ok(ConstantValue::S16(x))
}
ScalarType::Integer {
bits: 16,
is_signed: false,
} if x.len() == 4 => {
let x = u16::from_ne_bytes([x[0], x[1]]);
Ok(ConstantValue::U16(x))
}
ScalarType::Integer {
bits: 32,
is_signed: true,
Expand All @@ -85,10 +121,29 @@ impl ConstantValue {
let x = u32::from_ne_bytes([x[0], x[1], x[2], x[3]]);
Ok(ConstantValue::U32(x))
}
ScalarType::Integer {
bits: 64,
is_signed: true,
} if x.len() == 8 => {
let x = i64::from_ne_bytes([x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7]]);
Ok(ConstantValue::S64(x))
}
ScalarType::Integer {
bits: 64,
is_signed: false,
} if x.len() == 8 => {
let x = u64::from_ne_bytes([x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7]]);
Ok(ConstantValue::U64(x))
}
ScalarType::Float { bits: 16 } if x.len() == 4 => Ok(ConstantValue::F16()),
ScalarType::Float { bits: 32 } if x.len() == 4 => {
let x = f32::from_ne_bytes([x[0], x[1], x[2], x[3]]);
Ok(ConstantValue::F32(OrderedFloat(x)))
}
ScalarType::Float { bits: 64 } if x.len() == 8 => {
let x = f64::from_ne_bytes([x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7]]);
Ok(ConstantValue::F64(OrderedFloat(x)))
}
_ => Err(anyhow!(
"cannot parse {:?} from {} bytes",
scalar_ty,
Expand Down

0 comments on commit ea2336e

Please sign in to comment.