From 9085a831c2bb1e916cdffb4636488b143ae93c3c Mon Sep 17 00:00:00 2001 From: elftausend <76885970+elftausend@users.noreply.github.com> Date: Sun, 24 Nov 2024 12:43:32 +0100 Subject: [PATCH] Fix stack dev --- Cargo.toml | 2 +- src/devices/stack/mod.rs | 25 ++++++++-------- src/devices/stack/stack_device.rs | 48 +++++++++++++++++++------------ 3 files changed, 42 insertions(+), 33 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 1a947d98..086cfbd2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -53,7 +53,7 @@ min-cl = { git = "https://github.com/elftausend/min-cl", optional = true } [features] # default = ["cpu", "blas", "static-api", "macro", "cached", "autograd", "stack", "opencl", "fork", "graph", "untyped"] -default = ["cpu", "cached", "autograd"] +default = ["cpu", "cached", "autograd", "static-api", "blas", "macro", "stack"] # default = ["no-std"] # default = ["opencl"] # default = ["untyped", "cpu", "lazy", "graph", "autograd", "fork", "serde", "json", "half", "cached", "static-api", "stack", "opencl", "nnapi"] diff --git a/src/devices/stack/mod.rs b/src/devices/stack/mod.rs index ed8fb85d..b5b0d0e8 100644 --- a/src/devices/stack/mod.rs +++ b/src/devices/stack/mod.rs @@ -8,8 +8,7 @@ use core::ops::{AddAssign, Deref, DerefMut}; pub use stack_device::*; use crate::{ - cpu_stack_ops::clear_slice, ApplyFunction, Buffer, ClearBuf, Device, Eval, MayToCLSource, - OnDropBuffer, Resolve, Retrieve, Retriever, Shape, ToVal, UnaryGrad, Unit, ZeroGrad, + cpu_stack_ops::clear_slice, ApplyFunction, Buffer, ClearBuf, Device, Eval, MayToCLSource, Resolve, Retrieve, Retriever, Shape, ToVal, UnaryGrad, Unit, WrappedData, ZeroGrad }; // #[impl_stack] @@ -29,7 +28,7 @@ where impl ZeroGrad for Stack where T: Unit + Default, - Mods: OnDropBuffer, + Mods: WrappedData, { #[inline] fn zero_grad(&self, data: &mut Self::Base) { @@ -37,15 +36,15 @@ where } } -impl ApplyFunction for Stack +impl<'a, Mods, T, D, S> ApplyFunction<'a, T, S, D> for Stack where - Mods: Retrieve, + Mods: Retrieve<'a, Self, T, S>, T: Unit + Copy + Default + ToVal + 'static, D: Device, D::Base: Deref, S: Shape, { - fn apply_fn(&self, buf: &Buffer, f: impl Fn(Resolve) -> F) -> Buffer + fn apply_fn(&'a self, buf: &Buffer, f: impl Fn(Resolve) -> F) -> Buffer<'a, T, Self, S> where F: Eval + MayToCLSource, { @@ -59,7 +58,7 @@ where impl UnaryGrad for Stack where - Mods: OnDropBuffer, + Mods: WrappedData, T: Unit + AddAssign + Copy + core::ops::Mul, S: Shape, D: Device, @@ -90,8 +89,8 @@ mod tests { use super::stack_device::Stack; - pub trait AddBuf: Device { - fn add(&self, lhs: &Buffer, rhs: &Buffer) -> Buffer; + pub trait AddBuf<'a, T: Unit, D: Device = Self, S: Shape = ()>: Device { + fn add(&'a self, lhs: &Buffer, rhs: &Buffer) -> Buffer<'a, T, Self, S>; } /*// Without stack support @@ -111,13 +110,13 @@ mod tests { } }*/ - impl, T, D> AddBuf for CPU + impl<'a, Mods: Retrieve<'a, Self, T>, T, D> AddBuf<'a, T, D> for CPU where D: Device, D::Base: Deref, T: Unit + Add + Copy, { - fn add(&self, lhs: &Buffer, rhs: &Buffer) -> Buffer { + fn add(&'a self, lhs: &Buffer, rhs: &Buffer) -> Buffer<'a, T, Self> { let len = core::cmp::min(lhs.len(), rhs.len()); let mut out = self.retrieve(len, (lhs, rhs)).unwrap(); @@ -126,14 +125,14 @@ mod tests { } } - impl AddBuf for Stack + impl<'a, T, D, S: Shape> AddBuf<'a, T, D, S> for Stack where Stack: Alloc, D: Device, D::Base: Deref, T: Unit + Add + Copy + Default, { - fn add(&self, lhs: &Buffer, rhs: &Buffer) -> Buffer { + fn add(&'a self, lhs: &Buffer, rhs: &Buffer) -> Buffer<'a, T, Self, S> { let mut out = self.retrieve(S::LEN, (lhs, rhs)).unwrap(); for i in 0..S::LEN { diff --git a/src/devices/stack/stack_device.rs b/src/devices/stack/stack_device.rs index 7a1c10cc..33cfb7f9 100644 --- a/src/devices/stack/stack_device.rs +++ b/src/devices/stack/stack_device.rs @@ -4,7 +4,7 @@ use crate::{ flag::AllocFlag, impl_buffer_hook_traits, impl_retriever, impl_wrapped_data, pass_down_add_operation, pass_down_cursor, pass_down_grad_fn, pass_down_tape_actions, pass_down_use_gpu_or_cpu, shape::Shape, Alloc, Base, Buffer, CloneBuf, Device, DeviceError, - DevicelessAble, OnDropBuffer, Read, StackArray, Unit, WrappedData, WriteBuf, + DevicelessAble, Read, StackArray, Unit, WrappedData, WriteBuf, }; /// A device that allocates memory on the stack. @@ -32,8 +32,8 @@ pass_down_add_operation!(Stack); impl<'a, T: Unit + Copy + Default, S: Shape> DevicelessAble<'a, T, S> for Stack {} -impl Device for Stack { - type Data = Self::Wrap>; +impl Device for Stack { + type Data<'a, U: Unit, S: Shape> = Self::Wrap<'a, U, Self::Base>; type Base = StackArray; type Error = Infallible; @@ -42,34 +42,44 @@ impl Device for Stack { } #[inline] - fn base_to_data(&self, base: Self::Base) -> Self::Data { - self.wrap_in_base(base) - } - - #[inline] - fn wrap_to_data( + fn wrap_to_data<'a, T: Unit, S: Shape>( &self, - wrap: Self::Wrap>, - ) -> Self::Data { + wrap: Self::Wrap<'a, T, Self::Base>, + ) -> Self::Data<'a, T, S> { wrap } #[inline] - fn data_as_wrap( - data: &Self::Data, - ) -> &Self::Wrap> { + fn data_as_wrap<'a, 'b, T: Unit, S: Shape>( + data: &'b Self::Data<'a, T, S>, + ) -> &'b Self::Wrap<'a, T, Self::Base> { data } #[inline] - fn data_as_wrap_mut( - data: &mut Self::Data, - ) -> &mut Self::Wrap> { + fn data_as_wrap_mut<'a, 'b, T: Unit, S: Shape>( + data: &'b mut Self::Data<'a, T, S>, + ) -> &'b mut Self::Wrap<'a, T, Self::Base> { data } + + #[inline] + fn default_base_to_data<'a, T: Unit, S: Shape>( + &'a self, + base: Self::Base, + ) -> Self::Data<'a, T, S> { + self.wrap_in_base(base) + } + + fn default_base_to_data_unbound<'a, T: Unit, S: Shape>( + &self, + base: Self::Base, + ) -> Self::Data<'a, T, S> { + self.wrap_in_base_unbound(base) + } } -impl Alloc for Stack { +impl Alloc for Stack { #[inline] fn alloc(&self, _len: usize, _flag: AllocFlag) -> crate::Result> { Ok(StackArray::new()) @@ -129,7 +139,7 @@ where impl<'a, T: Unit, S: Shape> CloneBuf<'a, T, S> for Stack where - ::Data: Copy, + ::Data<'a, T, S>: Copy, { #[inline] fn clone_buf(&'a self, buf: &Buffer<'a, T, Self, S>) -> Buffer<'a, T, Self, S> {