diff --git a/Cargo.toml b/Cargo.toml index b3d11d0b..9f556d70 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -51,7 +51,8 @@ min-cl = { git = "https://github.com/elftausend/min-cl", optional = true } # min-cl = { version = "0.3.0", optional=true } [features] -default = ["cpu", "opencl", "blas", "static-api", "stack", "macro", "untyped", "autograd", "autograd", "cached", "lazy", "fork", "graph", "serde",] +default = ["cpu", "opencl", "blas", "static-api", "stack", "macro", "untyped", "autograd", "autograd", "cached", "lazy", "fork", "graph", "serde"] + # default = ["cpu"] # default = ["no-std"] # default = ["opencl"] diff --git a/src/modules/lazy.rs b/src/modules/lazy.rs index 9fb01468..1ac7333c 100644 --- a/src/modules/lazy.rs +++ b/src/modules/lazy.rs @@ -6,6 +6,7 @@ mod ty; mod wrapper; pub use ty::*; +use wrapper::MaybeData; use crate::{ op_hint::OpHint, register_buf_copyable, unregister_buf_copyable, AddLayer, AddOperation, Alloc, @@ -362,8 +363,7 @@ where unsafe { self.bump_cursor() }; Ok(LazyWrapper { - data: None, - id: Some(id), + maybe_data: MaybeData::Id(id), _pd: core::marker::PhantomData, }) } @@ -508,15 +508,15 @@ mod tests { let device = CPU::>::new(); let buf = Buffer::::new(&device, 10); let res = &buf.data; - assert_eq!(res.id, None); + assert_eq!(res.maybe_data.id(), None); let x: Buffer = device.retrieve(10, ()).unwrap(); let res = &x.data; - assert_eq!(res.id, Some(crate::Id { id: 0, len: 10 })); + assert_eq!(res.maybe_data.id(), Some(&crate::Id { id: 0, len: 10 })); let x: Buffer = device.retrieve(10, ()).unwrap(); let res = &x.data; - assert_eq!(res.id, Some(crate::Id { id: 1, len: 10 })); + assert_eq!(res.maybe_data.id(), Some(&crate::Id { id: 1, len: 10 })); } #[test] diff --git a/src/modules/lazy/wrapper.rs b/src/modules/lazy/wrapper.rs index ebf2e313..9ac1000a 100644 --- a/src/modules/lazy/wrapper.rs +++ b/src/modules/lazy/wrapper.rs @@ -1,14 +1,16 @@ +mod maybe_data; +pub use maybe_data::MaybeData; + use core::{ marker::PhantomData, ops::{Deref, DerefMut}, }; -use crate::{flag::AllocFlag, HasId, HostPtr, Id, Lazy, PtrType, ShallowCopy, WrappedData}; +use crate::{flag::AllocFlag, HasId, HostPtr, Lazy, PtrType, ShallowCopy, WrappedData}; #[derive(Debug, Default)] pub struct LazyWrapper { - pub data: Option, - pub id: Option, + pub maybe_data: MaybeData, pub _pd: PhantomData, } @@ -18,29 +20,29 @@ impl WrappedData for Lazy<'_, Mods, T2> { #[inline] fn wrap_in_base(&self, base: Base) -> Self::Wrap { LazyWrapper { - data: Some(self.modules.wrap_in_base(base)), - id: None, + maybe_data: MaybeData::Data(self.modules.wrap_in_base(base)), _pd: PhantomData, } } #[inline] fn wrapped_as_base(wrap: &Self::Wrap) -> &Base { - Mods::wrapped_as_base(wrap.data.as_ref().expect(MISSING_DATA)) + Mods::wrapped_as_base(wrap.maybe_data.data().expect(MISSING_DATA)) } #[inline] fn wrapped_as_base_mut(wrap: &mut Self::Wrap) -> &mut Base { - Mods::wrapped_as_base_mut(wrap.data.as_mut().expect(MISSING_DATA)) + Mods::wrapped_as_base_mut(wrap.maybe_data.data_mut().expect(MISSING_DATA)) } } impl HasId for LazyWrapper { #[inline] fn id(&self) -> crate::Id { - match self.id { - Some(id) => id, - None => self.data.as_ref().unwrap().id(), + match self.maybe_data { + MaybeData::Data(ref data) => data.id(), + MaybeData::Id(id) => id, + MaybeData::None => unimplemented!() } } } @@ -48,23 +50,23 @@ impl HasId for LazyWrapper { impl PtrType for LazyWrapper { #[inline] fn size(&self) -> usize { - match self.id { - Some(id) => id.len, - None => self.data.as_ref().unwrap().size(), + match self.maybe_data { + MaybeData::Data(ref data) => data.size(), + MaybeData::Id(id) => id.len, + MaybeData::None => unimplemented!() } } #[inline] fn flag(&self) -> AllocFlag { - self.data - .as_ref() + self.maybe_data.data() .map(|data| data.flag()) .unwrap_or(AllocFlag::Lazy) } #[inline] unsafe fn set_flag(&mut self, flag: AllocFlag) { - self.data.as_mut().unwrap().set_flag(flag) + self.maybe_data.data_mut().unwrap().set_flag(flag) } } @@ -76,26 +78,26 @@ impl, T> Deref for LazyWrapper { #[inline] fn deref(&self) -> &Self::Target { - self.data.as_ref().expect(MISSING_DATA) + self.maybe_data.data().expect(MISSING_DATA) } } impl, T> DerefMut for LazyWrapper { #[inline] fn deref_mut(&mut self) -> &mut Self::Target { - self.data.as_mut().expect(MISSING_DATA) + self.maybe_data.data_mut().expect(MISSING_DATA) } } impl> HostPtr for LazyWrapper { #[inline] fn ptr(&self) -> *const T { - self.data.as_ref().unwrap().ptr() + self.maybe_data.data().unwrap().ptr() } #[inline] fn ptr_mut(&mut self) -> *mut T { - self.data.as_mut().unwrap().ptr_mut() + self.maybe_data.data_mut().unwrap().ptr_mut() } } @@ -103,8 +105,11 @@ impl ShallowCopy for LazyWrapper { #[inline] unsafe fn shallow(&self) -> Self { LazyWrapper { - id: self.id, - data: self.data.as_ref().map(|data| data.shallow()), + maybe_data: match &self.maybe_data { + MaybeData::Data(data) => MaybeData::Data(data.shallow()), + MaybeData::Id(id) => MaybeData::Id(*id), + MaybeData::None => unimplemented!(), + }, _pd: PhantomData, } } diff --git a/src/modules/lazy/wrapper/maybe_data.rs b/src/modules/lazy/wrapper/maybe_data.rs new file mode 100644 index 00000000..dfeab997 --- /dev/null +++ b/src/modules/lazy/wrapper/maybe_data.rs @@ -0,0 +1,47 @@ +use crate::Id; + +#[derive(Debug, Default)] +pub enum MaybeData { + Data(Data), + Id(Id), + #[default] + None, +} + +impl MaybeData { + #[inline] + pub fn data(&self) -> Option<&Data> { + match self { + MaybeData::Data(data) => Some(data), + MaybeData::Id(_id) => None, + MaybeData::None => None, + } + } + + #[inline] + pub fn data_mut(&mut self) -> Option<&mut Data> { + match self { + MaybeData::Data(data) => Some(data), + MaybeData::Id(_id) => None, + MaybeData::None => None + } + } + + #[inline] + pub fn id(&self) -> Option<&Id> { + match self { + MaybeData::Data(_data) => None, + MaybeData::Id(id) => Some(id), + MaybeData::None => None, + } + } + + #[inline] + pub fn id_mut(&mut self) -> Option<&mut Id> { + match self { + MaybeData::Data(_data) => None, + MaybeData::Id(id) => Some(id), + MaybeData::None => None, + } + } +}