Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MaybeData #65

Merged
merged 2 commits into from
Oct 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ min-cl = { git = "https://github.com/elftausend/min-cl", optional = true }
# min-cl = { version = "0.3.0", optional=true }

[features]
default = ["cpu", "opencl", "blas", "static-api", "stack", "macro", "untyped", "autograd", "autograd", "cached", "lazy", "fork", "graph", "serde",]
default = ["cpu", "opencl", "blas", "static-api", "stack", "macro", "untyped", "autograd", "autograd", "cached", "lazy", "fork", "graph", "serde"]

# default = ["cpu"]
# default = ["no-std"]
# default = ["opencl"]
Expand Down
10 changes: 5 additions & 5 deletions src/modules/lazy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ mod ty;
mod wrapper;

pub use ty::*;
use wrapper::MaybeData;

use crate::{
op_hint::OpHint, register_buf_copyable, unregister_buf_copyable, AddLayer, AddOperation, Alloc,
Expand Down Expand Up @@ -362,8 +363,7 @@ where
unsafe { self.bump_cursor() };

Ok(LazyWrapper {
data: None,
id: Some(id),
maybe_data: MaybeData::Id(id),
_pd: core::marker::PhantomData,
})
}
Expand Down Expand Up @@ -508,15 +508,15 @@ mod tests {
let device = CPU::<Lazy<Base, i32>>::new();
let buf = Buffer::<i32, _>::new(&device, 10);
let res = &buf.data;
assert_eq!(res.id, None);
assert_eq!(res.maybe_data.id(), None);

let x: Buffer<i32, _> = device.retrieve(10, ()).unwrap();
let res = &x.data;
assert_eq!(res.id, Some(crate::Id { id: 0, len: 10 }));
assert_eq!(res.maybe_data.id(), Some(&crate::Id { id: 0, len: 10 }));

let x: Buffer<i32, _> = device.retrieve(10, ()).unwrap();
let res = &x.data;
assert_eq!(res.id, Some(crate::Id { id: 1, len: 10 }));
assert_eq!(res.maybe_data.id(), Some(&crate::Id { id: 1, len: 10 }));
}

#[test]
Expand Down
49 changes: 27 additions & 22 deletions src/modules/lazy/wrapper.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
mod maybe_data;
pub use maybe_data::MaybeData;

use core::{
marker::PhantomData,
ops::{Deref, DerefMut},
};

use crate::{flag::AllocFlag, HasId, HostPtr, Id, Lazy, PtrType, ShallowCopy, WrappedData};
use crate::{flag::AllocFlag, HasId, HostPtr, Lazy, PtrType, ShallowCopy, WrappedData};

#[derive(Debug, Default)]
pub struct LazyWrapper<Data, T> {
pub data: Option<Data>,
pub id: Option<Id>,
pub maybe_data: MaybeData<Data>,
pub _pd: PhantomData<T>,
}

Expand All @@ -18,53 +20,53 @@ impl<T2, Mods: WrappedData> WrappedData for Lazy<'_, Mods, T2> {
#[inline]
fn wrap_in_base<T, Base: HasId + PtrType>(&self, base: Base) -> Self::Wrap<T, Base> {
LazyWrapper {
data: Some(self.modules.wrap_in_base(base)),
id: None,
maybe_data: MaybeData::Data(self.modules.wrap_in_base(base)),
_pd: PhantomData,
}
}

#[inline]
fn wrapped_as_base<T, Base: HasId + PtrType>(wrap: &Self::Wrap<T, Base>) -> &Base {
Mods::wrapped_as_base(wrap.data.as_ref().expect(MISSING_DATA))
Mods::wrapped_as_base(wrap.maybe_data.data().expect(MISSING_DATA))
}

#[inline]
fn wrapped_as_base_mut<T, Base: HasId + PtrType>(wrap: &mut Self::Wrap<T, Base>) -> &mut Base {
Mods::wrapped_as_base_mut(wrap.data.as_mut().expect(MISSING_DATA))
Mods::wrapped_as_base_mut(wrap.maybe_data.data_mut().expect(MISSING_DATA))
}
}

