Skip to content

Commit

Permalink
Fix more tests, add retriever impl via trait bound variable impl_retr…
Browse files Browse the repository at this point in the history
…iever! macro
  • Loading branch information
elftausend committed Aug 15, 2023
1 parent e997f9b commit 2afadac
Show file tree
Hide file tree
Showing 21 changed files with 61 additions and 203 deletions.
1 change: 0 additions & 1 deletion src/device_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,5 @@ pub trait Retriever<T>: Device {
parents: impl Parents<NUM_PARENTS>,
) -> Buffer<T, Self, S>
where
T: 'static,
S: Shape;
}
12 changes: 8 additions & 4 deletions src/devices.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ pub use cdatatype::*;
#[cfg(all(any(feature = "cpu", feature = "stack"), feature = "macro"))]
mod cpu_stack_ops;

use crate::{Alloc, Buffer, HasId, OnDropBuffer, PtrType, Shape};
use crate::{Buffer, HasId, OnDropBuffer, PtrType, Shape};

pub trait Device: OnDropBuffer + Sized {
type Data<T, S: Shape>: HasId + PtrType;
Expand Down Expand Up @@ -89,8 +89,8 @@ macro_rules! impl_buffer_hook_traits {

#[macro_export]
macro_rules! impl_retriever {
($device:ident) => {
impl<T: 'static, Mods: crate::Retrieve<Self, T>> crate::Retriever<T> for $device<Mods> {
($device:ident, $($trait_bounds:tt)*) => {
impl<T: $( $trait_bounds )*, Mods: crate::Retrieve<Self, T>> crate::Retriever<T> for $device<Mods> {
#[inline]
fn retrieve<S: Shape, const NUM_PARENTS: usize>(
&self,
Expand All @@ -99,7 +99,7 @@ macro_rules! impl_retriever {
) -> Buffer<T, Self, S> {
let data = self
.modules
.retrieve::<T, S, NUM_PARENTS>(self, len, parents);
.retrieve::<S, NUM_PARENTS>(self, len, parents);
let buf = Buffer {
data,
device: Some(self),
Expand All @@ -109,4 +109,8 @@ macro_rules! impl_retriever {
}
}
};

($device:ident) => {
impl_retriever!($device, Sized);
}
}
6 changes: 5 additions & 1 deletion src/devices/cpu/cpu_device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use core::convert::Infallible;
use crate::{
cpu::CPUPtr, flag::AllocFlag, impl_buffer_hook_traits, impl_retriever, Alloc, Base, Buffer,
Cached, CachedModule, CloneBuf, Device, HasModules, LazySetup, MainMemory, Module,
OnDropBuffer, OnNewBuffer, Retrieve, Retriever, Setup, Shape, TapeActions,
OnDropBuffer, OnNewBuffer, Retrieve, Retriever, Setup, Shape, TapeActions, DevicelessAble,
};

pub trait IsCPU {}
Expand Down Expand Up @@ -49,6 +49,10 @@ impl<Mods: OnDropBuffer> Device for CPU<Mods> {
}
}

impl<T, S: Shape> DevicelessAble<'_, T, S> for CPU<Base> {

}

impl<Mods: OnDropBuffer> MainMemory for CPU<Mods> {
#[inline]
fn as_ptr<T, S: Shape>(ptr: &Self::Data<T, S>) -> *const T {
Expand Down
4 changes: 2 additions & 2 deletions src/devices/stack/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ mod tests {
impl<T, D> AddBuf<T, D> for CPU
where
D: MainMemory,
T: Add<Output = T> + Clone + 'static,
T: Add<Output = T> + Clone,
{
fn add(&self, lhs: &Buffer<T, D>, rhs: &Buffer<T, D>) -> Buffer<T, Self> {
let len = core::cmp::min(lhs.len(), rhs.len());
Expand All @@ -65,7 +65,7 @@ mod tests {
where
Stack: Alloc<T>,
D: MainMemory,
T: Add<Output = T> + Clone + 'static,
T: Add<Output = T> + Copy + Default,
{
fn add(&self, lhs: &Buffer<T, D, S>, rhs: &Buffer<T, D, S>) -> Buffer<T, Self, S> {
let mut out = self.retrieve(S::LEN, (lhs, rhs));
Expand Down
19 changes: 3 additions & 16 deletions src/devices/stack/stack_device.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use core::{convert::Infallible, marker::PhantomData};
use core::convert::Infallible;

use crate::{
flag::AllocFlag, shape::Shape, Alloc, Base, Buffer, CloneBuf, Device, DevicelessAble,
MainMemory, OnDropBuffer, Read, StackArray, WriteBuf, impl_retriever, impl_buffer_hook_traits, Retriever,
MainMemory, OnDropBuffer, Read, StackArray, WriteBuf, impl_retriever, impl_buffer_hook_traits,
};

/// A device that allocates memory on the stack.
Expand All @@ -20,20 +20,7 @@ impl Stack {
}

impl_buffer_hook_traits!(Stack);

impl<T, Mods: OnDropBuffer> Retriever<T> for Stack<Mods> {
fn retrieve<S, const NUM_PARENTS: usize>(
&self,
len: usize,
parents: impl crate::Parents<NUM_PARENTS>,
) -> Buffer<T, Self, S>
where
T: 'static,
S: Shape {
todo!()
}
}
// impl_retriever!(Stack);
impl_retriever!(Stack, Copy + Default);


impl<'a, T: Copy + Default, S: Shape> DevicelessAble<'a, T, S> for Stack {}
Expand Down
8 changes: 3 additions & 5 deletions src/features.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,23 @@ pub trait Feature: OnDropBuffer {}
// how to fix this:
// add retrieved buffer to no grads pool at the end of the chain (at device level (Retriever trait))
// => "generator", "actor"
pub trait Retrieve<D, G>: OnDropBuffer {
pub trait Retrieve<D, T>: OnDropBuffer {
// "generator"
#[track_caller]
fn retrieve<T, S, const NUM_PARENTS: usize>(
fn retrieve<S, const NUM_PARENTS: usize>(
&self,
device: &D,
len: usize,
parents: impl Parents<NUM_PARENTS>,
) -> D::Data<T, S>
where
T: 'static, // if 'static causes any problems -> put T to => Retrieve<D, T>?
S: Shape,
D: Device + Alloc<T>;

// "actor"
#[inline]
fn on_retrieve_finish<T, S: Shape>(&self, _retrieved_buf: &Buffer<T, D, S>)
fn on_retrieve_finish<S: Shape>(&self, _retrieved_buf: &Buffer<T, D, S>)
where
T: 'static,
D: Alloc<T>,
{
}
Expand Down
8 changes: 3 additions & 5 deletions src/modules/autograd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,29 +119,27 @@ impl<Mods: Setup<NewDev>, NewDev> Setup<NewDev> for Autograd<Mods> {
}
}

impl<G, Mods: Retrieve<D, G>, D> Retrieve<D, G> for Autograd<Mods>
impl<T:'static, Mods: Retrieve<D, T>, D> Retrieve<D, T> for Autograd<Mods>
where
D: PtrConv + Device + 'static,
{
#[inline]
fn retrieve<T, S, const NUM_PARENTS: usize>(
fn retrieve<S, const NUM_PARENTS: usize>(
&self,
device: &D,
len: usize,
parents: impl Parents<NUM_PARENTS>,
) -> <D>::Data<T, S>
where
D: Alloc<T>,
T: 'static,
S: crate::Shape,
{
self.modules.retrieve(device, len, parents)
}

#[inline]
fn on_retrieve_finish<T, S: Shape>(&self, retrieved_buf: &Buffer<T, D, S>)
fn on_retrieve_finish<S: Shape>(&self, retrieved_buf: &Buffer<T, D, S>)
where
T: 'static,
D: Alloc<T>,
{
self.register_no_grad_buf(retrieved_buf);
Expand Down
4 changes: 2 additions & 2 deletions src/modules/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ impl<T, D: Device, S: Shape> OnNewBuffer<T, D, S> for Base {}

impl OnDropBuffer for Base {}

impl<D, G> Retrieve<D, G> for Base {
impl<D, T> Retrieve<D, T> for Base {
#[inline]
fn retrieve<T, S, const NUM_PARENTS: usize>(
fn retrieve<S, const NUM_PARENTS: usize>(
&self,
device: &D,
len: usize,
Expand Down
9 changes: 4 additions & 5 deletions src/modules/cached.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,11 @@ impl<Mods: OnDropBuffer, SD: Device> OnDropBuffer for CachedModule<Mods, SD> {
}

// TODO: a more general OnDropBuffer => "Module"
impl<G, Mods: Retrieve<D, G>, D: Device + PtrConv<SimpleDevice>, SimpleDevice: Device + PtrConv<D>>
Retrieve<D, G> for CachedModule<Mods, SimpleDevice>
impl<T, Mods: Retrieve<D, T>, D: Device + PtrConv<SimpleDevice>, SimpleDevice: Device + PtrConv<D>>
Retrieve<D, T> for CachedModule<Mods, SimpleDevice>
{
#[inline]
fn retrieve<T, S: Shape, const NUM_PARENTS: usize>(
fn retrieve<S: Shape, const NUM_PARENTS: usize>(
&self,
device: &D,
len: usize,
Expand All @@ -80,9 +80,8 @@ impl<G, Mods: Retrieve<D, G>, D: Device + PtrConv<SimpleDevice>, SimpleDevice: D
}

#[inline]
fn on_retrieve_finish<T, S: Shape>(&self, retrieved_buf: &Buffer<T, D, S>)
fn on_retrieve_finish<S: Shape>(&self, retrieved_buf: &Buffer<T, D, S>)
where
T: 'static,
D: Alloc<T>,
{
self.modules.on_retrieve_finish(retrieved_buf)
Expand Down
8 changes: 3 additions & 5 deletions src/modules/lazy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,26 +115,24 @@ impl<Mods: TapeActions> TapeActions for Lazy<Mods> {
}
}

impl<G, Mods: Retrieve<D, G>, D: PtrConv + 'static> Retrieve<D, G> for Lazy<Mods> {
impl<T: 'static, Mods: Retrieve<D, T>, D: PtrConv + 'static> Retrieve<D, T> for Lazy<Mods> {
#[inline]
fn retrieve<T, S, const NUM_PARENTS: usize>(
fn retrieve<S, const NUM_PARENTS: usize>(
&self,
device: &D,
len: usize,
parents: impl Parents<NUM_PARENTS>,
) -> <D>::Data<T, S>
where
T: 'static,
S: Shape,
D: Alloc<T>,
{
self.mods.retrieve(device, len, parents)
}

#[inline]
fn on_retrieve_finish<T, S: Shape>(&self, retrieved_buf: &Buffer<T, D, S>)
fn on_retrieve_finish<S: Shape>(&self, retrieved_buf: &Buffer<T, D, S>)
where
T: 'static,
D: Alloc<T>,
{
unsafe { register_buf(&mut self.outs.borrow_mut(), retrieved_buf) };
Expand Down
4 changes: 2 additions & 2 deletions src/two_way_ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,8 @@ mod tests {

let a = f(4f32.to_val(), 3f32.to_val());

// roughly_eq_slices(&[a.eval()], &[22.2]);
assert_eq!(a.eval(), 22.2);
roughly_eq_slices(&[a.eval()], &[22.2]);
// assert_eq!(a.eval(), 22.2);

let r = f("x".to_marker(), "y".to_marker()).to_cl_source();
assert_eq!("(((x + y) * 3.6) - y)", r);
Expand Down
6 changes: 3 additions & 3 deletions tests/buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,13 @@ fn test_cached_cpu() {
let mut prev_ptr = None;

for _ in 0..100 {
let mut buf: Buffer<f32, _> = device.retrieve::<(), 0>(10, ());
let buf: Buffer<f32, _> = device.retrieve::<(), 0>(10, ());

if prev_ptr.is_some() {
assert_eq!(prev_ptr, Some(buf.data));
assert_eq!(prev_ptr, Some(buf.data.ptr));
}

prev_ptr = Some(buf.data);
prev_ptr = Some(buf.data.ptr);

}
}
Expand Down
15 changes: 9 additions & 6 deletions tests/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@ use std::ptr::null_mut;

#[cfg(feature = "cpu")]
#[cfg(not(feature = "realloc"))]
use custos::{range, Buffer, CPU};
use custos::{Buffer, CPU};

#[cfg(feature = "cpu")]
#[cfg(not(feature = "realloc"))]
#[track_caller]
fn cached_add<'a>(device: &'a CPU, a: &[f32], b: &[f32]) -> Buffer<'a, f32, CPU> {
let mut out = custos::Device::retrieve::<f32, ()>(device, 10, ());
use custos::Retriever;

let mut out = device.retrieve(10, ());

for i in 0..out.len() {
out[i] = a[i] + b[i];
}
Expand All @@ -19,21 +23,20 @@ fn cached_add<'a>(device: &'a CPU, a: &[f32], b: &[f32]) -> Buffer<'a, f32, CPU>
#[cfg(not(feature = "realloc"))]
#[test]
fn test_caching_cpu() {
use custos::Base;

let device = CPU::<Base>::new();

let a = Buffer::<f32, _>::new(&device, 100);
let b = Buffer::<f32, _>::new(&device, 100);

let mut old_ptr = null_mut();

for _ in range(100) {
for _ in 0..100 {
let mut out = cached_add(&device, &a, &b);
if out.host_ptr() != old_ptr && !old_ptr.is_null() {
panic!("Should be the same pointer!");
}
old_ptr = out.host_ptr_mut();
let len = device.addons.cache.borrow().nodes.len();
//let len = CPU_CACHE.with(|cache| cache.borrow().nodes.len());
assert_eq!(len, 3);
}
}
37 changes: 2 additions & 35 deletions tests/caller.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::{
ops::Add,
};

use custos::{range, Buffer, CacheReturn, Device, CPU};
use custos::prelude::*;

#[derive(Debug, Default, Clone)]
pub struct Call {
Expand All @@ -27,7 +27,7 @@ impl Add for &Call {
}
}

pub fn add<'a, T: Add<Output = T> + Copy>(
pub fn add<'a, T: Add<Output = T> + Copy + 'static>(
device: &'a CPU,
lhs: &Buffer<T>,
rhs: &Buffer<T>,
Expand All @@ -41,36 +41,3 @@ pub fn add<'a, T: Add<Output = T> + Copy>(

out
}

#[test]
fn test_caller() {
let device = CPU::<Base>::new();

let lhs = device.buffer([1, 2, 3, 4]);
let rhs = device.buffer([1, 2, 3, 4]);

for _ in range(100) {
add(&device, &lhs, &rhs);
}

assert_eq!(device.cache().nodes.len(), 3);

for _ in 0..100 {
add(&device, &lhs, &rhs);
}

assert_eq!(device.cache().nodes.len(), 102);

let cell = RefCell::new(10);

let x = cell.borrow();
// cell.borrow_mut();

let caller = Call::default();
caller.call();

let _ = &caller + &Call::default();

let loc = caller.location;
println!("location: {loc:?}");
}
2 changes: 1 addition & 1 deletion tests/clear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use custos_macro::stack_cpu_test;
#[stack_cpu_test]
#[test]
fn test_clear_cpu() {
let device = CPU::<Base>::new();
let device = CPU::<custos::Base>::new();

let mut buf = Buffer::with(&device, [1., 2., 3., 4., 5., 6.]);
assert_eq!(buf.read(), [1., 2., 3., 4., 5., 6.,]);
Expand Down
4 changes: 3 additions & 1 deletion tests/clone_buf.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use custos::{Buffer, CloneBuf, CPU};
use custos::prelude::*;

#[cfg(feature = "cpu")]
#[test]
fn test_buf_clone() {
use custos::CloneBuf;

let device = CPU::<Base>::new();
let buf = Buffer::from((&device, [1., 2., 6., 2., 4.]));

Expand Down
Loading

0 comments on commit 2afadac

Please sign in to comment.