diff --git a/spirq-core/src/constant.rs b/spirq-core/src/constant.rs index 970172f..0db5e80 100644 --- a/spirq-core/src/constant.rs +++ b/spirq-core/src/constant.rs @@ -15,9 +15,17 @@ use crate::{ pub enum ConstantValue { Typeless(Box<[u8]>), Bool(bool), + S8(i8), + S16(i16), S32(i32), + S64(i64), + U8(u8), + U16(u16), U32(u32), + U64(u64), + F16(), F32(OrderedFloat), + F64(OrderedFloat), } impl From<&[u32]> for ConstantValue { fn from(x: &[u32]) -> Self { @@ -71,6 +79,34 @@ impl ConstantValue { if let Some(scalar_ty) = ty.as_scalar() { match scalar_ty { ScalarType::Boolean => Ok(ConstantValue::Bool(x.iter().any(|x| x != &0))), + ScalarType::Integer { + bits: 8, + is_signed: true, + } if x.len() == 4 => { + let x = i8::from_ne_bytes([x[0]]); + Ok(ConstantValue::S8(x)) + } + ScalarType::Integer { + bits: 8, + is_signed: false, + } if x.len() == 4 => { + let x = u8::from_ne_bytes([x[0]]); + Ok(ConstantValue::U8(x)) + } + ScalarType::Integer { + bits: 16, + is_signed: true, + } if x.len() == 4 => { + let x = i16::from_ne_bytes([x[0], x[1]]); + Ok(ConstantValue::S16(x)) + } + ScalarType::Integer { + bits: 16, + is_signed: false, + } if x.len() == 4 => { + let x = u16::from_ne_bytes([x[0], x[1]]); + Ok(ConstantValue::U16(x)) + } ScalarType::Integer { bits: 32, is_signed: true, @@ -85,10 +121,29 @@ impl ConstantValue { let x = u32::from_ne_bytes([x[0], x[1], x[2], x[3]]); Ok(ConstantValue::U32(x)) } + ScalarType::Integer { + bits: 64, + is_signed: true, + } if x.len() == 8 => { + let x = i64::from_ne_bytes([x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7]]); + Ok(ConstantValue::S64(x)) + } + ScalarType::Integer { + bits: 64, + is_signed: false, + } if x.len() == 8 => { + let x = u64::from_ne_bytes([x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7]]); + Ok(ConstantValue::U64(x)) + } + ScalarType::Float { bits: 16 } if x.len() == 4 => Ok(ConstantValue::F16()), ScalarType::Float { bits: 32 } if x.len() == 4 => { let x = f32::from_ne_bytes([x[0], x[1], x[2], x[3]]); Ok(ConstantValue::F32(OrderedFloat(x))) } + ScalarType::Float { bits: 64 } if x.len() == 8 => { + let x = f64::from_ne_bytes([x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7]]); + Ok(ConstantValue::F64(OrderedFloat(x))) + } _ => Err(anyhow!( "cannot parse {:?} from {} bytes", scalar_ty,