impl<Data: HasId, T> HasId for LazyWrapper<Data, T> {
#[inline]
fn id(&self) -> crate::Id {
match self.id {
Some(id) => id,
None => self.data.as_ref().unwrap().id(),
match self.maybe_data {
MaybeData::Data(ref data) => data.id(),
MaybeData::Id(id) => id,
MaybeData::None => unimplemented!()
}
}
}

impl<Data: PtrType, T> PtrType for LazyWrapper<Data, T> {
#[inline]
fn size(&self) -> usize {
match self.id {
Some(id) => id.len,
None => self.data.as_ref().unwrap().size(),
match self.maybe_data {
MaybeData::Data(ref data) => data.size(),
MaybeData::Id(id) => id.len,
MaybeData::None => unimplemented!()
}
}

#[inline]
fn flag(&self) -> AllocFlag {
self.data
.as_ref()
self.maybe_data.data()
.map(|data| data.flag())
.unwrap_or(AllocFlag::Lazy)
}

#[inline]
unsafe fn set_flag(&mut self, flag: AllocFlag) {
self.data.as_mut().unwrap().set_flag(flag)
self.maybe_data.data_mut().unwrap().set_flag(flag)
}
}

Expand All @@ -76,35 +78,38 @@ impl<Data: Deref<Target = [T]>, T> Deref for LazyWrapper<Data, T> {

#[inline]
fn deref(&self) -> &Self::Target {
self.data.as_ref().expect(MISSING_DATA)
self.maybe_data.data().expect(MISSING_DATA)
}
}

impl<Data: DerefMut<Target = [T]>, T> DerefMut for LazyWrapper<Data, T> {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
self.data.as_mut().expect(MISSING_DATA)
self.maybe_data.data_mut().expect(MISSING_DATA)
}
}

impl<T, Data: HostPtr<T>> HostPtr<T> for LazyWrapper<Data, T> {
#[inline]
fn ptr(&self) -> *const T {
self.data.as_ref().unwrap().ptr()
self.maybe_data.data().unwrap().ptr()
}

#[inline]
fn ptr_mut(&mut self) -> *mut T {
self.data.as_mut().unwrap().ptr_mut()
self.maybe_data.data_mut().unwrap().ptr_mut()
}
}

impl<Data: ShallowCopy, T> ShallowCopy for LazyWrapper<Data, T> {
#[inline]
unsafe fn shallow(&self) -> Self {
LazyWrapper {
id: self.id,
data: self.data.as_ref().map(|data| data.shallow()),
maybe_data: match &self.maybe_data {
MaybeData::Data(data) => MaybeData::Data(data.shallow()),
MaybeData::Id(id) => MaybeData::Id(*id),
MaybeData::None => unimplemented!(),
},
_pd: PhantomData,
}
}
Expand Down
47 changes: 47 additions & 0 deletions src/modules/lazy/wrapper/maybe_data.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
use crate::Id;

#[derive(Debug, Default)]
pub enum MaybeData<Data> {
Data(Data),
Id(Id),
#[default]
None,
}

impl<Data> MaybeData<Data> {
#[inline]
pub fn data(&self) -> Option<&Data> {
match self {
MaybeData::Data(data) => Some(data),
MaybeData::Id(_id) => None,
MaybeData::None => None,
}
}

#[inline]
pub fn data_mut(&mut self) -> Option<&mut Data> {
match self {
MaybeData::Data(data) => Some(data),
MaybeData::Id(_id) => None,
MaybeData::None => None
}
}

#[inline]
pub fn id(&self) -> Option<&Id> {
match self {
MaybeData::Data(_data) => None,
MaybeData::Id(id) => Some(id),
MaybeData::None => None,
}
}

#[inline]
pub fn id_mut(&mut self) -> Option<&mut Id> {
match self {
MaybeData::Data(_data) => None,
MaybeData::Id(id) => Some(id),
MaybeData::None => None,
}
}
}
Loading