Skip to content

Commit

Permalink
Add Lazy alloc flag, set to lazy alloc flag when lazy wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
elftausend committed Dec 10, 2023
1 parent 126f240 commit 396ddc9
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 28 deletions.
10 changes: 6 additions & 4 deletions src/buffer/num.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ impl<T> From<T> for Num<T> {
impl Device for () {
type Data<T, S: crate::Shape> = Self::Base<T, S>;
type Base<T, S> = Num<T>;

type Error = Infallible;

fn new() -> Result<Self, Infallible> {
Expand All @@ -71,12 +71,14 @@ impl Device for () {
fn base_to_data<T, S: crate::Shape>(&self, base: Self::Base<T, S>) -> Self::Data<T, S> {
base
}

#[inline(always)]
fn wrap_to_data<T, S: crate::Shape>(&self, wrap: Self::Wrap<T, Self::Base<T, S>>) -> Self::Data<T, S> {
fn wrap_to_data<T, S: crate::Shape>(
&self,
wrap: Self::Wrap<T, Self::Base<T, S>>,
) -> Self::Data<T, S> {
wrap
}

}

impl<T: Default> Alloc<T> for () {
Expand Down
5 changes: 1 addition & 4 deletions src/devices/cpu/cpu_device.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
use core::{
convert::Infallible,
ops::DerefMut,
};
use core::{convert::Infallible, ops::DerefMut};

use crate::{
cpu::CPUPtr, flag::AllocFlag, impl_buffer_hook_traits, impl_retriever, pass_down_grad_fn,
Expand Down
1 change: 1 addition & 0 deletions src/flag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ pub enum AllocFlag {
Num,
/// Similiar to `None`, but the resulting [`Buffer`](crate::Buffer) is borrowed and not owned.
BorrowedCache,
Lazy,
}

impl PartialEq for AllocFlag {
Expand Down
3 changes: 1 addition & 2 deletions src/modules/base.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use crate::{
flag::AllocFlag, AddGradFn, AddOperation, Alloc, Device, ExecNow, HasId, HashLocation, Module,
OnDropBuffer, OnNewBuffer, Parents, PtrType, Retrieve, Setup, Shape,
WrappedData,
OnDropBuffer, OnNewBuffer, Parents, PtrType, Retrieve, Setup, Shape, WrappedData,
};

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
Expand Down
4 changes: 2 additions & 2 deletions src/modules/cached.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use core::{cell::RefCell, marker::PhantomData};

use crate::{
AddGradFn, AddOperation, Alloc, Buffer, Cache, Device, DeviceError, ExecNow, HasId, Module,
OnDropBuffer, OnNewBuffer, OptimizeMemGraph, Parents, PtrType, Retrieve, RunModule,
Setup, ShallowCopy, Shape, WrappedData,
OnDropBuffer, OnNewBuffer, OptimizeMemGraph, Parents, PtrType, Retrieve, RunModule, Setup,
ShallowCopy, Shape, WrappedData,
};

// creator struct
Expand Down
31 changes: 19 additions & 12 deletions src/modules/lazy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,16 @@ mod wrapper;
pub use ty::*;

use crate::{
pass_down_tape_actions, AddOperation, Alloc, Buffer, Device, ExecNow, HasId, Module, NoHasher,
OnDropBuffer, OnNewBuffer, Parents, PtrConv, Retrieve, RunModule, Setup, ShallowCopy, Shape,
UniqueId, UpdateArgs, WrappedData,
pass_down_tape_actions, AddOperation, Alloc, Buffer, Device, ExecNow, HasId, Id, Module,
NoHasher, OnDropBuffer, OnNewBuffer, Parents, PtrConv, Retrieve, RunModule, Setup, ShallowCopy,
Shape, UniqueId, UpdateArgs, WrappedData,
};
use core::{
any::Any,
cell::{Cell, RefCell},
fmt::Debug,
hash::BuildHasherDefault,
};
use core::{any::Any, cell::RefCell, fmt::Debug, hash::BuildHasherDefault};
use std::collections::HashMap;

pub use self::lazy_graph::LazyGraph;
Expand All @@ -18,6 +23,7 @@ use super::register_buf;
#[derive(Default)]
pub struct Lazy<Mods> {
pub modules: Mods,
pub id_count: Cell<u64>,
buffers: RefCell<HashMap<UniqueId, Box<dyn Any>, BuildHasherDefault<NoHasher>>>,
graph: RefCell<LazyGraph>,
}
Expand Down Expand Up @@ -51,6 +57,7 @@ impl<Mods: Module<D>, D: LazySetup + Device> Module<D> for Lazy<Mods> {
modules: Mods::new(),
buffers: Default::default(),
graph: Default::default(),
id_count: Default::default(),
}
}
}
Expand Down Expand Up @@ -151,20 +158,21 @@ where
#[inline]
fn retrieve<const NUM_PARENTS: usize>(
&self,
device: &D,
_device: &D,
len: usize,
parents: impl Parents<NUM_PARENTS>,
// ) -> D::Data<T, S>
_parents: impl Parents<NUM_PARENTS>,
) -> Self::Wrap<T, D::Base<T, S>>
where
S: Shape,
D: Alloc<T>,
{
// self.modules.retrieve(device, len, parents)
LazyWrapper {
data: Some(self.modules.retrieve(device, len, parents)),
// id: Some(),
id: None,
data: None,
id: Some(Id {
id: self.id_count.get(),
len,
}),
_pd: core::marker::PhantomData,
}
}
Expand All @@ -174,7 +182,7 @@ where
where
D: Alloc<T>,
{
unsafe { register_buf(&mut self.buffers.borrow_mut(), retrieved_buf) };
// unsafe { register_buf(&mut self.buffers.borrow_mut(), retrieved_buf) };

// pass down
self.modules.on_retrieve_finish(retrieved_buf)
Expand All @@ -201,7 +209,6 @@ mod tests {

let x: Buffer<i32, _> = device.retrieve(10, ());
let res = &x.data;

}

#[test]
Expand Down
11 changes: 7 additions & 4 deletions src/modules/lazy/wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use core::{
ops::{Deref, DerefMut},
};

use crate::{HasId, HostPtr, Id, Lazy, PtrType, ShallowCopy, Shape, WrappedData};
use crate::{flag::AllocFlag, HasId, HostPtr, Id, Lazy, PtrType, ShallowCopy, Shape, WrappedData};

#[derive(Debug, Default)]
pub struct LazyWrapper<Data, T> {
Expand Down Expand Up @@ -46,12 +46,15 @@ impl<Data: PtrType, T> PtrType for LazyWrapper<Data, T> {
}

#[inline]
fn flag(&self) -> crate::flag::AllocFlag {
self.data.as_ref().unwrap().flag()
fn flag(&self) -> AllocFlag {
self.data
.as_ref()
.map(|data| data.flag())
.unwrap_or(AllocFlag::Lazy)
}

#[inline]
unsafe fn set_flag(&mut self, flag: crate::flag::AllocFlag) {
unsafe fn set_flag(&mut self, flag: AllocFlag) {
self.data.as_mut().unwrap().set_flag(flag)
}
}
Expand Down

0 comments on commit 396ddc9

Please sign in to comment.