Skip to content

Commit

Permalink
Fix stack dev
Browse files Browse the repository at this point in the history
  • Loading branch information
elftausend committed Nov 24, 2024
1 parent 92d3db4 commit 9085a83
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 33 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ min-cl = { git = "https://github.com/elftausend/min-cl", optional = true }
[features]
# default = ["cpu", "blas", "static-api", "macro", "cached", "autograd", "stack", "opencl", "fork", "graph", "untyped"]

default = ["cpu", "cached", "autograd"]
default = ["cpu", "cached", "autograd", "static-api", "blas", "macro", "stack"]
# default = ["no-std"]
# default = ["opencl"]
# default = ["untyped", "cpu", "lazy", "graph", "autograd", "fork", "serde", "json", "half", "cached", "static-api", "stack", "opencl", "nnapi"]
Expand Down
25 changes: 12 additions & 13 deletions src/devices/stack/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@ use core::ops::{AddAssign, Deref, DerefMut};
pub use stack_device::*;

use crate::{
cpu_stack_ops::clear_slice, ApplyFunction, Buffer, ClearBuf, Device, Eval, MayToCLSource,
OnDropBuffer, Resolve, Retrieve, Retriever, Shape, ToVal, UnaryGrad, Unit, ZeroGrad,
cpu_stack_ops::clear_slice, ApplyFunction, Buffer, ClearBuf, Device, Eval, MayToCLSource, Resolve, Retrieve, Retriever, Shape, ToVal, UnaryGrad, Unit, WrappedData, ZeroGrad
};

// #[impl_stack]
Expand All @@ -29,23 +28,23 @@ where
impl<Mods, T> ZeroGrad<T> for Stack<Mods>
where
T: Unit + Default,
Mods: OnDropBuffer,
Mods: WrappedData,
{
#[inline]
fn zero_grad<S: Shape>(&self, data: &mut Self::Base<T, S>) {
clear_slice(data)
}
}

