From 396ddc91cb7c80660a5ad775809fcf99c8e20cab Mon Sep 17 00:00:00 2001 From: elftausend <76885970+elftausend@users.noreply.github.com> Date: Sun, 10 Dec 2023 22:33:37 +0100 Subject: [PATCH] Add Lazy alloc flag, set to lazy alloc flag when lazy wrapper --- src/buffer/num.rs | 10 ++++++---- src/devices/cpu/cpu_device.rs | 5 +---- src/flag.rs | 1 + src/modules/base.rs | 3 +-- src/modules/cached.rs | 4 ++-- src/modules/lazy.rs | 31 +++++++++++++++++++------------ src/modules/lazy/wrapper.rs | 11 +++++++---- 7 files changed, 37 insertions(+), 28 deletions(-) diff --git a/src/buffer/num.rs b/src/buffer/num.rs index ae775882..abcf2920 100644 --- a/src/buffer/num.rs +++ b/src/buffer/num.rs @@ -60,7 +60,7 @@ impl From for Num { impl Device for () { type Data = Self::Base; type Base = Num; - + type Error = Infallible; fn new() -> Result { @@ -71,12 +71,14 @@ impl Device for () { fn base_to_data(&self, base: Self::Base) -> Self::Data { base } - + #[inline(always)] - fn wrap_to_data(&self, wrap: Self::Wrap>) -> Self::Data { + fn wrap_to_data( + &self, + wrap: Self::Wrap>, + ) -> Self::Data { wrap } - } impl Alloc for () { diff --git a/src/devices/cpu/cpu_device.rs b/src/devices/cpu/cpu_device.rs index 5b728303..367f6a1a 100644 --- a/src/devices/cpu/cpu_device.rs +++ b/src/devices/cpu/cpu_device.rs @@ -1,7 +1,4 @@ -use core::{ - convert::Infallible, - ops::DerefMut, -}; +use core::{convert::Infallible, ops::DerefMut}; use crate::{ cpu::CPUPtr, flag::AllocFlag, impl_buffer_hook_traits, impl_retriever, pass_down_grad_fn, diff --git a/src/flag.rs b/src/flag.rs index f9aa4d1b..e790843f 100644 --- a/src/flag.rs +++ b/src/flag.rs @@ -13,6 +13,7 @@ pub enum AllocFlag { Num, /// Similiar to `None`, but the resulting [`Buffer`](crate::Buffer) is borrowed and not owned. BorrowedCache, + Lazy, } impl PartialEq for AllocFlag { diff --git a/src/modules/base.rs b/src/modules/base.rs index 8657a8ff..c6c5a1e5 100644 --- a/src/modules/base.rs +++ b/src/modules/base.rs @@ -1,7 +1,6 @@ use crate::{ flag::AllocFlag, AddGradFn, AddOperation, Alloc, Device, ExecNow, HasId, HashLocation, Module, - OnDropBuffer, OnNewBuffer, Parents, PtrType, Retrieve, Setup, Shape, - WrappedData, + OnDropBuffer, OnNewBuffer, Parents, PtrType, Retrieve, Setup, Shape, WrappedData, }; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)] diff --git a/src/modules/cached.rs b/src/modules/cached.rs index 9dd22594..821c284c 100644 --- a/src/modules/cached.rs +++ b/src/modules/cached.rs @@ -2,8 +2,8 @@ use core::{cell::RefCell, marker::PhantomData}; use crate::{ AddGradFn, AddOperation, Alloc, Buffer, Cache, Device, DeviceError, ExecNow, HasId, Module, - OnDropBuffer, OnNewBuffer, OptimizeMemGraph, Parents, PtrType, Retrieve, RunModule, - Setup, ShallowCopy, Shape, WrappedData, + OnDropBuffer, OnNewBuffer, OptimizeMemGraph, Parents, PtrType, Retrieve, RunModule, Setup, + ShallowCopy, Shape, WrappedData, }; // creator struct diff --git a/src/modules/lazy.rs b/src/modules/lazy.rs index cb2b747b..f752623d 100644 --- a/src/modules/lazy.rs +++ b/src/modules/lazy.rs @@ -4,11 +4,16 @@ mod wrapper; pub use ty::*; use crate::{ - pass_down_tape_actions, AddOperation, Alloc, Buffer, Device, ExecNow, HasId, Module, NoHasher, - OnDropBuffer, OnNewBuffer, Parents, PtrConv, Retrieve, RunModule, Setup, ShallowCopy, Shape, - UniqueId, UpdateArgs, WrappedData, + pass_down_tape_actions, AddOperation, Alloc, Buffer, Device, ExecNow, HasId, Id, Module, + NoHasher, OnDropBuffer, OnNewBuffer, Parents, PtrConv, Retrieve, RunModule, Setup, ShallowCopy, + Shape, UniqueId, UpdateArgs, WrappedData, +}; +use core::{ + any::Any, + cell::{Cell, RefCell}, + fmt::Debug, + hash::BuildHasherDefault, }; -use core::{any::Any, cell::RefCell, fmt::Debug, hash::BuildHasherDefault}; use std::collections::HashMap; pub use self::lazy_graph::LazyGraph; @@ -18,6 +23,7 @@ use super::register_buf; #[derive(Default)] pub struct Lazy { pub modules: Mods, + pub id_count: Cell, buffers: RefCell, BuildHasherDefault>>, graph: RefCell, } @@ -51,6 +57,7 @@ impl, D: LazySetup + Device> Module for Lazy { modules: Mods::new(), buffers: Default::default(), graph: Default::default(), + id_count: Default::default(), } } } @@ -151,10 +158,9 @@ where #[inline] fn retrieve( &self, - device: &D, + _device: &D, len: usize, - parents: impl Parents, - // ) -> D::Data + _parents: impl Parents, ) -> Self::Wrap> where S: Shape, @@ -162,9 +168,11 @@ where { // self.modules.retrieve(device, len, parents) LazyWrapper { - data: Some(self.modules.retrieve(device, len, parents)), - // id: Some(), - id: None, + data: None, + id: Some(Id { + id: self.id_count.get(), + len, + }), _pd: core::marker::PhantomData, } } @@ -174,7 +182,7 @@ where where D: Alloc, { - unsafe { register_buf(&mut self.buffers.borrow_mut(), retrieved_buf) }; + // unsafe { register_buf(&mut self.buffers.borrow_mut(), retrieved_buf) }; // pass down self.modules.on_retrieve_finish(retrieved_buf) @@ -201,7 +209,6 @@ mod tests { let x: Buffer = device.retrieve(10, ()); let res = &x.data; - } #[test] diff --git a/src/modules/lazy/wrapper.rs b/src/modules/lazy/wrapper.rs index 9257d262..8a9a0ae6 100644 --- a/src/modules/lazy/wrapper.rs +++ b/src/modules/lazy/wrapper.rs @@ -3,7 +3,7 @@ use core::{ ops::{Deref, DerefMut}, }; -use crate::{HasId, HostPtr, Id, Lazy, PtrType, ShallowCopy, Shape, WrappedData}; +use crate::{flag::AllocFlag, HasId, HostPtr, Id, Lazy, PtrType, ShallowCopy, Shape, WrappedData}; #[derive(Debug, Default)] pub struct LazyWrapper { @@ -46,12 +46,15 @@ impl PtrType for LazyWrapper { } #[inline] - fn flag(&self) -> crate::flag::AllocFlag { - self.data.as_ref().unwrap().flag() + fn flag(&self) -> AllocFlag { + self.data + .as_ref() + .map(|data| data.flag()) + .unwrap_or(AllocFlag::Lazy) } #[inline] - unsafe fn set_flag(&mut self, flag: crate::flag::AllocFlag) { + unsafe fn set_flag(&mut self, flag: AllocFlag) { self.data.as_mut().unwrap().set_flag(flag) } }