From 2afadacc0d70a3c04a952b669817e6ad08af920b Mon Sep 17 00:00:00 2001 From: elftausend <76885970+elftausend@users.noreply.github.com> Date: Tue, 15 Aug 2023 10:53:59 +0200 Subject: [PATCH] Fix more tests, add retriever impl via trait bound variable impl_retriever! macro --- src/device_traits.rs | 1 - src/devices.rs | 12 +++-- src/devices/cpu/cpu_device.rs | 6 ++- src/devices/stack/mod.rs | 4 +- src/devices/stack/stack_device.rs | 19 ++------ src/features.rs | 8 ++-- src/modules/autograd.rs | 8 ++-- src/modules/base.rs | 4 +- src/modules/cached.rs | 9 ++-- src/modules/lazy.rs | 8 ++-- src/two_way_ops/mod.rs | 4 +- tests/buffer.rs | 6 +-- tests/cache.rs | 15 +++--- tests/caller.rs | 37 +-------------- tests/clear.rs | 2 +- tests/clone_buf.rs | 4 +- tests/dealloc_dev.rs | 77 ------------------------------- tests/for.rs | 30 ------------ tests/shallow.rs | 4 ++ tests/threading/threads.rs | 2 + tests/write.rs | 4 +- 21 files changed, 61 insertions(+), 203 deletions(-) delete mode 100644 tests/dealloc_dev.rs delete mode 100644 tests/for.rs diff --git a/src/device_traits.rs b/src/device_traits.rs index 4d5eaece..8f1599cf 100644 --- a/src/device_traits.rs +++ b/src/device_traits.rs @@ -83,6 +83,5 @@ pub trait Retriever: Device { parents: impl Parents, ) -> Buffer where - T: 'static, S: Shape; } diff --git a/src/devices.rs b/src/devices.rs index ce495ed3..c8dea160 100644 --- a/src/devices.rs +++ b/src/devices.rs @@ -30,7 +30,7 @@ pub use cdatatype::*; #[cfg(all(any(feature = "cpu", feature = "stack"), feature = "macro"))] mod cpu_stack_ops; -use crate::{Alloc, Buffer, HasId, OnDropBuffer, PtrType, Shape}; +use crate::{Buffer, HasId, OnDropBuffer, PtrType, Shape}; pub trait Device: OnDropBuffer + Sized { type Data: HasId + PtrType; @@ -89,8 +89,8 @@ macro_rules! impl_buffer_hook_traits { #[macro_export] macro_rules! impl_retriever { - ($device:ident) => { - impl> crate::Retriever for $device { + ($device:ident, $($trait_bounds:tt)*) => { + impl> crate::Retriever for $device { #[inline] fn retrieve( &self, @@ -99,7 +99,7 @@ macro_rules! impl_retriever { ) -> Buffer { let data = self .modules - .retrieve::(self, len, parents); + .retrieve::(self, len, parents); let buf = Buffer { data, device: Some(self), @@ -109,4 +109,8 @@ macro_rules! impl_retriever { } } }; + + ($device:ident) => { + impl_retriever!($device, Sized); + } } diff --git a/src/devices/cpu/cpu_device.rs b/src/devices/cpu/cpu_device.rs index 1c6e5825..ea4e4a55 100644 --- a/src/devices/cpu/cpu_device.rs +++ b/src/devices/cpu/cpu_device.rs @@ -3,7 +3,7 @@ use core::convert::Infallible; use crate::{ cpu::CPUPtr, flag::AllocFlag, impl_buffer_hook_traits, impl_retriever, Alloc, Base, Buffer, Cached, CachedModule, CloneBuf, Device, HasModules, LazySetup, MainMemory, Module, - OnDropBuffer, OnNewBuffer, Retrieve, Retriever, Setup, Shape, TapeActions, + OnDropBuffer, OnNewBuffer, Retrieve, Retriever, Setup, Shape, TapeActions, DevicelessAble, }; pub trait IsCPU {} @@ -49,6 +49,10 @@ impl Device for CPU { } } +impl DevicelessAble<'_, T, S> for CPU { + +} + impl MainMemory for CPU { #[inline] fn as_ptr(ptr: &Self::Data) -> *const T { diff --git a/src/devices/stack/mod.rs b/src/devices/stack/mod.rs index 064532fc..8b6926d4 100644 --- a/src/devices/stack/mod.rs +++ b/src/devices/stack/mod.rs @@ -48,7 +48,7 @@ mod tests { impl AddBuf for CPU where D: MainMemory, - T: Add + Clone + 'static, + T: Add + Clone, { fn add(&self, lhs: &Buffer, rhs: &Buffer) -> Buffer { let len = core::cmp::min(lhs.len(), rhs.len()); @@ -65,7 +65,7 @@ mod tests { where Stack: Alloc, D: MainMemory, - T: Add + Clone + 'static, + T: Add + Copy + Default, { fn add(&self, lhs: &Buffer, rhs: &Buffer) -> Buffer { let mut out = self.retrieve(S::LEN, (lhs, rhs)); diff --git a/src/devices/stack/stack_device.rs b/src/devices/stack/stack_device.rs index dd32eac2..bfd92d34 100644 --- a/src/devices/stack/stack_device.rs +++ b/src/devices/stack/stack_device.rs @@ -1,8 +1,8 @@ -use core::{convert::Infallible, marker::PhantomData}; +use core::convert::Infallible; use crate::{ flag::AllocFlag, shape::Shape, Alloc, Base, Buffer, CloneBuf, Device, DevicelessAble, - MainMemory, OnDropBuffer, Read, StackArray, WriteBuf, impl_retriever, impl_buffer_hook_traits, Retriever, + MainMemory, OnDropBuffer, Read, StackArray, WriteBuf, impl_retriever, impl_buffer_hook_traits, }; /// A device that allocates memory on the stack. @@ -20,20 +20,7 @@ impl Stack { } impl_buffer_hook_traits!(Stack); - -impl Retriever for Stack { - fn retrieve( - &self, - len: usize, - parents: impl crate::Parents, - ) -> Buffer - where - T: 'static, - S: Shape { - todo!() - } -} -// impl_retriever!(Stack); +impl_retriever!(Stack, Copy + Default); impl<'a, T: Copy + Default, S: Shape> DevicelessAble<'a, T, S> for Stack {} diff --git a/src/features.rs b/src/features.rs index 72151db3..8e5cde6f 100644 --- a/src/features.rs +++ b/src/features.rs @@ -16,25 +16,23 @@ pub trait Feature: OnDropBuffer {} // how to fix this: // add retrieved buffer to no grads pool at the end of the chain (at device level (Retriever trait)) // => "generator", "actor" -pub trait Retrieve: OnDropBuffer { +pub trait Retrieve: OnDropBuffer { // "generator" #[track_caller] - fn retrieve( + fn retrieve( &self, device: &D, len: usize, parents: impl Parents, ) -> D::Data where - T: 'static, // if 'static causes any problems -> put T to => Retrieve? S: Shape, D: Device + Alloc; // "actor" #[inline] - fn on_retrieve_finish(&self, _retrieved_buf: &Buffer) + fn on_retrieve_finish(&self, _retrieved_buf: &Buffer) where - T: 'static, D: Alloc, { } diff --git a/src/modules/autograd.rs b/src/modules/autograd.rs index c2d9f248..ff065943 100644 --- a/src/modules/autograd.rs +++ b/src/modules/autograd.rs @@ -119,12 +119,12 @@ impl, NewDev> Setup for Autograd { } } -impl, D> Retrieve for Autograd +impl, D> Retrieve for Autograd where D: PtrConv + Device + 'static, { #[inline] - fn retrieve( + fn retrieve( &self, device: &D, len: usize, @@ -132,16 +132,14 @@ where ) -> ::Data where D: Alloc, - T: 'static, S: crate::Shape, { self.modules.retrieve(device, len, parents) } #[inline] - fn on_retrieve_finish(&self, retrieved_buf: &Buffer) + fn on_retrieve_finish(&self, retrieved_buf: &Buffer) where - T: 'static, D: Alloc, { self.register_no_grad_buf(retrieved_buf); diff --git a/src/modules/base.rs b/src/modules/base.rs index 714ae951..73b7a81e 100644 --- a/src/modules/base.rs +++ b/src/modules/base.rs @@ -40,9 +40,9 @@ impl OnNewBuffer for Base {} impl OnDropBuffer for Base {} -impl Retrieve for Base { +impl Retrieve for Base { #[inline] - fn retrieve( + fn retrieve( &self, device: &D, len: usize, diff --git a/src/modules/cached.rs b/src/modules/cached.rs index 5147b6e6..9a640201 100644 --- a/src/modules/cached.rs +++ b/src/modules/cached.rs @@ -63,11 +63,11 @@ impl OnDropBuffer for CachedModule { } // TODO: a more general OnDropBuffer => "Module" -impl, D: Device + PtrConv, SimpleDevice: Device + PtrConv> - Retrieve for CachedModule +impl, D: Device + PtrConv, SimpleDevice: Device + PtrConv> + Retrieve for CachedModule { #[inline] - fn retrieve( + fn retrieve( &self, device: &D, len: usize, @@ -80,9 +80,8 @@ impl, D: Device + PtrConv, SimpleDevice: D } #[inline] - fn on_retrieve_finish(&self, retrieved_buf: &Buffer) + fn on_retrieve_finish(&self, retrieved_buf: &Buffer) where - T: 'static, D: Alloc, { self.modules.on_retrieve_finish(retrieved_buf) diff --git a/src/modules/lazy.rs b/src/modules/lazy.rs index 85dd37fc..33c76b41 100644 --- a/src/modules/lazy.rs +++ b/src/modules/lazy.rs @@ -115,16 +115,15 @@ impl TapeActions for Lazy { } } -impl, D: PtrConv + 'static> Retrieve for Lazy { +impl, D: PtrConv + 'static> Retrieve for Lazy { #[inline] - fn retrieve( + fn retrieve( &self, device: &D, len: usize, parents: impl Parents, ) -> ::Data where - T: 'static, S: Shape, D: Alloc, { @@ -132,9 +131,8 @@ impl, D: PtrConv + 'static> Retrieve for Lazy(&self, retrieved_buf: &Buffer) + fn on_retrieve_finish(&self, retrieved_buf: &Buffer) where - T: 'static, D: Alloc, { unsafe { register_buf(&mut self.outs.borrow_mut(), retrieved_buf) }; diff --git a/src/two_way_ops/mod.rs b/src/two_way_ops/mod.rs index 9f12db43..01cdc2ed 100644 --- a/src/two_way_ops/mod.rs +++ b/src/two_way_ops/mod.rs @@ -290,8 +290,8 @@ mod tests { let a = f(4f32.to_val(), 3f32.to_val()); - // roughly_eq_slices(&[a.eval()], &[22.2]); - assert_eq!(a.eval(), 22.2); + roughly_eq_slices(&[a.eval()], &[22.2]); + // assert_eq!(a.eval(), 22.2); let r = f("x".to_marker(), "y".to_marker()).to_cl_source(); assert_eq!("(((x + y) * 3.6) - y)", r); diff --git a/tests/buffer.rs b/tests/buffer.rs index bcc8d515..d1260358 100644 --- a/tests/buffer.rs +++ b/tests/buffer.rs @@ -127,13 +127,13 @@ fn test_cached_cpu() { let mut prev_ptr = None; for _ in 0..100 { - let mut buf: Buffer = device.retrieve::<(), 0>(10, ()); + let buf: Buffer = device.retrieve::<(), 0>(10, ()); if prev_ptr.is_some() { - assert_eq!(prev_ptr, Some(buf.data)); + assert_eq!(prev_ptr, Some(buf.data.ptr)); } - prev_ptr = Some(buf.data); + prev_ptr = Some(buf.data.ptr); } } diff --git a/tests/cache.rs b/tests/cache.rs index 9f9d8f20..dd040f30 100644 --- a/tests/cache.rs +++ b/tests/cache.rs @@ -3,12 +3,16 @@ use std::ptr::null_mut; #[cfg(feature = "cpu")] #[cfg(not(feature = "realloc"))] -use custos::{range, Buffer, CPU}; +use custos::{Buffer, CPU}; #[cfg(feature = "cpu")] #[cfg(not(feature = "realloc"))] +#[track_caller] fn cached_add<'a>(device: &'a CPU, a: &[f32], b: &[f32]) -> Buffer<'a, f32, CPU> { - let mut out = custos::Device::retrieve::(device, 10, ()); + use custos::Retriever; + + let mut out = device.retrieve(10, ()); + for i in 0..out.len() { out[i] = a[i] + b[i]; } @@ -19,6 +23,8 @@ fn cached_add<'a>(device: &'a CPU, a: &[f32], b: &[f32]) -> Buffer<'a, f32, CPU> #[cfg(not(feature = "realloc"))] #[test] fn test_caching_cpu() { + use custos::Base; + let device = CPU::::new(); let a = Buffer::::new(&device, 100); @@ -26,14 +32,11 @@ fn test_caching_cpu() { let mut old_ptr = null_mut(); - for _ in range(100) { + for _ in 0..100 { let mut out = cached_add(&device, &a, &b); if out.host_ptr() != old_ptr && !old_ptr.is_null() { panic!("Should be the same pointer!"); } old_ptr = out.host_ptr_mut(); - let len = device.addons.cache.borrow().nodes.len(); - //let len = CPU_CACHE.with(|cache| cache.borrow().nodes.len()); - assert_eq!(len, 3); } } diff --git a/tests/caller.rs b/tests/caller.rs index f1e0d076..09328b6c 100644 --- a/tests/caller.rs +++ b/tests/caller.rs @@ -3,7 +3,7 @@ use std::{ ops::Add, }; -use custos::{range, Buffer, CacheReturn, Device, CPU}; +use custos::prelude::*; #[derive(Debug, Default, Clone)] pub struct Call { @@ -27,7 +27,7 @@ impl Add for &Call { } } -pub fn add<'a, T: Add + Copy>( +pub fn add<'a, T: Add + Copy + 'static>( device: &'a CPU, lhs: &Buffer, rhs: &Buffer, @@ -41,36 +41,3 @@ pub fn add<'a, T: Add + Copy>( out } - -#[test] -fn test_caller() { - let device = CPU::::new(); - - let lhs = device.buffer([1, 2, 3, 4]); - let rhs = device.buffer([1, 2, 3, 4]); - - for _ in range(100) { - add(&device, &lhs, &rhs); - } - - assert_eq!(device.cache().nodes.len(), 3); - - for _ in 0..100 { - add(&device, &lhs, &rhs); - } - - assert_eq!(device.cache().nodes.len(), 102); - - let cell = RefCell::new(10); - - let x = cell.borrow(); - // cell.borrow_mut(); - - let caller = Call::default(); - caller.call(); - - let _ = &caller + &Call::default(); - - let loc = caller.location; - println!("location: {loc:?}"); -} diff --git a/tests/clear.rs b/tests/clear.rs index 63e8f59d..eda6a2e3 100644 --- a/tests/clear.rs +++ b/tests/clear.rs @@ -6,7 +6,7 @@ use custos_macro::stack_cpu_test; #[stack_cpu_test] #[test] fn test_clear_cpu() { - let device = CPU::::new(); + let device = CPU::::new(); let mut buf = Buffer::with(&device, [1., 2., 3., 4., 5., 6.]); assert_eq!(buf.read(), [1., 2., 3., 4., 5., 6.,]); diff --git a/tests/clone_buf.rs b/tests/clone_buf.rs index a4ca9978..2e0d0e47 100644 --- a/tests/clone_buf.rs +++ b/tests/clone_buf.rs @@ -1,8 +1,10 @@ -use custos::{Buffer, CloneBuf, CPU}; +use custos::prelude::*; #[cfg(feature = "cpu")] #[test] fn test_buf_clone() { + use custos::CloneBuf; + let device = CPU::::new(); let buf = Buffer::from((&device, [1., 2., 6., 2., 4.])); diff --git a/tests/dealloc_dev.rs b/tests/dealloc_dev.rs deleted file mode 100644 index f93966ce..00000000 --- a/tests/dealloc_dev.rs +++ /dev/null @@ -1,77 +0,0 @@ -use custos::prelude::*; - -#[cfg(feature = "cpu")] -#[test] -fn test_rc_get_dev() { - { - let device = CPU::::new(); - let mut a = Buffer::from((&device, [1., 2., 3., 4., 5., 6.])); - - for _ in range(100) { - a.clear(); - assert_eq!(&[0.; 6], a.as_slice()); - } - } -} - -#[cfg(feature = "opencl")] -#[test] -fn test_dealloc_cl() -> custos::Result<()> { - let device = OpenCL::new(0)?; - - let a = Buffer::from((&device, [1f32, 2., 3., 4., 5., 6.])); - let b = Buffer::from((&device, [6., 5., 4., 3., 2., 1.])); - - drop(a); - drop(b); - drop(device); - - Ok(()) -} - -#[cfg(feature = "cpu")] -#[cfg(not(feature = "realloc"))] -#[test] -fn test_dealloc_device_cache_cpu() { - let device = CPU::::new(); - - assert_eq!(device.cache().nodes.len(), 0); - let a = device.retrieve::(10, ()); - assert_eq!(device.cache().nodes.len(), 1); - - drop(a); - drop(device); - //assert_eq!(device.cache.borrow().nodes.len(), 0); -} - -#[cfg(not(feature = "realloc"))] -#[cfg(feature = "opencl")] -#[test] -fn test_dealloc_device_cache_cl() -> custos::Result<()> { - let device = OpenCL::new(0)?; - - assert_eq!(device.cache().nodes.len(), 0); - let a = device.retrieve::(10, ()); - assert_eq!(device.cache().nodes.len(), 1); - - drop(a); - drop(device); - Ok(()) -} - -#[cfg(not(feature = "realloc"))] -#[cfg(feature = "cuda")] -#[test] -fn test_dealloc_device_cache_cu() -> custos::Result<()> { - use custos::CUDA; - - let device = CUDA::new(0)?; - - assert_eq!(device.cache().nodes.len(), 0); - let a = device.retrieve::(10, ()); - assert_eq!(device.cache().nodes.len(), 1); - - drop(a); - drop(device); - Ok(()) -} diff --git a/tests/for.rs b/tests/for.rs deleted file mode 100644 index 55997410..00000000 --- a/tests/for.rs +++ /dev/null @@ -1,30 +0,0 @@ -use custos::range; - -#[test] -fn test_range() { - let mut count = 0; - for epoch in range(10) { - assert_eq!(epoch, count); - count += 1; - } -} - -#[test] -fn test_range1() { - let mut count = 0; - for epoch in range(0..10) { - assert_eq!(epoch, count); - count += 1; - assert!(epoch < 10) - } -} - -#[test] -fn test_range_inclusive() { - let mut count = 0; - for epoch in range(0..=10) { - assert_eq!(epoch, count); - count += 1; - assert!(epoch < 11) - } -} diff --git a/tests/shallow.rs b/tests/shallow.rs index 8073f39f..9017ba79 100644 --- a/tests/shallow.rs +++ b/tests/shallow.rs @@ -3,6 +3,8 @@ use custos::{Buffer, CPU}; #[cfg(feature = "cpu")] #[test] fn test_shallow_buf_copy() { + use custos::Base; + let device = CPU::::new(); let buf = Buffer::from((&device, [1, 2, 3, 4, 5])); @@ -29,6 +31,8 @@ fn test_shallow_buf_realloc() { #[cfg(not(feature = "realloc"))] #[test] fn test_shallow_buf_realloc() { + use custos::Base; + let device = CPU::::new(); let buf = Buffer::from((&device, [1, 2, 3, 4, 5])); diff --git a/tests/threading/threads.rs b/tests/threading/threads.rs index ce8e9809..1053eea5 100644 --- a/tests/threading/threads.rs +++ b/tests/threading/threads.rs @@ -3,6 +3,8 @@ use custos::{Buffer, CPU}; #[cfg(feature = "cpu")] #[test] fn test_with_threads() { + use custos::Base; + let device = CPU::::new(); //let buf = Buffer::from((&device, [1, 2, 3, 4])); diff --git a/tests/write.rs b/tests/write.rs index 2856be0e..911a3fb2 100644 --- a/tests/write.rs +++ b/tests/write.rs @@ -8,7 +8,7 @@ use custos_macro::stack_cpu_test; #[stack_cpu_test] #[test] fn test_write_cpu() { - let device = CPU::::new(); + let device = CPU::::new(); let mut buf: Buffer<_, _, custos::Dim1<5>> = Buffer::new(&device, 5); device.write(&mut buf, &[1., 2., 3., 4., 5.]); assert_eq!(buf.as_slice(), &[1., 2., 3., 4., 5.]) @@ -17,7 +17,7 @@ fn test_write_cpu() { #[cfg(feature = "cpu")] #[test] fn test_write_buf_cpu() { - use custos::{Buffer, WriteBuf, CPU}; + use custos::{Buffer, WriteBuf, CPU, Base}; let device = CPU::::new();