impl<Mods, T, D, S> ApplyFunction<T, S, D> for Stack<Mods>
impl<'a, Mods, T, D, S> ApplyFunction<'a, T, S, D> for Stack<Mods>
where
Mods: Retrieve<Self, T, S>,
Mods: Retrieve<'a, Self, T, S>,
T: Unit + Copy + Default + ToVal + 'static,
D: Device,
D::Base<T, S>: Deref<Target = [T]>,
S: Shape,
{
fn apply_fn<F>(&self, buf: &Buffer<T, D, S>, f: impl Fn(Resolve<T>) -> F) -> Buffer<T, Self, S>
fn apply_fn<F>(&'a self, buf: &Buffer<T, D, S>, f: impl Fn(Resolve<T>) -> F) -> Buffer<'a, T, Self, S>
where
F: Eval<T> + MayToCLSource,
{
Expand All @@ -59,7 +58,7 @@ where

impl<Mods, T, D, S> UnaryGrad<T, S, D> for Stack<Mods>
where
Mods: OnDropBuffer,
Mods: WrappedData,
T: Unit + AddAssign + Copy + core::ops::Mul<Output = T>,
S: Shape,
D: Device,
Expand Down Expand Up @@ -90,8 +89,8 @@ mod tests {

use super::stack_device::Stack;

pub trait AddBuf<T: Unit, D: Device = Self, S: Shape = ()>: Device {
fn add(&self, lhs: &Buffer<T, D, S>, rhs: &Buffer<T, D, S>) -> Buffer<T, Self, S>;
pub trait AddBuf<'a, T: Unit, D: Device = Self, S: Shape = ()>: Device {
fn add(&'a self, lhs: &Buffer<T, D, S>, rhs: &Buffer<T, D, S>) -> Buffer<'a, T, Self, S>;
}

/*// Without stack support
Expand All @@ -111,13 +110,13 @@ mod tests {
}
}*/

impl<Mods: Retrieve<Self, T>, T, D> AddBuf<T, D> for CPU<Mods>
impl<'a, Mods: Retrieve<'a, Self, T>, T, D> AddBuf<'a, T, D> for CPU<Mods>
where
D: Device,
D::Base<T, ()>: Deref<Target = [T]>,
T: Unit + Add<Output = T> + Copy,
{
fn add(&self, lhs: &Buffer<T, D>, rhs: &Buffer<T, D>) -> Buffer<T, Self> {
fn add(&'a self, lhs: &Buffer<T, D>, rhs: &Buffer<T, D>) -> Buffer<'a, T, Self> {
let len = core::cmp::min(lhs.len(), rhs.len());

let mut out = self.retrieve(len, (lhs, rhs)).unwrap();
Expand All @@ -126,14 +125,14 @@ mod tests {
}
}

impl<T, D, S: Shape> AddBuf<T, D, S> for Stack
impl<'a, T, D, S: Shape> AddBuf<'a, T, D, S> for Stack
where
Stack: Alloc<T>,
D: Device,
D::Base<T, S>: Deref<Target = [T]>,
T: Unit + Add<Output = T> + Copy + Default,
{
fn add(&self, lhs: &Buffer<T, D, S>, rhs: &Buffer<T, D, S>) -> Buffer<T, Self, S> {
fn add(&'a self, lhs: &Buffer<T, D, S>, rhs: &Buffer<T, D, S>) -> Buffer<'a, T, Self, S> {
let mut out = self.retrieve(S::LEN, (lhs, rhs)).unwrap();

for i in 0..S::LEN {
Expand Down
48 changes: 29 additions & 19 deletions src/devices/stack/stack_device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::{
flag::AllocFlag, impl_buffer_hook_traits, impl_retriever, impl_wrapped_data,
pass_down_add_operation, pass_down_cursor, pass_down_grad_fn, pass_down_tape_actions,
pass_down_use_gpu_or_cpu, shape::Shape, Alloc, Base, Buffer, CloneBuf, Device, DeviceError,
DevicelessAble, OnDropBuffer, Read, StackArray, Unit, WrappedData, WriteBuf,
DevicelessAble, Read, StackArray, Unit, WrappedData, WriteBuf,
};

/// A device that allocates memory on the stack.
Expand Down Expand Up @@ -32,8 +32,8 @@ pass_down_add_operation!(Stack);

impl<'a, T: Unit + Copy + Default, S: Shape> DevicelessAble<'a, T, S> for Stack {}

impl<Mods: OnDropBuffer> Device for Stack<Mods> {
type Data<U: Unit, S: Shape> = Self::Wrap<U, Self::Base<U, S>>;
impl<Mods: WrappedData> Device for Stack<Mods> {
type Data<'a, U: Unit, S: Shape> = Self::Wrap<'a, U, Self::Base<U, S>>;
type Base<T: Unit, S: Shape> = StackArray<S, T>;
type Error = Infallible;

Expand All @@ -42,34 +42,44 @@ impl<Mods: OnDropBuffer> Device for Stack<Mods> {
}

#[inline]
fn base_to_data<T: Unit, S: Shape>(&self, base: Self::Base<T, S>) -> Self::Data<T, S> {
self.wrap_in_base(base)
}

#[inline]
fn wrap_to_data<T: Unit, S: Shape>(
fn wrap_to_data<'a, T: Unit, S: Shape>(
&self,
wrap: Self::Wrap<T, Self::Base<T, S>>,
) -> Self::Data<T, S> {
wrap: Self::Wrap<'a, T, Self::Base<T, S>>,
) -> Self::Data<'a, T, S> {
wrap
}

#[inline]
fn data_as_wrap<T: Unit, S: Shape>(
data: &Self::Data<T, S>,
) -> &Self::Wrap<T, Self::Base<T, S>> {
fn data_as_wrap<'a, 'b, T: Unit, S: Shape>(
data: &'b Self::Data<'a, T, S>,
) -> &'b Self::Wrap<'a, T, Self::Base<T, S>> {
data
}

#[inline]
fn data_as_wrap_mut<T: Unit, S: Shape>(
data: &mut Self::Data<T, S>,
) -> &mut Self::Wrap<T, Self::Base<T, S>> {
fn data_as_wrap_mut<'a, 'b, T: Unit, S: Shape>(
data: &'b mut Self::Data<'a, T, S>,
) -> &'b mut Self::Wrap<'a, T, Self::Base<T, S>> {
data
}

#[inline]
fn default_base_to_data<'a, T: Unit, S: Shape>(
&'a self,
base: Self::Base<T, S>,
) -> Self::Data<'a, T, S> {
self.wrap_in_base(base)
}

fn default_base_to_data_unbound<'a, T: Unit, S: Shape>(
&self,
base: Self::Base<T, S>,
) -> Self::Data<'a, T, S> {
self.wrap_in_base_unbound(base)
}
}

impl<Mods: OnDropBuffer, T: Unit + Copy + Default> Alloc<T> for Stack<Mods> {
impl<Mods: WrappedData, T: Unit + Copy + Default> Alloc<T> for Stack<Mods> {
#[inline]
fn alloc<S: Shape>(&self, _len: usize, _flag: AllocFlag) -> crate::Result<StackArray<S, T>> {
Ok(StackArray::new())
Expand Down Expand Up @@ -129,7 +139,7 @@ where

impl<'a, T: Unit, S: Shape> CloneBuf<'a, T, S> for Stack
where
<Stack as Device>::Data<T, S>: Copy,
<Stack as Device>::Data<'a, T, S>: Copy,
{
#[inline]
fn clone_buf(&'a self, buf: &Buffer<'a, T, Self, S>) -> Buffer<'a, T, Self, S> {
Expand Down

0 comments on commit 9085a83

Please sign in to comment.