From ac76de9250ab3e40d5d36818884ee4b6d02d88a2 Mon Sep 17 00:00:00 2001 From: elftausend <76885970+elftausend@users.noreply.github.com> Date: Mon, 14 Aug 2023 23:01:59 +0200 Subject: [PATCH] Move changes made in module_comb into main structure --- Cargo.toml | 2 +- src/buffer.rs | 281 ++++++++--------- src/buffer/impl_from.rs | 35 +-- src/buffer/impl_from_const.rs | 16 +- src/buffer/num.rs | 69 +++-- src/cache.rs | 134 ++++++++ src/cache/borrow_cache.rs | 190 ++++++++++++ src/cache/location_hasher.rs | 73 +++++ src/cache/nohasher.rs | 31 ++ src/count.rs | 173 ----------- src/device_traits.rs | 88 ++++++ src/devices.rs | 112 +++++++ src/devices/addons.rs | 124 -------- src/devices/borrowing_cache.rs | 123 -------- src/devices/cache.rs | 321 -------------------- src/devices/caller_cache.rs | 166 ---------- src/devices/cpu/cpu_device.rs | 128 ++++---- src/devices/cpu/mod.rs | 5 +- src/devices/cpu/ops.rs | 33 +- src/devices/cpu_stack_ops.rs | 11 +- src/devices/ident.rs | 61 ---- src/devices/mod.rs | 155 ---------- src/exec_on_cpu.rs | 28 +- src/features.rs | 93 ++++++ src/{module_comb => }/hooks.rs | 4 +- src/{module_comb => }/id.rs | 0 src/lib.rs | 217 +++---------- src/module_comb/buffer.rs | 33 ++ src/module_comb/devices/cpu.rs | 16 +- src/module_comb/devices/cuda.rs | 2 +- src/module_comb/devices/mod.rs | 56 ---- src/module_comb/features.rs | 76 ----- src/module_comb/mod.rs | 61 ---- src/module_comb/modules/autograd.rs | 2 +- src/module_comb/modules/autograd/tape.rs | 2 +- src/module_comb/modules/cached.rs | 2 +- src/modules/autograd.rs | 369 +++++++++++++++++++++++ src/modules/autograd/gradients.rs | 214 +++++++++++++ src/modules/autograd/tape.rs | 72 +++++ src/modules/base.rs | 59 ++++ src/modules/cached.rs | 254 ++++++++++++++++ src/modules/graph.rs | 4 + src/modules/lazy.rs | 180 +++++++++++ src/modules/mod.rs | 14 + src/op_traits.rs | 4 +- src/parents.rs | 78 +++++ src/{module_comb => }/ptr_conv.rs | 8 +- src/shape.rs | 14 +- src/two_way_ops/mod.rs | 8 +- src/unary.rs | 2 +- 50 files changed, 2394 insertions(+), 1809 deletions(-) create mode 100644 src/cache.rs create mode 100644 src/cache/borrow_cache.rs create mode 100644 src/cache/location_hasher.rs create mode 100644 src/cache/nohasher.rs delete mode 100644 src/count.rs create mode 100644 src/device_traits.rs create mode 100644 src/devices.rs delete mode 100644 src/devices/addons.rs delete mode 100644 src/devices/borrowing_cache.rs delete mode 100644 src/devices/cache.rs delete mode 100644 src/devices/caller_cache.rs delete mode 100644 src/devices/ident.rs delete mode 100644 src/devices/mod.rs create mode 100644 src/features.rs rename src/{module_comb => }/hooks.rs (55%) rename src/{module_comb => }/id.rs (100%) create mode 100644 src/modules/autograd.rs create mode 100644 src/modules/autograd/gradients.rs create mode 100644 src/modules/autograd/tape.rs create mode 100644 src/modules/base.rs create mode 100644 src/modules/cached.rs create mode 100644 src/modules/graph.rs create mode 100644 src/modules/lazy.rs create mode 100644 src/modules/mod.rs create mode 100644 src/parents.rs rename src/{module_comb => }/ptr_conv.rs (71%) diff --git a/Cargo.toml b/Cargo.toml index b643cbb5..aa49df4e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,7 +35,7 @@ min-cl = { version = "0.2.0", optional=true } [features] #default = ["no-std"] -default = ["blas", "cpu", "stack", "static-api", "macro", "autograd", "cuda"] +default = ["blas", "cpu", "macro"] #default = ["stack", "macro", "cpu", "blas", "opencl", "static-api", "autograd"] #default = ["stack", "cpu", "blas", "static-api", "opencl", "macro"] cpu = [] diff --git a/src/buffer.rs b/src/buffer.rs index f29cfa2f..11eca8f5 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -7,8 +7,9 @@ use crate::cpu::{CPUPtr, CPU}; use crate::CPU; use crate::{ - flag::AllocFlag, shape::Shape, Alloc, ClearBuf, CloneBuf, CommonPtrs, Device, DevicelessAble, - Ident, IsShapeIndep, MainMemory, PtrType, Read, ShallowCopy, WriteBuf, + flag::AllocFlag, shape::Shape, Alloc, Base, ClearBuf, CloneBuf, CommonPtrs, Device, + DevicelessAble, HasId, IsShapeIndep, MainMemory, OnNewBuffer, PtrType, Read, ShallowCopy, + WriteBuf, }; pub use self::num::Num; @@ -34,20 +35,13 @@ mod num; /// buffer_f32_cpu(&buf); /// buffer_generic(&buf); /// ``` -pub struct Buffer<'a, T = f32, D: Device = CPU, S: Shape = ()> { +pub struct Buffer<'a, T = f32, D: Device = CPU, S: Shape = ()> { /// the type of pointer - pub ptr: D::Ptr, + pub data: D::Data, /// A reference to the corresponding device. Mainly used for operations without a device parameter. pub device: Option<&'a D>, - /// Used as a cache and autograd identifier. - #[cfg(not(feature = "no-std"))] - pub ident: Option, } -unsafe impl<'a, T, D: Device, S: Shape> Send for Buffer<'a, T, D, S> {} - -unsafe impl<'a, T, D: Device, S: Shape> Sync for Buffer<'a, T, D, S> {} - impl<'a, T, D: Device, S: Shape> Buffer<'a, T, D, S> { /// Creates a zeroed (or values set to default) `Buffer` with the given length on the specified device. /// This `Buffer` can't outlive the device specified as a parameter. @@ -67,25 +61,92 @@ impl<'a, T, D: Device, S: Shape> Buffer<'a, T, D, S> { /// /// ``` #[inline] - pub fn new(device: &'a D, len: usize) -> Buffer<'a, T, D, S> + pub fn new(device: &'a D, len: usize) -> Self where - D: Alloc<'a, T, S>, /*+ GraphReturn*/ + D: OnNewBuffer + Alloc, { - let ptr = device.alloc(len, AllocFlag::None); - - #[cfg(not(feature = "no-std"))] - let ident = device.add_to_cache(&ptr); + let data = device.alloc(len, crate::flag::AllocFlag::None); + Buffer::from_new_alloc(device, data) + } - Buffer { - ptr, + #[inline] + fn from_new_alloc(device: &'a D, data: D::Data) -> Self + where + D: OnNewBuffer, + { + let buf = Buffer { + data, device: Some(device), - // TODO: enable, if leafs get more important - //node: device.graph().add_leaf(len), - #[cfg(not(feature = "no-std"))] - ident, + }; + + // mind: on_new_buffer must be called for user buffers! + device.on_new_buffer(device, &buf); + buf + } +} + +impl<'a, T, D: Device, S: Shape> HasId for Buffer<'a, T, D, S> { + #[inline] + fn id(&self) -> super::Id { + self.data.id() + } +} + +impl<'a, T, D: Device, S: Shape> Drop for Buffer<'a, T, D, S> { + #[inline] + fn drop(&mut self) { + if self.data.flag() != AllocFlag::None { + return; } + + if let Some(device) = self.device { + device.on_drop_buffer(device, self) + } + } +} + +impl<'a, T, D: Device + OnNewBuffer, S: Shape> Buffer<'a, T, D, S> { + /// Creates a new `Buffer` from a slice (&[T]). + #[inline] + pub fn from_slice(device: &'a D, slice: &[T]) -> Self + where + T: Clone, + D: Alloc, + { + let data = device.alloc_from_slice(slice); + Buffer::from_new_alloc(device, data) + } + + /// Creates a new `Buffer` from a `Vec`. + #[cfg(not(feature = "no-std"))] + #[inline] + pub fn from_vec(device: &'a D, data: Vec) -> Self + where + T: Clone, + D: Alloc, + { + let data = device.alloc_from_vec(data); + Buffer::from_new_alloc(device, data) + } + + /// Creates a new `Buffer` from an nd-array. + /// The dimension is defined by the [`Shape`]. + #[inline] + pub fn from_array(device: &'a D, array: S::ARR) -> Buffer + where + T: Clone, + D: Alloc, + { + let data = device.alloc_from_array(array); + Buffer::from_new_alloc(device, data) } +} + +unsafe impl<'a, T, D: Device, S: Shape> Send for Buffer<'a, T, D, S> {} +unsafe impl<'a, T, D: Device, S: Shape> Sync for Buffer<'a, T, D, S> {} + +impl<'a, T, D: Device, S: Shape> Buffer<'a, T, D, S> { /// Buffers created with this method can outlive the device used to create this `Buffer`.
/// No operations can be performed on this `Buffer` without a device parameter. /// # Examples @@ -109,9 +170,7 @@ impl<'a, T, D: Device, S: Shape> Buffer<'a, T, D, S> { D: DevicelessAble<'b, T, S>, { Buffer { - ptr: device.alloc(len, AllocFlag::None), - #[cfg(not(feature = "no-std"))] - ident: None, + data: device.alloc(len, AllocFlag::None), device: None, } } @@ -199,7 +258,7 @@ impl<'a, T, D: Device, S: Shape> Buffer<'a, T, D, S> { /// ``` #[inline] pub fn len(&self) -> usize { - self.ptr.size() + self.data.size() } /// Creates a shallow copy of &self. @@ -211,13 +270,11 @@ impl<'a, T, D: Device, S: Shape> Buffer<'a, T, D, S> { #[inline] pub unsafe fn shallow(&self) -> Buffer<'a, T, D, S> where - ::Ptr: ShallowCopy, + ::Data: ShallowCopy, { Buffer { - ptr: self.ptr.shallow(), + data: self.data.shallow(), device: self.device, - #[cfg(not(feature = "no-std"))] - ident: self.ident, } } @@ -230,7 +287,7 @@ impl<'a, T, D: Device, S: Shape> Buffer<'a, T, D, S> { /// Furthermore, the resulting `Buffer` can outlive `self`. pub unsafe fn shallow_or_clone(&self) -> Buffer<'a, T, D, S> where - ::Ptr: ShallowCopy, + ::Data: ShallowCopy, T: Clone, D: CloneBuf<'a, T, S>, { @@ -243,20 +300,6 @@ impl<'a, T, D: Device, S: Shape> Buffer<'a, T, D, S> { self.clone() } - /// Returns the [`Ident`] of a `Buffer`. - /// A `Buffer` receives an id, if it is useable for caching, graph optimization or autograd. - /// Panics, if `Buffer` hasn't an id. - #[inline] - pub fn id(&self) -> Ident { - #[cfg(feature = "no-std")] - { - unimplemented!("This buffer has no trackable id. Who?: e.g. 'Stack' Buffer, Buffers created via Buffer::from_raw_host..(..), `Num` (scalar) Buffer") - } - - #[cfg(not(feature = "no-std"))] - self.ident.expect("This buffer has no trackable id. Who?: e.g. 'Stack' Buffer, Buffers created via Buffer::from_raw_host..(..), `Num` (scalar) Buffer") - } - /// Sets all elements in `Buffer` to the default value. pub fn clear(&mut self) where @@ -266,22 +309,6 @@ impl<'a, T, D: Device, S: Shape> Buffer<'a, T, D, S> { } } -impl<'a, T, D: Device, S: Shape> Drop for Buffer<'a, T, D, S> { - #[inline] - fn drop(&mut self) { - if self.ptr.flag() != AllocFlag::None { - return; - } - - #[cfg(not(feature = "no-std"))] - if let Some(device) = self.device { - if let Some(ident) = self.ident { - device.remove(ident) - } - } - } -} - // TODO better solution for the to_dims stack problem? impl<'a, T, D: Device, S: Shape> Buffer<'a, T, D, S> { /// Converts a non stack allocated `Buffer` with shape `S` to a `Buffer` with shape `O`. @@ -299,17 +326,15 @@ impl<'a, T, D: Device, S: Shape> Buffer<'a, T, D, S> { pub fn to_dims(self) -> Buffer<'a, T, D, O> where D: crate::ToDim, - D::Ptr: ShallowCopy, + D::Data: ShallowCopy, { let buf = ManuallyDrop::new(self); - let ptr = buf.device().to_dim(unsafe { buf.ptr.shallow() }); + let data = buf.device().to_dim(unsafe { buf.data.shallow() }); Buffer { - ptr, + data, device: buf.device, - #[cfg(not(feature = "no-std"))] - ident: buf.ident, } } } @@ -335,18 +360,18 @@ impl<'a, T, D: IsShapeIndep, S: Shape> Buffer<'a, T, D, S> { impl<'a, T, D: Device, S: Shape> Buffer<'a, T, D, S> where - D::Ptr: CommonPtrs, + D::Data: CommonPtrs, { #[inline] /// Returns all types of pointers. (host, OpenCL, CUDA) pub fn ptrs(&self) -> (*const T, *mut c_void, u64) { - self.ptr.ptrs() + self.data.ptrs() } #[inline] /// Returns all types of pointers. (host, OpenCL, CUDA) pub fn ptrs_mut(&mut self) -> (*mut T, *mut c_void, u64) { - self.ptr.ptrs_mut() + self.data.ptrs_mut() } } @@ -365,73 +390,6 @@ impl<'a, T, D: Device> Buffer<'a, T, D> { } } -impl<'a, T, D: Device, S: Shape> Buffer<'a, T, D, S> { - /// Creates a new `Buffer` from a slice (&[T]). - /// The pointer of the allocation may be added to the cache of the device. - /// Usually, this pointer / `Buffer` is then returned by a `device.get_existing_buf(..)` (accesses the cache) call. - #[inline] - pub fn from_slice(device: &'a D, slice: &[T]) -> Self - where - T: Clone, - D: Alloc<'a, T, S>, - { - let ptr = device.with_slice(slice); - - #[cfg(not(feature = "no-std"))] - let ident = device.add_to_cache(&ptr); - - Buffer { - ptr, - #[cfg(not(feature = "no-std"))] - ident, - device: Some(device), - } - } - - /// Creates a new `Buffer` from a `Vec`. - /// The pointer of the allocation may be added to the cache of the device. - /// Usually, this pointer / `Buffer` is then returned by a `device.get_existing_buf(..)` call. - #[cfg(not(feature = "no-std"))] - #[inline] - pub fn from_vec(device: &'a D, data: Vec) -> Self - where - T: Clone, - D: Alloc<'a, T, S>, - { - let ptr = device.alloc_with_vec(data); - let ident = device.add_to_cache(&ptr); - - Buffer { - ptr, - ident, - device: Some(device), - } - } - - /// Creates a new `Buffer` from an nd-array. - /// The dimension is defined by the [`Shape`]. - /// The pointer of the allocation may be added to the cache of the device. - /// Usually, this pointer / `Buffer` is then returned by a `device.get_existing_buf(..)` call. - #[inline] - pub fn from_array(device: &'a D, array: S::ARR) -> Buffer - where - T: Clone, - D: Alloc<'a, T, S>, - { - let ptr = device.with_array(array); - - #[cfg(not(feature = "no-std"))] - let ident = device.add_to_cache(&ptr); - - Buffer { - ptr, - #[cfg(not(feature = "no-std"))] - ident, - device: Some(device), - } - } -} - #[cfg(feature = "cpu")] impl<'a, T, S: Shape> Buffer<'a, T, CPU, S> { /// Constructs a deviceless `Buffer` out of a host pointer and a length. @@ -458,9 +416,8 @@ impl<'a, T, S: Shape> Buffer<'a, T, CPU, S> { #[inline] pub unsafe fn from_raw_host(ptr: *mut T, len: usize) -> Buffer<'a, T, CPU, S> { Buffer { - ptr: CPUPtr::from_ptr(ptr, len, AllocFlag::Wrapper), + data: CPUPtr::from_ptr(ptr, len, AllocFlag::Wrapper), device: None, - ident: None, } } @@ -477,9 +434,8 @@ impl<'a, T, S: Shape> Buffer<'a, T, CPU, S> { len: usize, ) -> Buffer<'a, T, CPU, S> { Buffer { - ptr: CPUPtr::from_ptr(ptr, len, AllocFlag::Wrapper), + data: CPUPtr::from_ptr(ptr, len, AllocFlag::Wrapper), device: Some(device), - ident: None, } } } @@ -490,7 +446,7 @@ impl<'a, T, S: Shape> Buffer<'a, T, crate::OpenCL, S> { #[inline] pub fn cl_ptr(&self) -> *mut c_void { assert!( - !self.ptr.ptr.is_null(), + !self.data.ptr.is_null(), "called cl_ptr() on an invalid OpenCL buffer" ); self.ptrs().1 @@ -499,7 +455,7 @@ impl<'a, T, S: Shape> Buffer<'a, T, crate::OpenCL, S> { #[cfg(feature = "cuda")] impl<'a, T> Buffer<'a, T, crate::CUDA> { - // TODO: replace buf.ptr.2 with this fn, do the same with cl, cpu + // TODO: replace buf.data.2 with this fn, do the same with cl, cpu /// Returns a non null CUDA pointer #[inline] pub fn cu_ptr(&self) -> u64 { @@ -507,7 +463,7 @@ impl<'a, T> Buffer<'a, T, crate::CUDA> { self.ptrs().2 != 0, "called cu_ptr() on an invalid CUDA buffer" ); - self.ptr.ptr + self.data.ptr } } @@ -525,9 +481,10 @@ impl<'a, T, D: MainMemory, S: Shape> Buffer<'a, T, D, S> { } } +// custos v0.5 compatability impl<'a, T, D: MainMemory, S: Shape> Buffer<'a, T, D, S> where - D::Ptr: CommonPtrs, + D::Data: CommonPtrs, { /// Returns a non null host pointer #[inline] @@ -568,14 +525,12 @@ unsafe impl Sync for Buffer<'a, T> {}*/ impl<'a, T, D: Device, S: Shape> Default for Buffer<'a, T, D, S> where - D::Ptr: Default, + D::Data: Default, { fn default() -> Self { Self { - ptr: D::Ptr::::default(), + data: D::Data::::default(), device: None, - #[cfg(not(feature = "no-std"))] - ident: None, } } } @@ -594,7 +549,7 @@ impl AsMut<[T]> for Buffer<'_, T, D> { } } -/// A `Buffer` dereferences into a slice. +/// A main memory `Buffer` dereferences into a slice. /// /// # Examples /// @@ -623,11 +578,11 @@ impl core::ops::Deref for Buffer<'_, T, D, S> { #[inline] fn deref(&self) -> &Self::Target { - unsafe { core::slice::from_raw_parts(D::as_ptr(&self.ptr), self.len()) } + unsafe { core::slice::from_raw_parts(D::as_ptr(&self.data), self.len()) } } } -/// A `Buffer` dereferences into a mutable slice. +/// A main memory `Buffer` dereferences into a mutable slice. /// /// # Examples /// @@ -652,7 +607,7 @@ impl core::ops::Deref for Buffer<'_, T, D, S> { impl core::ops::DerefMut for Buffer<'_, T, D, S> { #[inline] fn deref_mut(&mut self) -> &mut Self::Target { - unsafe { core::slice::from_raw_parts_mut(D::as_ptr_mut(&mut self.ptr), self.len()) } + unsafe { core::slice::from_raw_parts_mut(D::as_ptr_mut(&mut self.data), self.len()) } } } @@ -665,7 +620,7 @@ where T: Debug + Default + Clone + 'a, D: Read + Device + 'a, for<'b> >::Read<'b>: Debug, - D::Ptr: CommonPtrs, + D::Data: CommonPtrs, { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.debug_struct("Buffer") @@ -726,7 +681,9 @@ mod tests { #[cfg(feature = "cpu")] #[test] fn test_deref() { - let device = crate::CPU::new(); + use crate::Base; + + let device = crate::CPU::::new(); let buf: Buffer = Buffer::from((&device, [1, 2, 3, 4])); let slice = &*buf; assert_eq!(slice, &[1, 2, 3, 4]); @@ -762,7 +719,9 @@ mod tests { #[cfg(feature = "cpu")] #[test] fn test_debug_print() { - let device = crate::CPU::new(); + use crate::Base; + + let device = crate::CPU::::new(); let buf = Buffer::from((&device, [1, 2, 3, 4, 5, 6])); println!("{buf:?}",); @@ -771,9 +730,9 @@ mod tests { #[cfg(feature = "cpu")] #[test] fn test_to_dims() { - use crate::Dim2; + use crate::{Base, Dim2}; - let device = crate::CPU::new(); + let device = crate::CPU::::new(); let buf = Buffer::from((&device, [1, 2, 3, 4, 5, 6])); let buf_dim2 = buf.to_dims::>(); @@ -783,12 +742,12 @@ mod tests { #[cfg(feature = "cpu")] #[test] fn test_id_cpu() { - use crate::{Ident, CPU}; + use crate::{Base, HasId, CPU}; - let device = CPU::new(); + let device = CPU::::new(); let buf = Buffer::from((&device, [1, 2, 3, 4])); - assert_eq!(buf.id(), Ident { idx: 0, len: 4 }) + assert_eq!(buf.id(), buf.data.id()) } #[cfg(feature = "stack")] diff --git a/src/buffer/impl_from.rs b/src/buffer/impl_from.rs index 0005ac48..80e319d6 100644 --- a/src/buffer/impl_from.rs +++ b/src/buffer/impl_from.rs @@ -1,6 +1,6 @@ use core::ops::Range; -use crate::{number::Number, shape::Shape, Alloc, Buffer}; +use crate::{number::Number, shape::Shape, Alloc, Buffer, Device, OnNewBuffer, Retriever}; #[cfg(feature = "cpu")] use crate::{WriteBuf, CPU}; @@ -9,7 +9,7 @@ impl<'a, T, D, const N: usize> From<(&'a D, [T; N])> for Buffer<'a, T, D> where T: Clone, // TODO: IsShapeIndep ... find way to include Stack - D: Alloc<'a, T>, + D: Alloc + OnNewBuffer, { #[inline] fn from((device, array): (&'a D, [T; N])) -> Self { @@ -19,7 +19,7 @@ where impl<'a, T, D> From<(&'a D, usize)> for Buffer<'a, T, D> where - D: Alloc<'a, T>, + D: Alloc + OnNewBuffer, { #[inline] fn from((device, len): (&'a D, usize)) -> Self { @@ -30,7 +30,7 @@ where /*impl<'a, T, D> Buffer<'a, T, D> where T: Clone, - D: Alloc<'a, T> + D: Alloc+ OnNewBuffer { #[inline] pub fn from_iter>(device: &'a D, iter: I) -> Self { @@ -42,7 +42,7 @@ where impl<'a, T, D> From<(&'a D, Range)> for Buffer<'a, T, D> where T: Number, - D: Alloc<'a, T>, + D: Alloc + OnNewBuffer, { #[inline] fn from((device, range): (&'a D, Range)) -> Self { @@ -55,7 +55,7 @@ where impl<'a, T, D, I> From<(&'a D, I)> for Buffer<'a, T, D> where T: Number, - D: Alloc<'a, T>, + D: Alloc+ OnNewBuffer, I: IntoIterator, { #[inline] @@ -67,7 +67,7 @@ where /*impl<'a, T, D, const N: usize> From<(&'a D, [T; N])> for Buffer<'a, T, D> where T: Clone, - D: Alloc<'a, T> + IsShapeIndep, + D: Alloc+ OnNewBuffer + IsShapeIndep, { fn from((device, array): (&'a D, [T; N])) -> Self { Buffer { @@ -84,7 +84,7 @@ impl<'a, T, D, const N: usize> From<(&'a D, &[T; N])> for Buffer<'a, T, D> where T: Clone, // TODO: IsShapeIndep ... find way to include Stack - D: Alloc<'a, T>, + D: Alloc + OnNewBuffer, { #[inline] fn from((device, array): (&'a D, &[T; N])) -> Self { @@ -95,7 +95,7 @@ where /*impl<'a, T, D, const N: usize> From<(&'a D, &[T; N])> for Buffer<'a, T, D> where T: Clone, - D: Alloc<'a, T> + IsShapeIndep, + D: Alloc+ OnNewBuffer + IsShapeIndep, { fn from((device, array): (&'a D, &[T; N])) -> Self { Buffer { @@ -112,7 +112,7 @@ impl<'a, T, D, S: Shape> From<(&'a D, &[T])> for Buffer<'a, T, D, S> where T: Clone, // TODO: IsShapeIndep ... find way to include Stack - D: Alloc<'a, T, S>, + D: Alloc + OnNewBuffer, { #[inline] fn from((device, slice): (&'a D, &[T])) -> Self { @@ -123,7 +123,7 @@ where /*impl<'a, T, D, S: Shape> From<(&'a D, &[T])> for Buffer<'a, T, D, S> where T: Clone, - D: Alloc<'a, T, S> + IsShapeIndep, + D: Alloc+ OnNewBuffer + IsShapeIndep, { fn from((device, slice): (&'a D, &[T])) -> Self { Buffer { @@ -140,7 +140,7 @@ impl<'a, T, D, S: Shape> From<(&'a D, Vec)> for Buffer<'a, T, D, S> where T: Clone, // TODO: IsShapeIndep ... find way to include Stack - D: Alloc<'a, T, S>, + D: Alloc + OnNewBuffer, { #[inline] fn from((device, vec): (&'a D, Vec)) -> Self { @@ -153,7 +153,7 @@ impl<'a, T, D, S: Shape> From<(&'a D, &Vec)> for Buffer<'a, T, D, S> where T: Clone, // TODO: IsShapeIndep ... find way to include Stack - D: Alloc<'a, T, S>, + D: Alloc + OnNewBuffer, { #[inline] fn from((device, vec): (&'a D, &Vec)) -> Self { @@ -164,8 +164,9 @@ where #[cfg(feature = "cpu")] impl<'a, 'b, T, S, D> From<(&'a D, Buffer<'b, T, CPU, S>)> for Buffer<'a, T, D, S> where + T: 'static, S: Shape, - D: WriteBuf + for<'c> Alloc<'c, T, S>, + D: WriteBuf + Device + Retriever, { fn from((device, buf): (&'a D, Buffer<'b, T, CPU, S>)) -> Self { let mut out = device.retrieve(buf.len(), &buf); @@ -180,11 +181,11 @@ mod tests { #[cfg(feature = "cpu")] #[test] fn test_buf_device_conversion_cpu() { - use crate::{Buffer, Read, CPU}; + use crate::{Base, Buffer, Read, CPU}; - let device = CPU::new(); + let device = CPU::::new(); - let cpu = CPU::new(); + let cpu = CPU::::new(); let cpu_buf = Buffer::from((&cpu, [1, 2, 4, 5])); let out = Buffer::from((&device, cpu_buf)); diff --git a/src/buffer/impl_from_const.rs b/src/buffer/impl_from_const.rs index 4ba6724d..4fd0bca2 100644 --- a/src/buffer/impl_from_const.rs +++ b/src/buffer/impl_from_const.rs @@ -1,4 +1,4 @@ -use crate::{prelude::Number, shape::Shape, Alloc, Buffer, Dim1, Dim2}; +use crate::{prelude::Number, shape::Shape, Alloc, Buffer, Device, Dim1, Dim2, OnNewBuffer}; /// Trait for creating [`Buffer`]s with a [`Shape`]. The [`Shape`] is inferred from the array. pub trait WithShape { @@ -20,7 +20,7 @@ pub trait WithShape { impl<'a, T, D, const N: usize> WithShape<&'a D, [T; N]> for Buffer<'a, T, D, Dim1> where T: Number, // using Number here, because T could be an array type - D: Alloc<'a, T, Dim1>, + D: Alloc + OnNewBuffer>, { #[inline] fn with(device: &'a D, array: [T; N]) -> Self { @@ -31,7 +31,7 @@ where impl<'a, T, D, const N: usize> WithShape<&'a D, &[T; N]> for Buffer<'a, T, D, Dim1> where T: Number, - D: Alloc<'a, T, Dim1>, + D: Alloc + OnNewBuffer>, { #[inline] fn with(device: &'a D, array: &[T; N]) -> Self { @@ -43,7 +43,7 @@ impl<'a, T, D, const B: usize, const A: usize> WithShape<&'a D, [[T; A]; B]> for Buffer<'a, T, D, Dim2> where T: Number, - D: Alloc<'a, T, Dim2>, + D: Alloc + OnNewBuffer>, { #[inline] fn with(device: &'a D, array: [[T; A]; B]) -> Self { @@ -55,7 +55,7 @@ impl<'a, T, D, const B: usize, const A: usize> WithShape<&'a D, &[[T; A]; B]> for Buffer<'a, T, D, Dim2> where T: Number, - D: Alloc<'a, T, Dim2>, + D: Alloc + OnNewBuffer>, { #[inline] fn with(device: &'a D, array: &[[T; A]; B]) -> Self { @@ -65,7 +65,7 @@ where impl<'a, T, D, S: Shape> WithShape<&'a D, ()> for Buffer<'a, T, D, S> where - D: Alloc<'a, T, S>, + D: Alloc + OnNewBuffer, { fn with(device: &'a D, _: ()) -> Self { Buffer::new(device, S::LEN) @@ -77,9 +77,9 @@ mod tests { #[cfg(feature = "cpu")] #[test] fn test_with_const_dim2_cpu() { - use crate::{Buffer, WithShape, CPU}; + use crate::{Base, Buffer, WithShape, CPU}; - let device = CPU::new(); + let device = CPU::::new(); let buf = Buffer::with(&device, [[1.0, 2.0], [3.0, 4.0]]); diff --git a/src/buffer/num.rs b/src/buffer/num.rs index f64f9e60..3b9afda3 100644 --- a/src/buffer/num.rs +++ b/src/buffer/num.rs @@ -1,11 +1,15 @@ use core::{ + convert::Infallible, ffi::c_void, ops::{Deref, DerefMut}, ptr::null_mut, }; -use crate::{shape::Shape, Buffer, CloneBuf, CommonPtrs, Device, PtrType}; +use crate::{ + flag::AllocFlag, Alloc, Buffer, CloneBuf, CommonPtrs, Device, HasId, OnDropBuffer, PtrType, +}; +#[derive(Debug, Default)] /// Makes it possible to use a single number in a [`Buffer`]. pub struct Num { /// The stored number. @@ -36,25 +40,53 @@ impl CommonPtrs for Num { } } +impl HasId for Num { + fn id(&self) -> crate::Id { + todo!() + } +} + +impl From for Num { + #[inline] + fn from(num: T) -> Self { + Num { num } + } +} + impl Device for () { - type Ptr = Num; - type Cache = (); + type Data = Num; + type Error = Infallible; - fn new() -> crate::Result { + fn new() -> Result { Ok(()) } } +impl Alloc for () { + #[inline] + fn alloc(&self, _len: usize, _flag: AllocFlag) -> Self::Data { + Num::default() + } + + #[inline] + fn alloc_from_slice(&self, data: &[T]) -> Self::Data + where + T: Clone, + { + data[0].clone().into() + } +} + +impl OnDropBuffer for () {} + impl<'a, T: Clone> CloneBuf<'a, T> for () { #[inline] fn clone_buf(&self, buf: &Buffer<'a, T, Self>) -> Buffer<'a, T, Self> { Buffer { - ptr: Num { - num: buf.ptr.num.clone(), + data: Num { + num: buf.data.num.clone(), }, device: buf.device, - #[cfg(not(feature = "no-std"))] - ident: buf.ident, } } } @@ -63,10 +95,8 @@ impl From for Buffer<'_, T, ()> { #[inline] fn from(ptr: T) -> Self { Buffer { - ptr: Num { num: ptr }, + data: Num { num: ptr }, device: None, - #[cfg(not(feature = "no-std"))] - ident: None, } } } @@ -80,10 +110,8 @@ impl<'a, T> Buffer<'a, T, ()> { T: Copy, { Buffer { - ptr: Num { num: self.ptr.num }, + data: Num { num: self.data.num }, device: self.device, - #[cfg(not(feature = "no-std"))] - ident: self.ident, } } @@ -105,7 +133,7 @@ impl<'a, T> Buffer<'a, T, ()> { where T: Copy, { - self.ptr.num + self.data.num } } @@ -114,14 +142,14 @@ impl<'a, T> Deref for Buffer<'a, T, ()> { #[inline] fn deref(&self) -> &Self::Target { - &self.ptr.num + &self.data.num } } impl<'a, T> DerefMut for Buffer<'a, T, ()> { #[inline] fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.ptr.num + &mut self.data.num } } @@ -144,4 +172,11 @@ mod tests { *a += 10; assert_eq!(*a, 15); } + + #[test] + fn test_num_device() { + use crate::Device; + + let _device = <()>::new().unwrap(); + } } diff --git a/src/cache.rs b/src/cache.rs new file mode 100644 index 00000000..16b0b37c --- /dev/null +++ b/src/cache.rs @@ -0,0 +1,134 @@ +use core::{hash::BuildHasherDefault, panic::Location}; +use std::collections::HashMap; +use std::rc::Rc; + +use crate::{flag::AllocFlag, Device, Shape}; + +use super::{Alloc, PtrConv}; + +mod location_hasher; +pub use location_hasher::*; + +mod nohasher; +pub use nohasher::*; + +mod borrow_cache; +pub use borrow_cache::*; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Cache { + pub nodes: + HashMap, Rc>, BuildHasherDefault>, +} + +impl Default for Cache { + #[inline] + fn default() -> Self { + Self::new() + } +} + +impl Cache { + #[inline] + pub fn new() -> Self { + Self { + nodes: Default::default(), + } + } + + #[track_caller] + #[inline] + pub fn get>( + &mut self, + device: &D, + len: usize, + callback: fn(), + ) -> D::Data + where + SD: PtrConv, + D: PtrConv, + { + let maybe_allocated = self.nodes.get(&Location::caller().into()); + match maybe_allocated { + Some(data) => unsafe { SD::convert(&data, AllocFlag::Wrapper) }, + None => self.add_node(device, len, callback), + } + } + + #[track_caller] + pub fn add_node>( + &mut self, + device: &D, + len: usize, + callback: fn(), + ) -> D::Data + where + D: PtrConv, + { + let data = device.alloc::(len, AllocFlag::Wrapper); + + let untyped_ptr = unsafe { D::convert(&data, AllocFlag::None) }; + self.nodes + .insert(Location::caller().into(), Rc::new(untyped_ptr)); + + callback(); + + data + } +} + +#[cfg(test)] +mod tests { + use super::Cache; + use crate::{Base, CPU}; + + #[test] + fn test_cache_add_node() { + let mut cache = Cache::>::default(); + let device = CPU::::new(); + + assert_eq!(cache.nodes.len(), 0); + + let out = cache.add_node::(&device, 10, || ()); + + assert_eq!(cache.nodes.len(), 1); + assert_eq!(out.len, 10); + + let out1 = cache.get::(&device, 10, || ()); + assert_ne!(out.ptr, out1.ptr); + } + + #[test] + fn test_cache_get_at_different_locations() { + let mut cache = Cache::>::default(); + let device = CPU::::new(); + + assert_eq!(cache.nodes.len(), 0); + + let out1 = cache.get::(&device, 10, || ()); + assert_eq!(cache.nodes.len(), 1); + + let out2 = cache.get::(&device, 10, || ()); + + assert_ne!(out1.ptr, out2.ptr); + assert_eq!(cache.nodes.len(), 2); + } + + #[test] + fn test_cache_get_reuse_based_on_location() { + let mut cache = Cache::>::default(); + let device = CPU::::new(); + + let mut prev = None; + for _ in 0..1000 { + let out3 = cache.get::(&device, 10, || ()); + if prev.is_none() { + prev = Some(out3.ptr); + } + assert_eq!(prev.unwrap(), out3.ptr); + assert_eq!(cache.nodes.len(), 1); + prev = Some(out3.ptr); + } + assert_eq!(cache.nodes.len(), 1); + } +} diff --git a/src/cache/borrow_cache.rs b/src/cache/borrow_cache.rs new file mode 100644 index 00000000..e19a8659 --- /dev/null +++ b/src/cache/borrow_cache.rs @@ -0,0 +1,190 @@ +use core::{ + any::Any, + fmt::{Debug, Display}, + hash::BuildHasherDefault, + mem::transmute, +}; +use std::collections::HashMap; + +use crate::{flag::AllocFlag, Alloc, Buffer, Device, Id, Shape}; + +use super::NoHasher; + +pub type UniqueId = u64; + +#[derive(Clone, Copy)] +pub enum CachingError { + InvalidId, + InvalidTypeInfo, +} + +impl CachingError { + pub fn as_str(&self) -> &'static str { + match self { + CachingError::InvalidId => "InvalidId: Invalid Buffer identifier.", + CachingError::InvalidTypeInfo => "InvalidTypeInfo: Invalid type information provided for allocated Buffer. Does your specific operation use mixed types?", + } + } +} + +impl Debug for CachingError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + Display::fmt(&self, f) + } +} + +impl Display for CachingError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "{}", self.as_str()) + } +} + +impl std::error::Error for CachingError {} + +#[derive(Debug, Default)] +pub struct BorrowCache { + pub cache: HashMap, BuildHasherDefault>, +} + +// TODO: make BorrowedCache unuseable without device (=> Static get methods with D: CacheReturn) +impl BorrowCache { + pub fn add_or_get<'a, T, D, S>(&mut self, device: &'a D, id: Id) -> &Buffer<'a, T, D, S> + where + T: 'static, + D: Alloc + 'static, + S: Shape, + { + self.add_buf_once::(device, id); + + let buf_any = self.cache.get(&id).unwrap(); + buf_any.downcast_ref().unwrap() + } + + pub fn add_or_get_mut<'a, T, D, S>(&mut self, device: &D, id: Id) -> &mut Buffer<'a, T, D, S> + where + T: 'static, + D: Alloc + 'static, + S: Shape, + { + self.add_buf_once::(device, id); + self.get_buf_mut(id).unwrap() + } + + pub fn add_buf_once<'a, T, D, S>(&mut self, device: &'a D, id: Id) + where + T: 'static, + D: Alloc + 'static, + S: Shape, + { + if self.cache.get(&id).is_some() { + return; + } + + self.add_buf::(device, id) + } + + pub fn add_buf<'a, T, D, S>(&mut self, device: &'a D, id: Id) + where + T: 'static, + D: Alloc + 'static, + S: Shape, + { + // not using ::new, because this buf would get added to the cache of the device. + // not anymore ? + let buf = Buffer { + data: device.alloc::(id.len, AllocFlag::BorrowedCache), + device: Some(device), + }; + + let buf = unsafe { transmute::<_, Buffer<'static, T, D, S>>(buf) }; + self.cache.insert(*id, Box::new(buf)); + } + + #[inline] + pub fn get_buf_with_dev<'a, 'b, T, D, S>( + &'b self, + id: Id, + _dev: &'a D, + ) -> Option<&'b Buffer<'a, T, D, S>> + where + T: 'static, + D: Alloc + 'static, + S: Shape, + { + self.cache.get(&id)?.downcast_ref() + } + + #[inline] + pub fn get_buf<'a, T, D, S>(&self, id: Id) -> Result<&Buffer<'a, T, D, S>, CachingError> + where + T: 'static, + D: Device + 'static, + S: Shape, + { + self.cache + .get(&id) + .ok_or(CachingError::InvalidId)? + .downcast_ref() + .ok_or(CachingError::InvalidTypeInfo) + } + + #[inline] + pub fn get_buf_mut<'a, T, D, S>( + &mut self, + id: Id, + ) -> Result<&mut Buffer<'a, T, D, S>, CachingError> + where + T: 'static, + D: Device + 'static, + S: Shape, + { + unsafe { + transmute( + self.cache + .get_mut(&id) + .ok_or(CachingError::InvalidId)? + .downcast_mut::>() + .ok_or(CachingError::InvalidTypeInfo), + ) + } + } +} + +#[cfg(test)] +mod tests { + use crate::{Base, CPU}; + + use super::BorrowCache; + + /*#[test] + fn test_comp_error() { + let device = CPU::new(); + + + let a = { + let mut cache = BorrowingCache::default(); + cache.add_or_get::(&device, Id::new(10)) + }; + }*/ + + /*#[test] + fn test_get_borrowed() { + let device = CPU::::default(); + let mut cache = BorrowCache::default(); + + let (fid, sid, tid) = ( + Id::new_bumped(10), + Id::new_bumped(10), + Id::new_bumped(10), + ); + + cache.add_buf_once::(&device, fid); + cache.add_buf_once::(&device, sid); + cache.add_buf_once::(&device, tid); + + let a = cache.get_buf::(fid).unwrap(); + let b = cache.get_buf::(fid).unwrap(); + + assert_eq!(a.ptr, b.ptr); + }*/ +} diff --git a/src/cache/location_hasher.rs b/src/cache/location_hasher.rs new file mode 100644 index 00000000..5ea29811 --- /dev/null +++ b/src/cache/location_hasher.rs @@ -0,0 +1,73 @@ +use core::{ops::BitXor, panic::Location}; + +#[derive(Default)] +pub struct LocationHasher { + hash: u64, +} + +const K: u64 = 0x517cc1b727220a95; + +impl std::hash::Hasher for LocationHasher { + #[inline] + fn finish(&self) -> u64 { + self.hash + } + + #[inline] + fn write(&mut self, _bytes: &[u8]) { + unimplemented!("LocationHasher only hashes u64, (u32 and usize as u64 cast).") + } + + #[inline] + fn write_u64(&mut self, i: u64) { + self.hash = self.hash.rotate_left(5).bitxor(i).wrapping_mul(K); + } + + #[inline] + fn write_u32(&mut self, i: u32) { + self.write_u64(i as u64); + } + + #[inline] + fn write_usize(&mut self, i: usize) { + self.write_u64(i as u64); + } +} + +#[derive(Debug, Clone, Copy, Eq)] +pub struct HashLocation<'a> { + pub file: &'a str, + pub line: u32, + pub col: u32, +} + +impl PartialEq for HashLocation<'_> { + #[inline] + fn eq(&self, other: &Self) -> bool { + // if filename pointer is actually actually unique, then this works (added units tests to check this... still not sure) + if self.file.as_ptr() != other.file.as_ptr() { + return false; + } + self.line == self.line && self.col == self.col + } +} + +impl<'a> std::hash::Hash for HashLocation<'a> { + #[inline] + fn hash(&self, state: &mut H) { + self.file.as_ptr().hash(state); + let line_col = (self.line as u64) << 9 | self.col as u64; + line_col.hash(state); + } +} + +impl<'a> From<&'a Location<'a>> for HashLocation<'a> { + #[inline] + fn from(loc: &'a Location<'a>) -> Self { + Self { + file: loc.file(), + line: loc.line(), + col: loc.column(), + } + } +} diff --git a/src/cache/nohasher.rs b/src/cache/nohasher.rs new file mode 100644 index 00000000..f6facf62 --- /dev/null +++ b/src/cache/nohasher.rs @@ -0,0 +1,31 @@ +#[derive(Default)] +pub struct NoHasher { + hash: u64, +} + +impl std::hash::Hasher for NoHasher { + #[inline] + fn finish(&self) -> u64 { + self.hash + } + + #[inline] + fn write(&mut self, _bytes: &[u8]) { + unimplemented!("NoHasher only hashes u64, (u32 and usize as u64 cast).") + } + + #[inline] + fn write_u64(&mut self, i: u64) { + self.hash = i; + } + + #[inline] + fn write_u32(&mut self, i: u32) { + self.write_u64(i as u64); + } + + #[inline] + fn write_usize(&mut self, i: usize) { + self.write_u64(i as u64); + } +} diff --git a/src/count.rs b/src/count.rs deleted file mode 100644 index 46b0a93d..00000000 --- a/src/count.rs +++ /dev/null @@ -1,173 +0,0 @@ -use core::ops::{Range, RangeInclusive}; - -/// Converts ranges into a start and end index. -pub trait AsRangeArg { - /// Returns the start index of the range. - fn start(&self) -> usize; - /// Returns the end index of the range. - fn end(&self) -> usize; -} - -impl AsRangeArg for Range { - #[inline] - fn start(&self) -> usize { - self.start - } - - #[inline] - fn end(&self) -> usize { - self.end - } -} - -impl AsRangeArg for RangeInclusive { - #[inline] - fn start(&self) -> usize { - *self.start() - } - - #[inline] - fn end(&self) -> usize { - *self.end() + 1 - } -} - -impl AsRangeArg for usize { - #[inline] - fn start(&self) -> usize { - 0 - } - - #[inline] - fn end(&self) -> usize { - *self - } -} - -impl AsRangeArg for (usize, usize) { - #[inline] - fn start(&self) -> usize { - self.0 - } - - #[inline] - fn end(&self) -> usize { - self.1 - } -} - -/// `range` resets the cache count in every iteration. -/// The cache count is used to retrieve the same allocation in each iteration. -/// Not adding `range` results in allocating new memory in each iteration, -/// which is only freed when the device is dropped.
-/// To disable this caching behaviour, enable the `realloc` feature. -/// -/// # Example -#[cfg_attr(not(feature = "no-std"), doc = "```")] -#[cfg_attr(feature = "no-std", doc = "```ignore")] -/// use custos::{get_count, range, Ident, bump_count}; -/// -/// for _ in range(100) { // using only one usize: exclusive range -/// Ident::new(10); // an 'Ident' is created if a Buffer is retrieved from cache. -/// bump_count(); -/// assert!(get_count() == 1); -/// } -/// assert!(get_count() == 0); -/// ``` -#[inline] -pub fn range(range: R) -> Count { - Count(range.start(), range.end()) -} - -/// used to reset the cache count -#[derive(Debug, Clone, Copy)] -pub struct Count(pub(super) usize, pub(super) usize); - -/// The iterator used for setting the cache count. -#[derive(Debug)] -pub struct CountIntoIter { - epoch: usize, - #[cfg(not(feature = "no-std"))] - idx: usize, - end: usize, -} - -impl Iterator for CountIntoIter { - type Item = usize; - - fn next(&mut self) -> Option { - #[cfg(not(feature = "no-std"))] - unsafe { - crate::set_count(self.idx) - }; - if self.epoch >= self.end { - return None; - } - let epoch = Some(self.epoch); - self.epoch += 1; - epoch - } -} - -impl IntoIterator for Count { - type Item = usize; - - type IntoIter = CountIntoIter; - - #[inline] - fn into_iter(self) -> Self::IntoIter { - CountIntoIter { - epoch: self.0, - #[cfg(not(feature = "no-std"))] - idx: crate::get_count(), - end: self.1, - } - } -} - -#[cfg(test)] -mod tests { - use crate::{range, Count, CountIntoIter}; - - fn count_iter(iter: &mut CountIntoIter) { - iter.next(); - assert_eq!(iter.epoch, 1); - #[cfg(not(feature = "no-std"))] - assert_eq!(iter.idx, 0); - assert_eq!(iter.end, 10); - - iter.next(); - assert_eq!(iter.epoch, 2); - #[cfg(not(feature = "no-std"))] - assert_eq!(iter.idx, 0); - assert_eq!(iter.end, 10); - } - - #[test] - fn test_count_into_iter() { - let mut iter = CountIntoIter { - epoch: 0, - #[cfg(not(feature = "no-std"))] - idx: 0, - end: 10, - }; - - count_iter(&mut iter); - } - - #[test] - fn test_count() { - let count: Count = Count(0, 10); - count_iter(&mut count.into_iter()); - } - - #[test] - fn test_range_inclusive() { - let count: Count = range(0..=9); - count_iter(&mut count.into_iter()); - - for (idx, other) in count.into_iter().zip(0..=9) { - assert_eq!(idx, other) - } - } -} diff --git a/src/device_traits.rs b/src/device_traits.rs new file mode 100644 index 00000000..41543c30 --- /dev/null +++ b/src/device_traits.rs @@ -0,0 +1,88 @@ +// TODO: move to devices folder ig + +use crate::{flag::AllocFlag, prelude::Device, Buffer, HasId, Parents, PtrType, Shape, StackArray}; + +pub trait Alloc: Device + Sized { + /// Allocate memory on the implemented device. + /// # Example + #[cfg_attr(feature = "cpu", doc = "```")] + #[cfg_attr(not(feature = "cpu"), doc = "```ignore")] + /// use custos::{CPU, Alloc, Buffer, Read, flag::AllocFlag, GraphReturn, cpu::CPUPtr}; + /// + /// let device = CPU::new(); + /// let ptr = Alloc::::alloc(&device, 12, AllocFlag::None); + /// + /// let buf: Buffer = Buffer { + /// ident: None, + /// ptr, + /// device: Some(&device), + /// }; + /// assert_eq!(vec![0.; 12], device.read(&buf)); + /// ``` + fn alloc(&self, len: usize, flag: AllocFlag) -> Self::Data; + + /// Allocate new memory with data + /// # Example + #[cfg_attr(feature = "cpu", doc = "```")] + #[cfg_attr(not(feature = "cpu"), doc = "```ignore")] + /// use custos::{CPU, Alloc, Buffer, Read, GraphReturn, cpu::CPUPtr}; + /// + /// let device = CPU::new(); + /// let ptr = Alloc::::with_slice(&device, &[1, 5, 4, 3, 6, 9, 0, 4]); + /// + /// let buf: Buffer = Buffer { + /// ident: None, + /// ptr, + /// device: Some(&device), + /// }; + /// assert_eq!(vec![1, 5, 4, 3, 6, 9, 0, 4], device.read(&buf)); + /// ``` + fn alloc_from_slice(&self, data: &[T]) -> Self::Data + where + T: Clone; + + /// If the vector `vec` was allocated previously, this function can be used in order to reduce the amount of allocations, which may be faster than using a slice of `vec`. + #[inline] + #[cfg(not(feature = "no-std"))] + fn alloc_from_vec(&self, vec: Vec) -> Self::Data + where + T: Clone, + { + self.alloc_from_slice(&vec) + } + + /// Allocates a pointer with the array provided by the `S:`[`Shape`] generic. + /// By default, the array is flattened and then passed to [`Alloc::alloc_from_slice`]. + #[inline] + fn alloc_from_array(&self, array: S::ARR) -> Self::Data + where + T: Clone, + { + let stack_array = StackArray::::from_array(array); + self.alloc_from_slice(stack_array.flatten()) + } +} + +pub trait Module { + type Module; + + fn new() -> Self::Module; +} + +/// Used for modules that should affect the device. +pub trait Setup { + #[inline] + fn setup(_device: &mut D) {} +} + +pub trait Retriever: Device { + #[track_caller] + fn retrieve( + &self, + len: usize, + parents: impl Parents, + ) -> Buffer + where + T: 'static, + S: Shape; +} diff --git a/src/devices.rs b/src/devices.rs new file mode 100644 index 00000000..2996ee41 --- /dev/null +++ b/src/devices.rs @@ -0,0 +1,112 @@ +//! This module defines all available compute devices + +mod generic_blas; +pub use generic_blas::*; + +#[cfg(feature = "cpu")] +pub mod cpu; + +#[cfg(feature = "cuda")] +pub mod cuda; + +#[cfg(feature = "opencl")] +pub mod opencl; + +#[cfg(feature = "stack")] +pub mod stack; + +#[cfg(feature = "wgpu")] +pub mod wgpu; + +#[cfg(feature = "network")] +pub mod network; + +mod stack_array; +pub use stack_array::*; + +mod cdatatype; +pub use cdatatype::*; + +#[cfg(all(any(feature = "cpu", feature = "stack"), feature = "macro"))] +mod cpu_stack_ops; + +use crate::{Alloc, Buffer, HasId, OnDropBuffer, PtrType, Shape}; + +pub trait Device: OnDropBuffer + Sized { + type Data: HasId + PtrType; + + type Error; + + #[inline] + fn new() -> Result { + todo!() + } + + /// Creates a new [`Buffer`] using `A`. + /// + /// # Example + #[cfg_attr(feature = "cpu", doc = "```")] + #[cfg_attr(not(feature = "cpu"), doc = "```ignore")] + /// use custos::{CPU, Device}; + /// + /// let device = CPU::new(); + /// let buf = device.buffer([5, 4, 3]); + /// + /// assert_eq!(buf.read(), [5, 4, 3]); + /// ``` + fn buffer<'a, T, S: Shape, A>(&'a self, arr: A) -> Buffer<'a, T, Self, S> + where + Buffer<'a, T, Self, S>: From<(&'a Self, A)>, + { + Buffer::from((self, arr)) + } +} + +#[macro_export] +macro_rules! impl_buffer_hook_traits { + ($device:ident) => { + impl> OnNewBuffer + for $device + { + #[inline] + fn on_new_buffer(&self, device: &D, new_buf: &Buffer) { + self.modules.on_new_buffer(device, new_buf) + } + } + + impl OnDropBuffer for $device { + #[inline] + fn on_drop_buffer<'a, T, D: Device, S: Shape>( + &self, + device: &'a D, + buf: &Buffer, + ) { + self.modules.on_drop_buffer(device, buf) + } + } + }; +} + +#[macro_export] +macro_rules! impl_retriever { + ($device:ident) => { + impl> Retriever for $device { + #[inline] + fn retrieve( + &self, + len: usize, + parents: impl crate::Parents, + ) -> Buffer { + let data = self + .modules + .retrieve::(self, len, parents); + let buf = Buffer { + data, + device: Some(self), + }; + self.modules.on_retrieve_finish(&buf); + buf + } + } + }; +} diff --git a/src/devices/addons.rs b/src/devices/addons.rs deleted file mode 100644 index b8450bda..00000000 --- a/src/devices/addons.rs +++ /dev/null @@ -1,124 +0,0 @@ -use core::{cell::RefCell, fmt::Debug}; - -use crate::{Cache, CacheReturn, Device, GlobalCount, Graph, GraphReturn, NodeIdx, PtrConv}; - -use super::caller_cache::{CallerCacheReturn, TrackCallerCache}; - -/// Provides several addons for a device. -/// - `graph`: An optimizeable graph. -/// - `cache`: A cache for allocations. -/// - `tape`: A (gradient) tape. -pub struct Addons { - /// An optimizeable graph. - pub graph: RefCell>, - /// A cache for allocations. - pub cache: RefCell>, - /// A (gradient) tape. - #[cfg(feature = "autograd")] - pub tape: RefCell>, - pub caller_cache: RefCell>, -} - -impl Debug for Addons -where - D::Ptr: Debug, -{ - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - #[cfg(feature = "autograd")] - { - f.debug_struct("Addons") - .field("graph", &self.graph) - .field("cache", &self.cache) - .field("tape", &self.tape) - .finish() - } - - #[cfg(not(feature = "autograd"))] - f.debug_struct("Addons") - .field("graph", &self.graph) - .field("cache", &self.cache) - .finish() - } -} - -impl Default for Addons -where - D::Ptr: Default, -{ - fn default() -> Self { - Self { - graph: Default::default(), - cache: Default::default(), - #[cfg(feature = "autograd")] - tape: Default::default(), - caller_cache: Default::default(), - } - } -} - -/// `AddonsReturn` is probably implemented for all devices that have an [`Addons`] field. -pub trait AddonsReturn: Device { - /// Returns a reference to [`Addons`]. - fn addons(&self) -> &Addons; -} - -impl GraphReturn for D { - #[inline] - fn graph(&self) -> std::cell::Ref> { - self.addons().graph.borrow() - } - - #[inline] - fn graph_mut(&self) -> std::cell::RefMut> { - self.addons().graph.borrow_mut() - } -} - -impl CacheReturn for D { - #[inline] - fn cache(&self) -> core::cell::Ref> - where - Self: PtrConv, - { - self.addons().cache.borrow() - } - - #[inline] - fn cache_mut(&self) -> core::cell::RefMut> - where - Self: PtrConv, - { - self.addons().cache.borrow_mut() - } -} - -impl CallerCacheReturn for D { - #[inline] - fn cache(&self) -> core::cell::Ref> - where - Self: PtrConv, - { - self.addons().caller_cache.borrow() - } - - #[inline] - fn cache_mut(&self) -> core::cell::RefMut> - where - Self: PtrConv, - { - self.addons().caller_cache.borrow_mut() - } -} - -#[cfg(feature = "autograd")] -impl crate::TapeReturn for D { - #[inline] - fn tape(&self) -> core::cell::Ref> { - self.addons().tape.borrow() - } - - #[inline] - fn tape_mut(&self) -> core::cell::RefMut> { - self.addons().tape.borrow_mut() - } -} diff --git a/src/devices/borrowing_cache.rs b/src/devices/borrowing_cache.rs deleted file mode 100644 index 0ed76a20..00000000 --- a/src/devices/borrowing_cache.rs +++ /dev/null @@ -1,123 +0,0 @@ -use core::{any::Any, hash::BuildHasherDefault, mem::transmute}; -use std::collections::HashMap; - -use crate::{flag::AllocFlag, Alloc, Buffer, Device, Ident, IdentHasher, Shape}; - -#[derive(Debug, Default)] -pub struct BorrowingCache { - pub cache: HashMap, BuildHasherDefault>, -} - -// TODO: make BorrowedCache unuseable without device (=> Static get methods with D: CacheReturn) -impl BorrowingCache { - pub fn add_or_get<'a, T, D, S>(&mut self, device: &'a D, id: Ident) -> &Buffer<'a, T, D, S> - where - T: 'static, - D: Alloc<'a, T, S> + 'static, - S: Shape, - { - self.add_buf_once(device, id); - - let buf_any = self.cache.get(&id).unwrap(); - buf_any.downcast_ref().unwrap() - } - - pub fn add_or_get_mut<'a, T, D, S>(&mut self, device: &D, id: Ident) -> &mut Buffer<'a, T, D, S> - where - T: 'static, - D: for<'b> Alloc<'b, T, S> + 'static, - S: Shape, - { - self.add_buf_once(device, id); - self.get_buf_mut(id).unwrap() - } - - pub fn add_buf_once<'a, T, D, S>(&mut self, device: &'a D, ident: Ident) - where - T: 'static, - D: Alloc<'a, T, S> + 'static, - S: Shape, - { - if self.cache.get(&ident).is_some() { - return; - } - - self.add_buf(device, ident) - } - - pub fn add_buf<'a, T, D, S>(&mut self, device: &'a D, ident: Ident) - where - T: 'static, - D: Alloc<'a, T, S> + 'static, - S: Shape, - { - // not using ::new, because this buf would get added to the cache of the device. - let buf = Buffer { - ptr: device.alloc(ident.len, AllocFlag::BorrowedCache), - device: Some(device), - ident: Some(ident), - }; - - let buf = unsafe { transmute::<_, Buffer<'static, T, D, S>>(buf) }; - self.cache.insert(ident, Box::new(buf)); - } - - #[inline] - pub fn get_buf<'a, T, D, S>(&self, id: Ident) -> Option<&Buffer<'a, T, D, S>> - where - T: 'static, - D: Device + 'static, - S: Shape, - { - self.cache.get(&id)?.downcast_ref() - } - - #[inline] - pub fn get_buf_mut<'a, T, D, S>(&mut self, id: Ident) -> Option<&mut Buffer<'a, T, D, S>> - where - T: 'static, - D: Device + 'static, - S: Shape, - { - unsafe { transmute(self.cache.get_mut(&id)?.downcast_mut::>()) } - } -} - -#[cfg(test)] -mod tests { - use crate::{Ident, CPU}; - - use super::BorrowingCache; - - /*#[test] - fn test_comp_error() { - let device = CPU::new(); - - - let a = { - let mut cache = BorrowingCache::default(); - cache.add_or_get::(&device, Ident::new(10)) - }; - }*/ - - #[test] - fn test_get_borrowed() { - let device = CPU::new(); - let mut cache = BorrowingCache::default(); - - let (fid, sid, tid) = ( - Ident::new_bumped(10), - Ident::new_bumped(10), - Ident::new_bumped(10), - ); - - cache.add_buf_once::(&device, fid); - cache.add_buf_once::(&device, sid); - cache.add_buf_once::(&device, tid); - - let a = cache.get_buf::(fid).unwrap(); - let b = cache.get_buf::(fid).unwrap(); - - assert_eq!(a.ptr, b.ptr); - } -} diff --git a/src/devices/cache.rs b/src/devices/cache.rs deleted file mode 100644 index ecf91c9a..00000000 --- a/src/devices/cache.rs +++ /dev/null @@ -1,321 +0,0 @@ -//! Contains the [`Cache`]ing logic. - -use core::{cell::RefMut, fmt::Debug, hash::BuildHasherDefault, ops::BitXor}; -use std::collections::HashMap; - -use std::rc::Rc; - -use crate::{ - flag::AllocFlag, shape::Shape, Alloc, Buffer, CacheAble, Device, GlobalCount, GraphReturn, - Ident, PtrConv, PtrType, -}; - -/// This trait makes a device's [`Cache`] accessible and is implemented for all compute devices. -pub trait CacheReturn: GraphReturn { - /// Returns a reference to a device's [`Cache`]. - fn cache(&self) -> core::cell::Ref> - where - Self: PtrConv; - - /// Returns a mutable reference to a device's [`Cache`]. - fn cache_mut(&self) -> RefMut> - where - Self: PtrConv; -} - -const K: usize = 0x517cc1b727220a95; - -/// A low-overhead [`Ident`] hasher using "FxHasher". -#[derive(Default)] -pub struct IdentHasher { - hash: usize, -} - -impl std::hash::Hasher for IdentHasher { - #[inline] - fn finish(&self) -> u64 { - self.hash as u64 - } - - #[inline] - fn write(&mut self, _bytes: &[u8]) { - unimplemented!("IdentHasher only hashes usize.") - } - - #[inline] - fn write_usize(&mut self, i: usize) { - self.hash = self.hash.rotate_left(5).bitxor(i).wrapping_mul(K); - } -} - -impl CacheAble for Cache -where - D: PtrConv + CacheReturn, -{ - #[cfg(not(feature = "realloc"))] - #[inline] - fn retrieve( - device: &D, - len: usize, - add_node: impl crate::AddGraph, - ) -> Buffer - where - for<'b> D: Alloc<'b, T, S>, - { - device - .cache_mut() - .get(device, Ident::new(len), add_node, crate::bump_count) - } - - #[cfg(feature = "realloc")] - #[inline] - fn retrieve( - device: &D, - len: usize, - _add_node: impl crate::AddGraph, - ) -> Buffer - where - for<'b> D: Alloc<'b, T, S>, - { - Buffer::new(device, len) - } - - #[inline] - unsafe fn get_existing_buf(device: &D, ident: Ident) -> Option> { - let ptr = D::convert(device.cache().nodes.get(&ident)?, AllocFlag::Wrapper); - - Some(Buffer { - ptr, - device: Some(device), - ident: Some(ident), - }) - } - - #[inline] - fn remove(device: &D, ident: Ident) { - device.cache_mut().nodes.remove(&ident); - } - - fn add_to_cache(device: &D, ptr: &::Ptr) -> Option { - device.graph_mut().add_leaf(ptr.size()); - let ident = Ident::new_bumped(ptr.size()); - let raw_ptr = unsafe { std::rc::Rc::new(D::convert(ptr, AllocFlag::Wrapper)) }; - device.cache_mut().nodes.insert(ident, raw_ptr); - Some(ident) - } -} - -/// A cache for 'no-generic' raw pointers. -pub struct Cache { - /// A map of all cached buffers using a custom hash function. - pub nodes: HashMap>, BuildHasherDefault>, -} - -impl Debug for Cache -where - D::Ptr: Debug, -{ - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Cache2") - .field("cache", &self.nodes) - .finish() - } -} - -impl Default for Cache -where - D::Ptr: Default, -{ - #[inline] - fn default() -> Self { - Self { - nodes: Default::default(), - } - } -} - -impl Cache { - /// Adds a new cache entry to the cache. - /// The next get call will return this entry if the [Ident] is correct. - /// # Example - #[cfg_attr(feature = "cpu", doc = "```")] - #[cfg_attr(not(feature = "cpu"), doc = "```ignore")] - /// use custos::prelude::*; - /// use custos::{Ident, bump_count}; - /// - /// let device = CPU::new(); - /// let cache: Buffer = device - /// .cache_mut() - /// .add_node(&device, Ident { idx: 0, len: 7 }, (), bump_count); - /// - /// let ptr = device - /// .cache() - /// .nodes - /// .get(&Ident { idx: 0, len: 7 }) - /// .unwrap() - /// .clone(); - /// - /// assert_eq!(cache.host_ptr(), ptr.ptr as *mut f32); - /// ``` - pub fn add_node<'a, T, S: Shape>( - &mut self, - device: &'a D, - ident: Ident, - _add_node: impl crate::AddGraph, - callback: fn(), - ) -> Buffer<'a, T, D, S> - where - D: Alloc<'a, T, S>, - { - let ptr = device.alloc(ident.len, AllocFlag::Wrapper); - - #[cfg(feature = "opt-cache")] - let graph_node = device.graph_mut().add(ident.len, _add_node); - - #[cfg(not(feature = "opt-cache"))] - let graph_node = crate::Node { - idx: ident.idx, - deps: [0; 2], - len: ident.len, - }; - - let untyped_ptr = unsafe { D::convert(&ptr, AllocFlag::None) }; - self.nodes.insert(ident, Rc::new(untyped_ptr)); - - callback(); - - Buffer { - ptr, - device: Some(device), - ident: Some(Ident { - idx: graph_node.idx, - len: ident.len, - }), - } - } - - /// Retrieves cached pointers and constructs a [`Buffer`] with the pointers and the given `len`gth. - /// If a cached pointer doesn't exist, a new `Buffer` will be added to the cache and returned. - /// - /// # Example - #[cfg_attr(feature = "cpu", doc = "```")] - #[cfg_attr(not(feature = "cpu"), doc = "```ignore")] - /// use custos::prelude::*; - /// use custos::bump_count; - /// - /// let device = CPU::new(); - /// - /// let cache_entry: Buffer = device.cache_mut().get(&device, Ident::new(10), (), bump_count); - /// let new_cache_entry: Buffer = device.cache_mut().get(&device, Ident::new(10), (), bump_count); - /// - /// assert_ne!(cache_entry.ptrs(), new_cache_entry.ptrs()); - /// - /// unsafe { set_count(0) }; - /// - /// let first_entry: Buffer = device.cache_mut().get(&device, Ident::new(10), (), bump_count); - /// assert_eq!(cache_entry.ptrs(), first_entry.ptrs()); - /// ``` - pub fn get<'a, T, S: Shape>( - &mut self, - device: &'a D, - ident: Ident, - add_node: impl crate::AddGraph, - callback: fn(), - ) -> Buffer<'a, T, D, S> - where - D: Alloc<'a, T, S>, - { - let may_allocated = self.nodes.get(&ident); - - match may_allocated { - Some(ptr) => { - callback(); - let typed_ptr = unsafe { D::convert(ptr, AllocFlag::Wrapper) }; - - Buffer { - ptr: typed_ptr, - device: Some(device), - ident: Some(ident), - } - } - None => self.add_node(device, ident, add_node, callback), - } - } -} - -#[cfg(test)] -mod tests { - use core::hash::Hasher; - use std::collections::HashSet; - - //#[cfg(not(feature = "realloc"))] - //use crate::set_count; - //use crate::{bump_count, Buffer, CacheReturn, Ident, IdentHasher}; - - #[test] - #[cfg_attr(miri, ignore)] - fn test_ident_hasher() { - use crate::IdentHasher; - - let mut hashed_items = HashSet::new(); - let mut hasher = IdentHasher::default(); - - for item in 0..2500000 { - hasher.write_usize(item); - hasher.write_usize(100000); - let hashed_item = hasher.finish(); - assert!(!hashed_items.contains(&hashed_item)); - - hashed_items.insert(hashed_item); - } - } - - #[cfg(feature = "cpu")] - #[test] - fn test_add_node() { - use crate::{bump_count, Buffer, CacheReturn, Ident}; - - let device = crate::CPU::new(); - let cache: Buffer = - device - .cache_mut() - .add_node(&device, Ident { idx: 0, len: 7 }, (), bump_count); - - let ptr = device - .cache() - .nodes - .get(&Ident { idx: 0, len: 7 }) - .unwrap() - .clone(); - - assert_eq!(cache.host_ptr(), ptr.ptr as *mut f32); - } - - #[cfg(feature = "cpu")] - #[cfg(not(feature = "realloc"))] - #[test] - fn test_get() { - // for: cargo test -- --test-threads=1 - - use crate::{bump_count, set_count, Buffer, CacheReturn, Ident}; - unsafe { set_count(0) }; - let device = crate::CPU::new(); - - let cache_entry: Buffer = device - .cache_mut() - .get(&device, Ident::new(10), (), bump_count); - let new_cache_entry: Buffer = - device - .cache_mut() - .get(&device, Ident::new(10), (), bump_count); - - assert_ne!(cache_entry.ptrs(), new_cache_entry.ptrs()); - - unsafe { set_count(0) }; - - let first_entry: Buffer = device - .cache_mut() - .get(&device, Ident::new(10), (), bump_count); - assert_eq!(cache_entry.ptrs(), first_entry.ptrs()); - } -} diff --git a/src/devices/caller_cache.rs b/src/devices/caller_cache.rs deleted file mode 100644 index c6442609..00000000 --- a/src/devices/caller_cache.rs +++ /dev/null @@ -1,166 +0,0 @@ -use core::{cell::RefMut, panic::Location}; -use std::collections::HashMap; - -use std::rc::Rc; - -use crate::{bump_count, flag::AllocFlag, Alloc, Buffer, Device, Ident, PtrConv, Shape, CPU}; - -pub trait CallerCacheReturn { - /// Returns a reference to a device's [`Cache`]. - fn cache(&self) -> core::cell::Ref> - where - Self: PtrConv; - - /// Returns a mutable reference to a device's [`Cache`]. - fn cache_mut(&self) -> RefMut> - where - Self: PtrConv; -} -pub trait Cache: Device { - type CallerCache: TrackCallerCacheAble; - - #[track_caller] - #[inline] - fn call(&self, len: usize) -> Buffer - where - for<'b> Self: Alloc<'b, T, S>, - { - Self::CallerCache::get(self, len) - } -} - -impl Cache for CPU { - type CallerCache = TrackCallerCache; -} - -#[derive(Debug, Default)] -pub struct TrackCallerCache { - nodes: HashMap<&'static std::panic::Location<'static>, Rc>>, -} - -pub trait TrackCallerCacheAble { - #[track_caller] - fn get(device: &D, len: usize) -> Buffer - where - for<'b> D: Alloc<'b, T, S>; -} - -impl TrackCallerCacheAble for () { - #[inline] - fn get(device: &D, len: usize) -> Buffer - where - for<'b> D: Alloc<'b, T, S>, - { - Buffer::new(device, len) - } -} - -impl TrackCallerCacheAble for TrackCallerCache -where - D: PtrConv + CallerCacheReturn, -{ - #[track_caller] - fn get(device: &D, len: usize) -> Buffer - where - D: for<'a> Alloc<'a, T, S>, - { - device.cache_mut().get(device, Ident::new(len), bump_count) - } -} - -impl TrackCallerCache { - #[track_caller] - pub fn get<'a, T, S>( - &mut self, - device: &'a D, - ident: Ident, - callback: fn(), - ) -> Buffer<'a, T, D, S> - where - D: Alloc<'a, T, S>, - S: Shape, - { - let maybe_allocated = self.nodes.get(Location::caller()); - - match maybe_allocated { - Some(ptr) => { - callback(); - let typed_ptr = unsafe { D::convert(ptr, AllocFlag::Wrapper) }; - - Buffer { - ptr: typed_ptr, - device: Some(device), - ident: Some(ident), - } - } - None => self.add_node(device, ident, callback), - } - } - - #[track_caller] - fn add_node<'a, T, S>( - &mut self, - device: &'a D, - ident: Ident, - callback: fn(), - ) -> Buffer<'a, T, D, S> - where - D: Alloc<'a, T, S>, - S: Shape, - { - let ptr = device.alloc(ident.len, AllocFlag::Wrapper); - - let untyped_ptr = unsafe { D::convert(&ptr, AllocFlag::None) }; - self.nodes.insert(Location::caller(), Rc::new(untyped_ptr)); - - callback(); - - Buffer { - ptr, - device: Some(device), - ident: Some(Ident { - idx: ident.idx, - len: ident.len, - }), - } - } -} - -#[cfg(test)] -mod tests { - use core::ops::Add; - - use crate::{devices::caller_cache::CallerCacheReturn, Buffer, Device, CPU}; - - use super::Cache; - - #[track_caller] - fn add<'a, T: Add + Copy>( - device: &'a CPU, - lhs: &Buffer, - rhs: &Buffer, - ) -> Buffer<'a, T> { - let len = std::cmp::min(lhs.len(), rhs.len()); - - let mut out = device.call::(len); - - for idx in 0..len { - out[idx] = lhs[idx] + rhs[idx]; - } - - out - } - #[test] - fn test_caller_cache() { - let device = CPU::new(); - - let lhs = device.buffer([1, 2, 3, 4]); - let rhs = device.buffer([1, 2, 3, 4]); - - for _i in 0..100 { - add(&device, &lhs, &rhs); - } - - assert_eq!(device.cache().nodes.len(), 1); - } -} diff --git a/src/devices/cpu/cpu_device.rs b/src/devices/cpu/cpu_device.rs index 842f8d61..668c0528 100644 --- a/src/devices/cpu/cpu_device.rs +++ b/src/devices/cpu/cpu_device.rs @@ -1,16 +1,13 @@ -use crate::{ - devices::cache::Cache, flag::AllocFlag, shape::Shape, Addons, AddonsReturn, Alloc, Buffer, - CloneBuf, Device, DevicelessAble, MainMemory, PtrConv, -}; +use core::convert::Infallible; -use core::{ - fmt::Debug, - mem::{align_of, size_of}, +use crate::{ + cpu::CPUPtr, flag::AllocFlag, impl_buffer_hook_traits, impl_retriever, Alloc, Base, Buffer, + Cached, CachedModule, CloneBuf, Device, HasModules, LazySetup, MainMemory, Module, + OnDropBuffer, OnNewBuffer, Retrieve, Retriever, Setup, Shape, TapeActions, }; -use super::CPUPtr; +pub trait IsCPU {} -#[derive(Debug, Default)] /// A CPU is used to perform calculations on the host CPU. /// To make new operations invocable, a trait providing new functions should be implemented for [CPU]. /// @@ -25,41 +22,69 @@ use super::CPUPtr; /// /// assert_eq!(out, vec![1, 2, 3]); /// ``` -pub struct CPU { - /// Provides additional functionality for the CPU. e.g. a cache, a gradient [`Tape`](crate::Tape), an optimizeable [`Graph`](crate::Graph) and a [`Cache`](crate::Cache). - pub addons: Addons, +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)] +pub struct CPU { + pub modules: Mods, } -impl CPU { - /// Creates an [CPU] with default addons. - #[must_use] - pub fn new() -> CPU { - CPU { - addons: Addons::default(), - } +impl_retriever!(CPU); +impl_buffer_hook_traits!(CPU); + +impl IsCPU for CPU {} + +// maybe +impl CPU { + pub fn default() -> CPU>>> { + CPU::>::new() + } +} + +impl Device for CPU { + type Error = Infallible; + type Data = CPUPtr; + + fn new() -> Result { + todo!() + // Ok(CPU::new()) } } -impl Device for CPU { - type Ptr = CPUPtr; - type Cache = Cache; //::CT +impl MainMemory for CPU { + #[inline] + fn as_ptr(ptr: &Self::Data) -> *const T { + ptr.ptr + } - fn new() -> crate::Result { - Ok(Self::new()) + #[inline] + fn as_ptr_mut(ptr: &mut Self::Data) -> *mut T { + ptr.ptr } } -impl AddonsReturn for CPU { +impl HasModules for CPU { #[inline] - fn addons(&self) -> &Addons { - &self.addons + fn modules(&self) -> &Mods { + &self.modules } } -impl<'a, T> DevicelessAble<'a, T> for CPU {} +impl CPU { + #[inline] + pub fn new() -> CPU + where + SimpleMods: Module, Module = NewMods>, + NewMods: Setup>, + { + let mut cpu = CPU { + modules: SimpleMods::new(), + }; + NewMods::setup(&mut cpu); + cpu + } +} -impl Alloc<'_, T, S> for CPU { - fn alloc(&self, mut len: usize, flag: AllocFlag) -> CPUPtr { +impl Alloc for CPU { + fn alloc(&self, mut len: usize, flag: AllocFlag) -> Self::Data { assert!(len > 0, "invalid buffer len: 0"); if S::LEN > len { @@ -69,8 +94,9 @@ impl Alloc<'_, T, S> for CPU { CPUPtr::new_initialized(len, flag) } - fn with_slice(&self, data: &[T]) -> CPUPtr + fn alloc_from_slice(&self, data: &[T]) -> Self::Data where + S: Shape, T: Clone, { assert!(!data.is_empty(), "invalid buffer len: 0"); @@ -82,7 +108,11 @@ impl Alloc<'_, T, S> for CPU { cpu_ptr } - fn alloc_with_vec(&self, mut vec: Vec) -> CPUPtr { + + fn alloc_from_vec(&self, mut vec: Vec) -> Self::Data + where + T: Clone, + { assert!(!vec.is_empty(), "invalid buffer len: 0"); let ptr = vec.as_mut_ptr(); @@ -93,37 +123,25 @@ impl Alloc<'_, T, S> for CPU { } } -impl PtrConv for CPU { +impl TapeActions for CPU { #[inline] - unsafe fn convert( - ptr: &Self::Ptr, - flag: AllocFlag, - ) -> Self::Ptr { - CPUPtr { - ptr: ptr.ptr as *mut Conv, - len: ptr.len, - flag, - align: Some(align_of::()), - size: Some(size_of::()), - } - } -} - -impl MainMemory for CPU { - #[inline] - fn as_ptr(ptr: &Self::Ptr) -> *const T { - ptr.ptr + fn tape(&self) -> Option> { + self.modules.tape() } #[inline] - fn as_ptr_mut(ptr: &mut Self::Ptr) -> *mut T { - ptr.ptr + fn tape_mut(&self) -> Option> { + self.modules.tape_mut() } } -impl<'a, T: Clone, S: Shape> CloneBuf<'a, T, S> for CPU { +impl LazySetup for CPU {} + +impl<'a, Mods: OnDropBuffer + OnNewBuffer, T: Clone, S: Shape> CloneBuf<'a, T, S> + for CPU +{ #[inline] - fn clone_buf(&'a self, buf: &Buffer<'a, T, CPU, S>) -> Buffer<'a, T, CPU, S> { + fn clone_buf(&'a self, buf: &Buffer<'a, T, CPU, S>) -> Buffer<'a, T, CPU, S> { let mut cloned = Buffer::new(self, buf.len()); cloned.clone_from_slice(buf); cloned diff --git a/src/devices/cpu/mod.rs b/src/devices/cpu/mod.rs index ce150ca4..21639fa8 100644 --- a/src/devices/cpu/mod.rs +++ b/src/devices/cpu/mod.rs @@ -1,9 +1,6 @@ //! The CPU module provides the CPU backend for custos. -use crate::{ - module_comb::{HasId, Id}, - CommonPtrs, PtrType, ShallowCopy, -}; +use crate::{CommonPtrs, HasId, Id, PtrType, ShallowCopy}; #[cfg(feature = "blas")] pub use blas::*; use core::{ diff --git a/src/devices/cpu/ops.rs b/src/devices/cpu/ops.rs index 33b6f65a..6064d20c 100644 --- a/src/devices/cpu/ops.rs +++ b/src/devices/cpu/ops.rs @@ -1,6 +1,33 @@ -use core::ops::{Index, Range, RangeBounds}; +use core::{ + any::Any, + ops::{Index, Range, RangeBounds}, +}; -use crate::{bounds_to_range, Buffer, ClearBuf, CopySlice, MainMemory, Read, Shape, WriteBuf, CPU}; +use crate::{ + bounds_to_range, AddOperation, Buffer, ClearBuf, CopySlice, Device, MainMemory, OnDropBuffer, + Operation, Read, Shape, WriteBuf, CPU, +}; + +impl AddOperation for CPU { + #[inline] + unsafe fn add_operation( + &self, + out: &mut Buffer, + operation: impl Fn(&mut dyn Any), + ) { + self.modules.add_operation(out, operation) + } + + #[inline] + fn add_operation2(&self, operation: impl Operation) { + self.modules.add_operation2(operation) + } + + #[inline] + fn call_lazily(&self) { + self.modules.call_lazily() + } +} impl Read for CPU { type Read<'a> = &'a [T] where T: 'a, D: 'a, S: 'a; @@ -19,7 +46,7 @@ impl Read for CPU { } } -impl WriteBuf for CPU { +impl WriteBuf for CPU { #[inline] fn write(&self, buf: &mut Buffer, data: &[T]) { buf.copy_from_slice(data) diff --git a/src/devices/cpu_stack_ops.rs b/src/devices/cpu_stack_ops.rs index 90e7f5f8..e1e01011 100644 --- a/src/devices/cpu_stack_ops.rs +++ b/src/devices/cpu_stack_ops.rs @@ -4,9 +4,9 @@ use core::ops::AddAssign; //#[cfg(any(feature = "cpu", feature = "stack"))] use custos_macro::impl_stack; -use crate::MayToCLSource; #[cfg(any(feature = "cpu", feature = "stack"))] -use crate::{ApplyFunction, Buffer, Device, Eval, MainMemory, Resolve, Shape, ToVal, UnaryGrad}; +use crate::{ApplyFunction, Buffer, Eval, MainMemory, Resolve, Shape, ToVal, UnaryGrad}; +use crate::{MayToCLSource, OnDropBuffer, Retrieve, Retriever}; #[cfg(feature = "cpu")] use crate::CPU; @@ -15,9 +15,10 @@ use crate::CPU; use crate::Stack; #[impl_stack] -impl ApplyFunction for CPU +impl ApplyFunction for CPU where - T: Copy + Default + ToVal, + Mods: OnDropBuffer + Retrieve, + T: Copy + Default + ToVal + 'static, D: crate::MainMemory, S: Shape, { @@ -25,7 +26,7 @@ where where F: Eval + MayToCLSource, { - let mut out = self.retrieve::(buf.len(), buf); + let mut out = self.retrieve(buf.len(), buf); for (value, x) in out.iter_mut().zip(buf.iter()) { *value = f((*x).to_val()).eval() diff --git a/src/devices/ident.rs b/src/devices/ident.rs deleted file mode 100644 index dd1806cf..00000000 --- a/src/devices/ident.rs +++ /dev/null @@ -1,61 +0,0 @@ -use core::cell::Cell; -use std::thread_local; - -thread_local! { - pub(crate) static COUNT: Cell = Cell::new(0); -} - -/// Sets current cache identifier / index. -/// This function is usually called after an iteration in a loop -> [Count](crate::Count) or [range](crate::range) -/// # Safety -/// Manually setting the count may yield multiple `Buffer` pointing two the same data. -#[inline] -pub unsafe fn set_count(count: usize) { - COUNT.with(|c| c.set(count)); -} - -/// Returns current cache identifier / index -#[inline] -pub fn get_count() -> usize { - COUNT.with(|c| c.get()) -} - -#[inline] -/// Increases the cache identifier / index by 1. -pub fn bump_count() { - COUNT.with(|c| { - let count = c.get(); - c.set(count + 1); - }) -} - -#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord)] -/// An `Ident` is used to identify a cached pointer. -pub struct Ident { - /// The index of the `Ident`. - pub idx: usize, - /// The amount of elements a corresponding [`Buffer`](crate::Buffer) has. - pub len: usize, -} - -impl Ident { - /// Returns a new `Ident` with the current cache identifier / index. - #[inline] - pub fn new(len: usize) -> Ident { - Ident { - idx: get_count(), - len, - } - } - - /// Returns a new `Ident` with the current cache identifier / index and increases the cache identifier / index by 1. - #[inline] - pub fn new_bumped(len: usize) -> Ident { - let id = Ident { - idx: get_count(), - len, - }; - bump_count(); - id - } -} diff --git a/src/devices/mod.rs b/src/devices/mod.rs deleted file mode 100644 index 842a915c..00000000 --- a/src/devices/mod.rs +++ /dev/null @@ -1,155 +0,0 @@ -//! This module defines all available compute devices - -mod generic_blas; -pub use generic_blas::*; - -#[cfg(feature = "no-std")] -#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord)] -/// Dummy Ident -pub struct Ident { - /// unused - pub idx: usize, - /// unused - pub len: usize, -} - -#[cfg(not(feature = "no-std"))] -mod addons; -#[cfg(not(feature = "no-std"))] -pub use addons::*; - -use crate::{flag::AllocFlag, shape::Shape, AddGraph, Alloc, Buffer, Device}; - -#[cfg(not(feature = "no-std"))] -pub mod cache; - -#[cfg(not(feature = "no-std"))] -#[cfg(feature = "autograd")] -pub mod borrowing_cache; - -//pub mod cache; -#[cfg(not(feature = "no-std"))] -pub use cache::*; - -//pub use cache::{Cache, CacheReturn}; - -#[cfg(feature = "cpu")] -pub mod cpu; - -#[cfg(feature = "cuda")] -pub mod cuda; - -#[cfg(feature = "opencl")] -pub mod opencl; - -#[cfg(feature = "stack")] -pub mod stack; - -#[cfg(feature = "wgpu")] -pub mod wgpu; - -#[cfg(feature = "network")] -pub mod network; - -mod stack_array; -pub use stack_array::*; - -mod cdatatype; -pub use cdatatype::*; - -mod caller_cache; - -#[cfg(all(any(feature = "cpu", feature = "stack"), feature = "macro"))] -mod cpu_stack_ops; - -#[cfg(not(feature = "no-std"))] -mod ident; -#[cfg(not(feature = "no-std"))] -pub use ident::*; - -/// Used to convert a device pointer to the a pointer of a different type. -pub trait PtrConv: Device { - /// Converts a pointer to a pointer with a different type. - /// # Safety - /// Prone to double frees. Make sure that the pointer is not freed twice. - /// `custos` solves this by using fitting [`AllocFlag`]s. - unsafe fn convert( - ptr: &Self::Ptr, - flag: AllocFlag, - ) -> Self::Ptr; -} - -/// Implementors of this trait can be used as cache for a device. -pub trait CacheAble { - /// May allocate a new buffer or return an existing one. - /// It may use the cache count provided by the cache count ([Ident]). - /// This depends on the type of cache. - /// - /// # Example - #[cfg_attr(all(feature = "cpu", not(feature = "realloc")), doc = "```")] - #[cfg_attr(all(not(feature = "cpu"), feature = "realloc"), doc = "```ignore")] - /// use custos::{Device, CPU, set_count}; - /// - /// let device = CPU::new(); - /// - /// let buf = device.retrieve::(10, ()); - /// - /// // unsafe, because the next .retrieve call will tehn return the same buffer - /// unsafe { set_count(0) } - /// - /// let buf_2 = device.retrieve::(10, ()); - /// - /// assert_eq!(buf.ptr.ptr, buf_2.ptr.ptr); - /// - /// ``` - fn retrieve(device: &D, len: usize, add_node: impl AddGraph) -> Buffer - where - for<'a> D: Alloc<'a, T, S>; - - /// May return an existing buffer using the provided [`Ident`]. - /// This function panics if no buffer with the provided [`Ident`] exists. - /// - /// # Safety - /// This function is unsafe because it is possible to return multiple `Buffer` with `Ident` that share the same memory. - /// If this function is called twice with the same `Ident`, the returned `Buffer` will be the same. - /// Even though the return `Buffer`s are owned, this does not lead to double-frees (see [`AllocFlag`]). - #[cfg(not(feature = "no-std"))] - unsafe fn get_existing_buf(device: &D, id: Ident) -> Option>; - - /// Removes a `Buffer` with the provided [`Ident`] from the cache. - /// This function is internally called when a `Buffer` with [`AllocFlag`] `None` is dropped. - #[cfg(not(feature = "no-std"))] - fn remove(device: &D, ident: Ident); - - /// Adds a pointer that was allocated by [`Alloc`] to the cache and returns a new corresponding [`Ident`]. - /// This function is internally called when a `Buffer` with [`AllocFlag`] `None` is created. - #[cfg(not(feature = "no-std"))] - fn add_to_cache(device: &D, ptr: &D::Ptr) -> Option; -} - -// TODO: Mind num implement? -impl CacheAble for () { - #[inline] - fn retrieve(device: &D, len: usize, _add_node: impl AddGraph) -> Buffer - where - for<'a> D: Alloc<'a, T, S>, - { - Buffer::new(device, len) - } - - #[cfg(not(feature = "no-std"))] - #[inline] - fn remove(_device: &D, _ident: Ident) {} - - #[cfg(not(feature = "no-std"))] - #[inline] - fn add_to_cache(_device: &D, _ptr: &::Ptr) -> Option { - None - } - - #[cfg(not(feature = "no-std"))] - #[inline] - unsafe fn get_existing_buf(_device: &D, _id: Ident) -> Option> { - None - } -} diff --git a/src/exec_on_cpu.rs b/src/exec_on_cpu.rs index 5bdcac86..02aa4128 100644 --- a/src/exec_on_cpu.rs +++ b/src/exec_on_cpu.rs @@ -8,7 +8,7 @@ mod cl_may_unified; #[cfg(feature = "opencl")] pub use cl_may_unified::*; -use crate::{Alloc, Buffer, Device, Read, WriteBuf, CPU}; +use crate::{Alloc, Base, Buffer, Device, Read, Retriever, WriteBuf, CPU}; /// Moves a `Buffer` stored on device `D` to a `CPU` `Buffer` /// and executes the unary operation `F` with a `CPU` on the newly created `CPU` `Buffer`. @@ -45,11 +45,11 @@ pub fn cpu_exec_unary<'a, T, D, F>( f: F, ) -> crate::Result> where - T: Clone + Default, + T: Clone + Default + 'static, F: for<'b> Fn(&'b CPU, &Buffer<'_, T, CPU>) -> Buffer<'b, T, CPU>, - D: Device + Read + WriteBuf + for<'c> Alloc<'c, T>, + D: Device + Read + WriteBuf + Alloc + Retriever, { - let cpu = CPU::new(); + let cpu = CPU::::new(); let cpu_buf = Buffer::::from((&cpu, x.read_to_vec())); Ok(Buffer::from((device, f(&cpu, &cpu_buf)))) // TODO add new node to graph @@ -67,7 +67,7 @@ where F: for<'b> Fn(&'b CPU, &mut Buffer<'_, T, CPU>), D: Read + WriteBuf, { - let cpu = CPU::new(); + let cpu = CPU::::new(); let mut cpu_buf = Buffer::::from((&cpu, x.read_to_vec())); f(&cpu, &mut cpu_buf); @@ -111,11 +111,11 @@ pub fn cpu_exec_binary<'a, T, D, F>( f: F, ) -> Buffer<'a, T, D> where - T: Clone + Default, + T: Clone + Default + 'static, F: for<'b> Fn(&'b CPU, &Buffer<'_, T, CPU>, &Buffer<'_, T, CPU>) -> Buffer<'b, T, CPU>, - D: Device + Read + WriteBuf + for<'c> Alloc<'c, T>, + D: Device + Read + WriteBuf + Alloc + Retriever, { - let cpu = CPU::new(); + let cpu = CPU::::new(); let cpu_lhs = Buffer::::from((&cpu, lhs.read_to_vec())); let cpu_rhs = Buffer::::from((&cpu, rhs.read_to_vec())); Buffer::from((device, f(&cpu, &cpu_lhs, &cpu_rhs))) @@ -134,7 +134,7 @@ where F: for<'b> Fn(&'b CPU, &mut Buffer<'_, T, CPU>, &Buffer<'_, T, CPU>), D: Read + WriteBuf, { - let cpu = CPU::new(); + let cpu = CPU::::new(); let mut cpu_lhs = Buffer::::from((&cpu, lhs.read_to_vec())); let cpu_rhs = Buffer::::from((&cpu, rhs.read_to_vec())); f(&cpu, &mut cpu_lhs, &cpu_rhs); @@ -155,7 +155,7 @@ where /// /// let device = OpenCL::new(0).unwrap(); /// -/// let cpu = CPU::new(); +/// let cpu = CPU::::new(); /// /// let lhs = Buffer::from((&device, [1, 2, 3])); /// let rhs = Buffer::from((&device, [1, 2, 3])); @@ -184,7 +184,7 @@ macro_rules! to_cpu_mut { /// /// let device = OpenCL::new(0).unwrap(); /// -/// let cpu = CPU::new(); +/// let cpu = CPU::::new(); /// /// let lhs = Buffer::from((&device, [1, 2, 3])); /// let rhs = Buffer::from((&device, [1, 2, 3])); @@ -237,7 +237,7 @@ macro_rules! to_raw_host_mut { /// let b = Buffer::new(&device, 10); /// let c = Buffer::new(&device, 10); /// -/// let cpu = CPU::new(); +/// let cpu = CPU::::new(); /// /// ``` */ @@ -271,7 +271,7 @@ where D: Read, F: Fn(&CPU, &Buffer) -> T, { - let cpu = CPU::new(); + let cpu = CPU::::new(); let cpu_x = Buffer::from((&cpu, x.read_to_vec())); f(&cpu, &cpu_x) } @@ -285,7 +285,7 @@ mod tests { let device = crate::OpenCL::new(0).unwrap(); - let cpu = CPU::new(); + let cpu = CPU::::new(); let lhs = Buffer::from((&device, [1, 2, 3])); let rhs = Buffer::from((&device, [1, 2, 3])); diff --git a/src/features.rs b/src/features.rs new file mode 100644 index 00000000..6a517e91 --- /dev/null +++ b/src/features.rs @@ -0,0 +1,93 @@ +use core::{ + any::Any, + cell::{Ref, RefMut}, +}; + +use crate::{Parents, Shape}; + +use super::{Alloc, Buffer, Device, OnDropBuffer}; + +pub trait Feature: OnDropBuffer {} + +// is a cached module is placed before Autograd results a problem +// -> the retrieved buffer is not added to the no grads pool of the autograd module +// let device = CPU::>>::new(); +// +// how to fix this: +// add retrieved buffer to no grads pool at the end of the chain (at device level (Retriever trait)) +// => "generator", "actor" +pub trait Retrieve: OnDropBuffer { + // "generator" + #[track_caller] + fn retrieve( + &self, + device: &D, + len: usize, + parents: impl Parents, + ) -> D::Data + where + T: 'static, // if 'static causes any problems -> put T to => Retrieve? + S: Shape, + D: Device + Alloc; + + // "actor" + #[inline] + fn on_retrieve_finish(&self, _retrieved_buf: &Buffer) + where + T: 'static, + D: Alloc, + { + } +} + +pub trait HasModules { + fn modules(&self) -> &Mods; +} + +pub trait TapeActions { + // "generator" - do not forget to pass down + #[inline] + fn tape(&self) -> Option> { + None + } + // "generator" - do not forget to pass down + #[inline] + fn tape_mut(&self) -> Option> { + None + } + + // use track caller to identify a specific grad function + //-> if backward is not called (.drain()), the grad fn vector will gradually fill up + #[track_caller] + fn add_grad_fn( + &self, + // ids: impl AllocGradsFrom, + grad_fn: impl Fn(&mut crate::Gradients) + 'static, + ) where + T: 'static, + Self: Device + 'static, + { + if let Some(mut tape) = self.tape_mut() { + // the type T must match for every Id! + // for id in ids.ids() { + // tape.grads.grads_pool.add_buf_once::(self, id) + // } + + tape.add_grad_fn(grad_fn) + } + } +} + +pub trait Operation { + fn forward(&mut self); +} + +pub trait AddOperation { + fn add_operation2(&self, operation: impl Operation) {} + unsafe fn add_operation( + &self, + out: &mut Buffer, + operation: impl Fn(&mut dyn Any), + ); + fn call_lazily(&self) {} +} diff --git a/src/module_comb/hooks.rs b/src/hooks.rs similarity index 55% rename from src/module_comb/hooks.rs rename to src/hooks.rs index 5def93f7..65b64a54 100644 --- a/src/module_comb/hooks.rs +++ b/src/hooks.rs @@ -1,9 +1,9 @@ use crate::Shape; -use super::{Buffer, Device, Module}; +use super::{Buffer, Device}; pub trait OnDropBuffer { - fn on_drop_buffer<'a, T, D: Device, S: Shape>(&self, device: &'a D, buf: &Buffer) {} + fn on_drop_buffer<'a, T, D: Device, S: Shape>(&self, _device: &'a D, _buf: &Buffer) {} } pub trait OnNewBuffer { diff --git a/src/module_comb/id.rs b/src/id.rs similarity index 100% rename from src/module_comb/id.rs rename to src/id.rs diff --git a/src/lib.rs b/src/lib.rs index 837cacff..3da476b0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -55,13 +55,11 @@ use core::ffi::c_void; //pub use libs::*; pub use buffer::*; -pub use count::*; pub use devices::*; pub use error::*; use flag::AllocFlag; -pub use graph::*; #[cfg(feature = "cpu")] pub use devices::cpu::CPU; @@ -92,22 +90,38 @@ pub mod exec_on_cpu; pub mod devices; mod buffer; -mod count; mod error; +mod cache; +mod device_traits; +mod features; pub mod flag; -mod graph; +// mod graph; +mod hooks; +mod id; +mod modules; mod op_traits; +mod parents; +mod ptr_conv; mod shape; mod two_way_ops; mod unary; +pub use cache::*; +pub use device_traits::*; +pub use features::*; +pub use hooks::*; +pub use id::*; +pub use modules::*; +pub use parents::*; +pub use ptr_conv::*; + #[cfg(feature = "static-api")] pub mod static_api; #[cfg(feature = "autograd")] pub mod autograd; -pub mod module_comb; +// pub mod module_comb; pub mod number; pub use op_traits::*; pub use shape::*; @@ -126,6 +140,11 @@ The automatic differentiation system requires caching of buffers, which is deact #[cfg(all(feature = "realloc", feature = "opt-cache"))] compile_error!("A typical 'cache' does not exist when the `realloc` feature is enabled."); +#[cfg(test)] +pub fn location() -> &'static core::panic::Location<'static> { + core::panic::Location::caller() +} + /// This trait is implemented for every pointer type. pub trait PtrType { /// Returns the element count. @@ -150,183 +169,15 @@ pub trait CommonPtrs { fn ptrs_mut(&mut self) -> (*mut T, *mut c_void, u64); } -/// This trait is the base trait for every device. -pub trait Device: Sized + 'static { - /// The type of the pointer that is used for `Buffer`. - type Ptr: PtrType; - /// The type of the cache. - type Cache: CacheAble; - //type Tape: ; - - /// Creates a new device. - fn new() -> crate::Result; - - /// Creates a new [`Buffer`] using `A`. - /// - /// # Example - #[cfg_attr(feature = "cpu", doc = "```")] - #[cfg_attr(not(feature = "cpu"), doc = "```ignore")] - /// use custos::{CPU, Device}; - /// - /// let device = CPU::new(); - /// let buf = device.buffer([5, 4, 3]); - /// - /// assert_eq!(buf.read(), [5, 4, 3]); - /// ``` - fn buffer<'a, T, S: Shape, A>(&'a self, arr: A) -> Buffer<'a, T, Self, S> - where - Buffer<'a, T, Self, S>: From<(&'a Self, A)>, - { - Buffer::from((self, arr)) - } - - /// May allocate a new [`Buffer`] or return an existing one. - /// It may use the cache count provided by the cache count (identified by [`Ident`]).
- /// This depends on the type of cache and enabled features.
- /// With the `realloc` feature enabled, it is guaranteed that the returned `Buffer` is newly allocated and freed every time. - /// - /// # Example - #[cfg_attr(all(feature = "cpu", not(feature = "realloc")), doc = "```")] - #[cfg_attr(all(not(feature = "cpu"), feature = "realloc"), doc = "```ignore")] - /// use custos::{Device, CPU, set_count}; - /// - /// let device = CPU::new(); - /// - /// let buf = device.retrieve::(10, ()); - /// - /// // unsafe, because the next .retrieve call will then return the same buffer - /// unsafe { set_count(0) } - /// - /// let buf_2 = device.retrieve::(10, ()); - /// - /// assert_eq!(buf.ptr.ptr, buf_2.ptr.ptr); - /// - /// ``` - #[inline] - fn retrieve(&self, len: usize, add_node: impl AddGraph) -> Buffer - where - for<'a> Self: Alloc<'a, T, S>, - { - Self::Cache::retrieve(self, len, add_node) - } - - /// May return an existing buffer using the provided [`Ident`]. - /// This function panics if no buffer with the provided `Ident` exists. - /// - /// # Safety - /// This function is unsafe because it is possible to return multiple [`Buffer`] with `Ident` that share the same memory. - /// If this function is called twice with the same `Ident`, the returned `Buffer` will be the same. - /// Even though the return `Buffer`s are owned, this does not lead to double-frees (see [`AllocFlag`]). - #[cfg(feature = "autograd")] - #[inline] - unsafe fn get_existing_buf(&self, ident: Ident) -> Buffer { - Self::Cache::get_existing_buf(self, ident).expect("A matching Buffer does not exist.") - } - - /// Removes a `Buffer` with the provided [`Ident`] from the cache. - /// This function is internally called when a `Buffer` with [`AllocFlag`] `None` is dropped. - #[cfg(not(feature = "no-std"))] - #[inline] - fn remove(&self, ident: Ident) { - Self::Cache::remove(self, ident); - } - - /// Adds a pointer that was allocated by [`Alloc`] to the cache and returns a new corresponding [`Ident`]. - /// This function is internally called when a `Buffer` with [`AllocFlag`] `None` is created. - #[cfg(not(feature = "no-std"))] - #[inline] - fn add_to_cache(&self, ptr: &Self::Ptr) -> Option { - Self::Cache::add_to_cache(self, ptr) - } -} - /// All type of devices that can create [`Buffer`]s -pub trait DevicelessAble<'a, T, S: Shape = ()>: Alloc<'a, T, S> {} +pub trait DevicelessAble<'a, T, S: Shape = ()>: Alloc {} /// Devices that can access the main memory / RAM of the host. pub trait MainMemory: Device { /// Returns the respective immutable host memory pointer - fn as_ptr(ptr: &Self::Ptr) -> *const T; + fn as_ptr(ptr: &Self::Data) -> *const T; /// Returns the respective mutable host memory pointer - fn as_ptr_mut(ptr: &mut Self::Ptr) -> *mut T; -} - -/// This trait is for allocating memory on the implemented device. -/// -/// # Example -#[cfg_attr(feature = "cpu", doc = "```")] -#[cfg_attr(not(feature = "cpu"), doc = "```ignore")] -/// use custos::{CPU, Alloc, Buffer, Read, flag::AllocFlag, GraphReturn, cpu::CPUPtr}; -/// -/// let device = CPU::new(); -/// let ptr = Alloc::::alloc(&device, 12, AllocFlag::None); -/// -/// let buf: Buffer = Buffer { -/// ident: None, -/// ptr, -/// device: Some(&device), -/// }; -/// assert_eq!(vec![0.; 12], device.read(&buf)); -/// ``` -pub trait Alloc<'a, T, S: Shape = ()>: Device { - /// Allocate memory on the implemented device. - /// # Example - #[cfg_attr(feature = "cpu", doc = "```")] - #[cfg_attr(not(feature = "cpu"), doc = "```ignore")] - /// use custos::{CPU, Alloc, Buffer, Read, flag::AllocFlag, GraphReturn, cpu::CPUPtr}; - /// - /// let device = CPU::new(); - /// let ptr = Alloc::::alloc(&device, 12, AllocFlag::None); - /// - /// let buf: Buffer = Buffer { - /// ident: None, - /// ptr, - /// device: Some(&device), - /// }; - /// assert_eq!(vec![0.; 12], device.read(&buf)); - /// ``` - fn alloc(&'a self, len: usize, flag: AllocFlag) -> ::Ptr; - - /// Allocate new memory with data - /// # Example - #[cfg_attr(feature = "cpu", doc = "```")] - #[cfg_attr(not(feature = "cpu"), doc = "```ignore")] - /// use custos::{CPU, Alloc, Buffer, Read, GraphReturn, cpu::CPUPtr}; - /// - /// let device = CPU::new(); - /// let ptr = Alloc::::with_slice(&device, &[1, 5, 4, 3, 6, 9, 0, 4]); - /// - /// let buf: Buffer = Buffer { - /// ident: None, - /// ptr, - /// device: Some(&device), - /// }; - /// assert_eq!(vec![1, 5, 4, 3, 6, 9, 0, 4], device.read(&buf)); - /// ``` - fn with_slice(&'a self, data: &[T]) -> ::Ptr - where - T: Clone; - - /// If the vector `vec` was allocated previously, this function can be used in order to reduce the amount of allocations, which may be faster than using a slice of `vec`. - #[inline] - #[cfg(not(feature = "no-std"))] - fn alloc_with_vec(&'a self, vec: Vec) -> ::Ptr - where - T: Clone, - { - self.with_slice(&vec) - } - - /// Allocates a pointer with the array provided by the `S:`[`Shape`] generic. - /// By default, the array is flattened and then passed to [`Alloc::with_slice`]. - #[inline] - fn with_array(&'a self, array: S::ARR) -> ::Ptr - where - T: Clone, - { - let stack_array = StackArray::::from_array(array); - self.with_slice(stack_array.flatten()) - } + fn as_ptr_mut(ptr: &mut Self::Data) -> *mut T; } /// If the `autograd` feature is enabled, then this will be implemented for all types that implement [`TapeReturn`]. @@ -384,16 +235,16 @@ pub mod prelude { //! Typical imports for using custos. pub use crate::{ - number::*, range, shape::*, Alloc, Buffer, CDatatype, ClearBuf, CopySlice, Device, - GraphReturn, Ident, MainMemory, MayTapeReturn, MayToCLSource, Read, ShallowCopy, WithShape, - WriteBuf, + number::*, shape::*, Alloc, Buffer, CDatatype, ClearBuf, CopySlice, Device, MainMemory, + MayTapeReturn, MayToCLSource, Read, ShallowCopy, WithShape, WriteBuf, }; #[cfg(feature = "cpu")] pub use crate::{exec_on_cpu::*, CPU}; - #[cfg(not(feature = "no-std"))] - pub use crate::{cache::CacheReturn, get_count, set_count, Cache}; + // TODO + // #[cfg(not(feature = "no-std"))] + // pub use crate::{cache::CacheReturn, get_count, set_count, Cache}; #[cfg(feature = "opencl")] pub use crate::opencl::{enqueue_kernel, CLBuffer, OpenCL, CL}; @@ -422,9 +273,9 @@ mod tests { #[cfg(feature = "cpu")] #[test] fn test_buffer_from_device() { - use crate::{Device, CPU}; + use crate::{Base, Device, CPU}; - let device = CPU::new(); + let device = CPU::::new(); let buf = device.buffer([1, 2, 3]); assert_eq!(buf.read(), [1, 2, 3]) diff --git a/src/module_comb/buffer.rs b/src/module_comb/buffer.rs index 9aa08018..23467ddf 100644 --- a/src/module_comb/buffer.rs +++ b/src/module_comb/buffer.rs @@ -1,6 +1,22 @@ use super::{Alloc, Base, Device, HasId, MainMemory, OnNewBuffer, WriteBuf, CPU}; use crate::{flag::AllocFlag, PtrType, Shape}; +/// The underlying non-growable array structure of `custos`. A `Buffer` may be encapsulated in other data structures. +/// By default, the `Buffer` is a f32 CPU Buffer with no statically known shape. +/// # Example +#[cfg_attr(feature = "cpu", doc = "```")] +#[cfg_attr(not(feature = "cpu"), doc = "```ignore")] +/// use custos::prelude::*; +/// +/// fn buffer_f32_cpu(buf: &Buffer) {} +/// fn buffer_generic(buf: &Buffer) {} +/// +/// let device = CPU::new(); +/// let buf = Buffer::from((&device, [0.5, 1.3, 3.2, 2.43])); +/// +/// buffer_f32_cpu(&buf); +/// buffer_generic(&buf); +/// ``` pub struct Buffer<'a, T = f32, D: Device = CPU, S: Shape = ()> { /// the type of pointer pub data: D::Data, @@ -9,6 +25,23 @@ pub struct Buffer<'a, T = f32, D: Device = CPU, S: Shape = ()> { } impl<'a, T, D: Device, S: Shape> Buffer<'a, T, D, S> { + /// Creates a zeroed (or values set to default) `Buffer` with the given length on the specified device. + /// This `Buffer` can't outlive the device specified as a parameter. + #[cfg_attr(feature = "cpu", doc = "```")] + #[cfg_attr(not(feature = "cpu"), doc = "```ignore")] + /// use custos::{CPU, Buffer}; + /// + /// let device = CPU::new(); + /// let mut buffer = Buffer::::new(&device, 6); + /// + /// // this only works with CPU or unified memory buffers (this creates a slice with the host pointer) + /// for value in &mut buffer { + /// *value = 2; + /// } + /// + /// assert_eq!(buffer.as_slice(), &[2; 6]); + /// + /// ``` #[inline] pub fn new(device: &'a D, len: usize) -> Self where diff --git a/src/module_comb/devices/cpu.rs b/src/module_comb/devices/cpu.rs index 6620634a..7bc64f5b 100644 --- a/src/module_comb/devices/cpu.rs +++ b/src/module_comb/devices/cpu.rs @@ -16,6 +16,20 @@ use crate::{ pub trait IsCPU {} +/// A CPU is used to perform calculations on the host CPU. +/// To make new operations invocable, a trait providing new functions should be implemented for [CPU]. +/// +/// # Example +/// ``` +/// use custos::{CPU, Read, Buffer}; +/// +/// let device = CPU::new(); +/// let a = Buffer::from((&device, [1, 2, 3])); +/// +/// let out = device.read(&a); +/// +/// assert_eq!(out, vec![1, 2, 3]); +/// ``` #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)] pub struct CPU { pub modules: Mods, @@ -76,7 +90,7 @@ impl CPU { } } -impl Alloc for CPU { +impl Allocfor CPU { type Data = CPUPtr; fn alloc(&self, mut len: usize, flag: AllocFlag) -> Self::Data { diff --git a/src/module_comb/devices/cuda.rs b/src/module_comb/devices/cuda.rs index b927570d..c8f6d206 100644 --- a/src/module_comb/devices/cuda.rs +++ b/src/module_comb/devices/cuda.rs @@ -133,7 +133,7 @@ impl Drop for CUDA { } } -impl Alloc for CUDA { +impl Allocfor CUDA { type Data = CUDAPtr; fn alloc(&self, len: usize, flag: crate::flag::AllocFlag) -> Self::Data { diff --git a/src/module_comb/devices/mod.rs b/src/module_comb/devices/mod.rs index 3002fb3e..942188f3 100644 --- a/src/module_comb/devices/mod.rs +++ b/src/module_comb/devices/mod.rs @@ -8,60 +8,4 @@ pub use cuda::*; use super::{Alloc, OnDropBuffer}; -pub trait Device: Alloc + OnDropBuffer { - type Error; - #[inline] - fn new() -> Result { - todo!() - } -} - -#[macro_export] -macro_rules! impl_buffer_hook_traits { - ($device:ident) => { - impl> OnNewBuffer - for $device - { - #[inline] - fn on_new_buffer(&self, device: &D, new_buf: &Buffer) { - self.modules.on_new_buffer(device, new_buf) - } - } - - impl OnDropBuffer for $device { - #[inline] - fn on_drop_buffer<'a, T, D: Device, S: Shape>( - &self, - device: &'a D, - buf: &Buffer, - ) { - self.modules.on_drop_buffer(device, buf) - } - } - }; -} - -#[macro_export] -macro_rules! impl_retriever { - ($device:ident) => { - impl> Retriever for $device { - #[inline] - fn retrieve( - &self, - len: usize, - parents: impl crate::module_comb::Parents, - ) -> Buffer { - let data = self - .modules - .retrieve::(self, len, parents); - let buf = Buffer { - data, - device: Some(self), - }; - self.modules.on_retrieve_finish(&buf); - buf - } - } - }; -} diff --git a/src/module_comb/features.rs b/src/module_comb/features.rs index 70ae020e..881ebab5 100644 --- a/src/module_comb/features.rs +++ b/src/module_comb/features.rs @@ -75,82 +75,6 @@ pub trait TapeActions { } } -pub trait Parents { - fn ids(self) -> [Id; N]; -} - -impl Parents<0> for () { - #[inline] - fn ids(self) -> [Id; 0] { - [] - } -} - -impl Parents<1> for Id { - #[inline] - fn ids(self) -> [Id; 1] { - [self] - } -} - -impl Parents<2> for (Id, Id) { - #[inline] - fn ids(self) -> [Id; 2] { - [self.0, self.1] - } -} - -impl Parents<3> for (Id, Id, Id) { - #[inline] - fn ids(self) -> [Id; 3] { - [self.0, self.1, self.2] - } -} - -impl Parents for [Id; N] { - #[inline] - fn ids(self) -> [Id; N] { - self - } -} - -impl Parents<1> for &Buffer<'_, T, D, S> { - #[inline] - fn ids(self) -> [Id; 1] { - [self.id()] - } -} - -impl Parents<2> - for (&Buffer<'_, T, D, S>, &Buffer<'_, T1, D1, S1>) -{ - #[inline] - fn ids(self) -> [Id; 2] { - let (lhs, rhs) = self; - [lhs.id(), rhs.id()] - } -} - -impl Parents<3> - for ( - &Buffer<'_, T, D, S>, - &Buffer<'_, T1, D1, S1>, - &Buffer<'_, T2, D2, S2>, - ) -{ - #[inline] - fn ids(self) -> [Id; 3] { - let (buf, buf1, buf2) = self; - [buf.id(), buf1.id(), buf2.id()] - } -} - -impl Parents for [&Buffer<'_, T, D, S>; N] { - #[inline] - fn ids(self) -> [Id; N] { - self.map(|buf| buf.id()) - } -} pub trait Operation { fn forward(&mut self); diff --git a/src/module_comb/mod.rs b/src/module_comb/mod.rs index ebaa19f3..947c146f 100644 --- a/src/module_comb/mod.rs +++ b/src/module_comb/mod.rs @@ -38,67 +38,6 @@ pub fn location() -> &'static core::panic::Location<'static> { core::panic::Location::caller() } -pub trait Alloc: Sized { - type Data: HasId + PtrType; - - fn alloc(&self, len: usize, flag: AllocFlag) -> Self::Data; - fn alloc_from_slice(&self, data: &[T]) -> Self::Data - where - T: Clone; - - /// If the vector `vec` was allocated previously, this function can be used in order to reduce the amount of allocations, which may be faster than using a slice of `vec`. - #[inline] - #[cfg(not(feature = "no-std"))] - fn alloc_from_vec(&self, vec: Vec) -> Self::Data - where - T: Clone, - { - self.alloc_from_slice(&vec) - } - - /// Allocates a pointer with the array provided by the `S:`[`Shape`] generic. - /// By default, the array is flattened and then passed to [`Alloc::with_slice`]. - #[inline] - fn alloc_from_array(&self, array: S::ARR) -> Self::Data - where - T: Clone, - { - let stack_array = StackArray::::from_array(array); - self.alloc_from_slice(stack_array.flatten()) - } -} - -pub trait Module { - type Module; - - fn new() -> Self::Module; -} - -/// Used for modules that should affect the device. -pub trait Setup { - #[inline] - fn setup(_device: &mut D) {} -} - -pub trait Retriever: Device { - #[track_caller] - fn retrieve( - &self, - len: usize, - parents: impl Parents, - ) -> Buffer - where - T: 'static, - S: Shape; -} - -/// Devices that can access the main memory / RAM of the host. -pub trait MainMemory: Device { - /// Returns the respective immutable host memory pointer - fn as_ptr(ptr: &Self::Data) -> *const T; - /// Returns the respective mutable host memory pointer - fn as_ptr_mut(ptr: &mut Self::Data) -> *mut T; -} #[cfg(test)] mod tests { diff --git a/src/module_comb/modules/autograd.rs b/src/module_comb/modules/autograd.rs index ba83a416..5497845c 100644 --- a/src/module_comb/modules/autograd.rs +++ b/src/module_comb/modules/autograd.rs @@ -179,7 +179,7 @@ const AUTOGRAD_NOT_AVAILABLE: &'static str = "Autograd<> is not available."; impl<'a, T, D, S> Buffer<'a, T, D, S> where T: Clone + One + 'static, - D: TapeActions + WriteBuf + Alloc + 'static, + D: TapeActions + WriteBuf + Alloc+ 'static, S: Shape, { /// Calls `.backward_seeded` on the [`Tape`]. diff --git a/src/module_comb/modules/autograd/tape.rs b/src/module_comb/modules/autograd/tape.rs index 8940e575..bfbbb488 100644 --- a/src/module_comb/modules/autograd/tape.rs +++ b/src/module_comb/modules/autograd/tape.rs @@ -64,7 +64,7 @@ impl Tape { pub fn backward_seeded(&mut self, buf: &Buffer) where T: Clone + One + 'static, - D: Alloc + WriteBuf + 'static, + D: Alloc+ WriteBuf + 'static, { let out = self.grads.get_mut::(buf.device(), buf.id()); out.write(&vec![T::one(); out.len()]); diff --git a/src/module_comb/modules/cached.rs b/src/module_comb/modules/cached.rs index e971df80..84a05401 100644 --- a/src/module_comb/modules/cached.rs +++ b/src/module_comb/modules/cached.rs @@ -64,7 +64,7 @@ impl OnDropBuffer for CachedModule { } // TODO: a more general OnDropBuffer => "Module" -impl, D: Alloc + PtrConv, SimpleDevice: Alloc + PtrConv> +impl, D: Alloc+ PtrConv, SimpleDevice: Alloc+ PtrConv> Retrieve for CachedModule { #[inline] diff --git a/src/modules/autograd.rs b/src/modules/autograd.rs new file mode 100644 index 00000000..a285ab6d --- /dev/null +++ b/src/modules/autograd.rs @@ -0,0 +1,369 @@ +mod gradients; +mod tape; + +pub use gradients::*; +pub use tape::*; + +use core::{ + any::Any, + cell::{Ref, RefCell, RefMut}, + hash::BuildHasher, + mem::transmute, +}; +use std::collections::HashMap; + +use crate::{ + flag::AllocFlag, prelude::One, Alloc, Buffer, Device, HasId, Id, Module, OnDropBuffer, + OnNewBuffer, Parents, PtrConv, Retrieve, Setup, Shape, TapeActions, UniqueId, WriteBuf, +}; + +use super::{Cached, CachedModule}; + +#[derive(Debug, Default)] +pub struct Autograd { + modules: Mods, + tape: RefCell, +} + +#[inline] +pub unsafe fn register_buf<'a, T, D, S>( + cache: &mut HashMap, impl BuildHasher>, + buf: &'a Buffer, +) where + T: 'static, + D: Device + PtrConv + 'static, + S: Shape, +{ + let wrapped_data = D::convert::(&buf.data, AllocFlag::Wrapper); + let buf = Buffer { + data: wrapped_data, + device: buf.device, + }; + let buf: Buffer<'static, T, D, S> = transmute(buf); + cache.insert(*buf.id(), Box::new(buf)); +} + +#[inline] +pub fn unregister_buf(cache: &mut HashMap, impl BuildHasher>, id: Id) { + cache.remove(&id); +} + +impl Autograd { + #[inline] + pub fn register_no_grad_buf(&self, buf: &Buffer) + where + T: 'static, + D: Device + PtrConv + 'static, + S: Shape, + { + let no_grads_pool = &mut self.tape.borrow_mut().grads.no_grads_pool.cache; + + if no_grads_pool.get(&buf.id()).is_some() { + return; + } + + unsafe { register_buf(no_grads_pool, buf) }; + } +} + +impl OnNewBuffer for Autograd +where + T: 'static, + D: Alloc + PtrConv + 'static, + S: Shape, + Mods: OnNewBuffer, +{ + #[inline] + fn on_new_buffer(&self, device: &D, new_buf: &Buffer) { + self.register_no_grad_buf(new_buf); + + // allocates gradient memory for the corresponding buffer id + self.tape + .borrow_mut() + .grads + .grads_pool + .add_buf_once::(device, new_buf.id()); + + // pass down + self.modules.on_new_buffer(device, new_buf) + } +} + +impl OnDropBuffer for Autograd { + #[inline] + fn on_drop_buffer<'a, T, D: Device, S: Shape>(&self, device: &'a D, buf: &Buffer) { + unregister_buf( + &mut self.tape.borrow_mut().grads.no_grads_pool.cache, + buf.id(), + ); + self.modules.on_drop_buffer(device, buf) + } +} + +impl, D: Device> Module for Autograd { + type Module = Autograd>; + + #[inline] + fn new() -> Self::Module { + Autograd { + modules: Cached::::new(), + tape: Default::default(), + } + } +} + +impl, NewDev> Setup for Autograd { + #[inline] + fn setup(device: &mut NewDev) { + Mods::setup(device) + } +} + +impl, D> Retrieve for Autograd +where + D: PtrConv + Device + 'static, +{ + #[inline] + fn retrieve( + &self, + device: &D, + len: usize, + parents: impl Parents, + ) -> ::Data + where + D: Alloc, + T: 'static, + S: crate::Shape, + { + self.modules.retrieve(device, len, parents) + } + + #[inline] + fn on_retrieve_finish(&self, retrieved_buf: &Buffer) + where + T: 'static, + D: Alloc, + { + self.register_no_grad_buf(retrieved_buf); + + // allocates gradients + self.tape + .borrow_mut() + .grads + .grads_pool + .add_buf_once::(retrieved_buf.device(), retrieved_buf.id()); + + self.modules.on_retrieve_finish(retrieved_buf) + } +} + +impl TapeActions for Autograd { + #[inline] + fn tape(&self) -> Option> { + Some(self.tape.borrow()) + } + + #[inline] + fn tape_mut(&self) -> Option> { + Some(self.tape.borrow_mut()) + } +} + +const AUTOGRAD_NOT_AVAILABLE: &'static str = "Autograd<> is not available."; + +impl<'a, T, D, S> Buffer<'a, T, D, S> +where + T: Clone + One + 'static, + D: TapeActions + WriteBuf + Alloc + 'static, + S: Shape, +{ + /// Calls `.backward_seeded` on the [`Tape`]. + /// The seed of the gradient is set to `1` and contains `self.len()` elements. + #[inline] + pub fn backward(&self) { + if let Some(mut tape) = self.device().tape_mut() { + tape.backward_seeded(self) + } + } + + /// Returns a reference to the gradient of this buffer. + /// The lifetime is bound to the lifetime of self, which is more strict and catches some mistakes at compile-time. + /// However, If the borrow checker complains and you are sure that everything should be fine, use `grad_unbound` instead. + /// + /// Panics if the gradient was not allocated. + #[inline] + pub fn grad(&self) -> Ref { + self.grad_unbound() + } + + /// Returns a reference to the gradient of this buffer. + /// Lifetimes are checked during runtime with `RefCell`. + /// Panics if the gradient was not allocated. + // TODO: Maybe return Result with two error variants? + #[inline] + pub fn grad_unbound(&self) -> Ref<'a, Self> { + Ref::map( + self.device().tape().expect(AUTOGRAD_NOT_AVAILABLE), + |tape| { + tape.grads.may_get_ref(self.id()).expect( + "Gradient was not allocated for this buffer. Did you forget to call `backward`?", + ) + }, + ) + } + + /// Returns a mutable reference to the gradient of this buffer. + /// The lifetime is bound to the lifetime of self, which is more strict. + /// If the borrow checker complains, use `grad_mut_unbound` instead. + /// Panics if the gradient was not allocated. + // TODO: Maybe return Result with two error variants? + #[inline] + pub fn grad_mut(&mut self) -> RefMut { + self.grad_mut_unbound() + } + + /// Returns a mutable reference to the gradient of this buffer. + /// Lifetimes are checked during runtime. + /// Panics if the gradient was not allocated. + // TODO: Maybe return Result with two error variants? + #[inline] + pub fn grad_mut_unbound(&mut self) -> RefMut<'a, Self> { + RefMut::map( + self.device().tape_mut().expect(AUTOGRAD_NOT_AVAILABLE), + |tape| { + tape.grads.may_get_mut(self.id()).expect( + "Gradient was not allocated for this buffer. Did you forget to call `backward`?", + ) + }, + ) + } +} + +#[cfg(test)] +mod tests { + use core::any::Any; + + use crate::{Base, Buffer, Cached, Device, HasId, Retriever, Shape, CPU}; + + use super::Autograd; + + #[inline] + pub fn downcast_val<'a, 'b, T: 'static, D: Device + 'static, S: Shape>( + buf_any: &'b Box, + _device: &'a D, + ) -> Option<&'b Buffer<'a, T, D, S>> { + buf_any.downcast_ref::>() + } + + #[test] + fn test_buffer_creation_autograd_register_manual() { + let device = CPU::>::new(); + let buf: Buffer = Buffer::::new(&device, 10); + + let autograd = &device.modules; + { + let no_grads_pool = &mut autograd.tape.borrow_mut().grads.no_grads_pool; + let buf_any = no_grads_pool.cache.get(&buf.id()).unwrap(); + + let buf1 = downcast_val::(buf_any, &device).unwrap(); + assert_eq!(buf1.data.ptr, buf.data.ptr); + } + } + + #[test] + fn test_buffer_creation_autograd_get_buf() { + let device = CPU::>::new(); + let buf: Buffer = Buffer::::new(&device, 10); + + let autograd = &device.modules; + { + let no_grads_pool = &mut autograd.tape.borrow_mut().grads.no_grads_pool; + let buf1 = no_grads_pool + .get_buf_with_dev::(buf.id(), &device) + .unwrap(); + assert_eq!(buf1.data.ptr, buf.data.ptr); + } + } + + #[test] + fn test_buffer_creation_autograd_unregister() { + let device = CPU::>::new(); + let buf: Buffer = Buffer::::new(&device, 10); + let id = buf.id(); + let autograd = &device.modules; + + drop(buf); + + { + let no_grads_pool = &autograd.tape.borrow_mut().grads.no_grads_pool; + assert!(no_grads_pool.cache.get(&id).is_none()); + } + } + + #[test] + fn test_buffer_new_and_retrieve() { + let device = CPU::>::new(); + let _lhs = Buffer::::new(&device, 10); + + for _ in 0..100 { + let x = device.retrieve::(100, ()); + assert_eq!(x.len(), 100) + } + + let no_grads_pool = &device.modules.tape.borrow().grads.no_grads_pool; + assert_eq!(no_grads_pool.cache.len(), 2); + } + + #[test] + fn test_cached_before_autograd() { + // is a cached module is placed before Autograd results a problem + // -> the retrieved buffer is not added to the no grads pool of the autograd module + let device = CPU::>>::new(); + + // how to fix this: + // add retrieved buffer to no grads pool at the end of the chain (at device level (Retriever trait)) + // => "generator", "actor" + + let _lhs = Buffer::::new(&device, 10); + + for _ in 0..100 { + let x = device.retrieve::(100, ()); + assert_eq!(x.len(), 100) + } + + let no_grads_pool = &device.modules.modules.tape.borrow().grads.no_grads_pool; + assert_eq!(no_grads_pool.cache.len(), 2); + } + + #[test] + #[should_panic] + fn test_tape_return_without_autograd() { + let device = CPU::::new(); + let buf = Buffer::::new(&device, 10); + buf.grad(); + } + + #[test] + #[should_panic] + fn test_tape_return_without_grad_allocation() { + let device = CPU::>::new(); + let buf = Buffer::::new(&device, 10); + buf.grad(); + } + + #[test] + fn test_tape_return_with_grad_allocation() { + let device = CPU::>::new(); + let buf = Buffer::::new(&device, 10); + + // allocates a new gradient buffer if none exists for the specified id + device + .modules + .tape + .borrow_mut() + .grads + .get_mut::(&device, buf.id()); + + buf.grad(); + } +} diff --git a/src/modules/autograd/gradients.rs b/src/modules/autograd/gradients.rs new file mode 100644 index 00000000..a84a70a7 --- /dev/null +++ b/src/modules/autograd/gradients.rs @@ -0,0 +1,214 @@ +use crate::{ + Alloc, Base, BorrowCache, Buffer, CachingError, Device, HasId, Id, Parents, Shape, CPU, +}; + +const INVALID_ID: &'static str = "A matching Buffer does not exist."; + +/// A cache for gradients. +/// The cache is populated by `get_ref`, `get_like` or `get_mut_ref` calls. +#[derive(Default)] +pub struct Gradients { + pub grads_pool: BorrowCache, + pub no_grads_pool: BorrowCache, +} + +impl core::fmt::Debug for Gradients { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("Gradients") + .field("cache", &self.grads_pool) + .finish() + } +} + +type LhsRhsOut<'a, 'b, T, D, S> = ( + &'b Buffer<'a, T, D, S>, + &'b Buffer<'a, T, D, S>, + &'b mut Buffer<'a, T, D, S>, + &'b mut Buffer<'a, T, D, S>, + &'b Buffer<'a, T, D, S>, +); + +impl Gradients { + /// Clears the cache. + #[inline] + pub fn zero_grad(&mut self) { + self.grads_pool.cache.clear(); + } + + /// May get a reference to a gradient [`Buffer`]. + #[inline] + pub fn may_get_ref<'a, T, S, D>(&self, ident: Id) -> Result<&Buffer<'a, T, D, S>, CachingError> + where + T: 'static, + S: Shape, + D: Alloc + 'static, + { + self.grads_pool.get_buf(ident) + } + + /// May get a mutable reference to a gradient [`Buffer`]. + #[inline] + pub fn may_get_mut<'a, T, S, D>( + &mut self, + id: Id, + ) -> Result<&mut Buffer<'a, T, D, S>, CachingError> + where + T: 'static, + S: Shape, + D: Alloc + 'static, + { + self.grads_pool.get_buf_mut(id) + } + + /// Returns a reference to a gradient [`Buffer`]. + /// Allocates a gradient [`Buffer`] if it does not exist. + #[inline] + pub fn get_ref<'a, T, S, D>(&mut self, device: &'a D, id: Id) -> &Buffer<'a, T, D, S> + where + T: 'static, + S: Shape, + D: Alloc + 'static, + { + self.grads_pool.add_or_get(device, id) + } + + /// Returns a mutable reference to a gradient [`Buffer`]. + /// Allocates a gradient [`Buffer`] if it does not exist. + #[inline] + pub fn get_mut<'a, T, S, D>(&mut self, device: &'a D, id: Id) -> &mut Buffer<'a, T, D, S> + where + T: 'static, + S: Shape, + D: Alloc + 'static, + { + self.grads_pool.add_or_get_mut(device, id) + } + + /// Returns a reference to a gradient [`Buffer`] using information from `buf`. + #[inline] + pub fn get_like<'a, T, S, D>(&mut self, buf: &Buffer<'a, T, D, S>) -> &Buffer<'a, T, D, S> + where + T: 'static, + S: Shape, + D: Alloc + 'static, + D::Data: HasId, + { + self.get_ref(buf.device(), buf.id()) + } + + #[inline] + pub fn get_buf_from_no_grad_pool<'a, T, S, D>(&self, id: Id) -> &Buffer<'a, T, D, S> + where + T: 'static, + S: Shape, + D: Alloc + 'static, + { + self.no_grads_pool.get_buf::(id).expect(INVALID_ID) + } + + /// Returns the forward [`Buffer`]s lhs and and rhs, and the gradient `Buffer`s lhs_grad, rhs_grad and out_grad. + /// Usefull for binary operations. + #[inline] + pub fn get_triple<'a, T, S, D>( + &mut self, + device: &'a D, + (lid, rid, oid): (Id, Id, Id), + ) -> LhsRhsOut<'a, '_, T, D, S> + where + T: 'static, + S: Shape, + D: Alloc + 'static, + { + self.grads_pool.add_buf_once::(device, rid); + self.grads_pool.add_buf_once::(device, oid); + let lhs_grad_ptr = self.get_mut(device, lid) as *mut _; + let lhs_grad = unsafe { &mut *lhs_grad_ptr }; + + let rhs_grad_ptr = self.get_mut(device, rid) as *mut _; + let rhs_grad = unsafe { &mut *rhs_grad_ptr }; + ( + self.get_buf_from_no_grad_pool(lid), + self.get_buf_from_no_grad_pool(rid), + lhs_grad, + rhs_grad, + self.may_get_ref(oid).unwrap(), + ) + } + + /// Returns the forward [`Buffer`] x and the gradient `Buffer`s x_grad and out_grad. + /// Useful for unary operations. + /// + #[inline] + pub fn get_double<'a, T, IS, OS, D>( + &mut self, + // device: &'a D, + parents: impl Parents<2>, + // (xid, oid): (Id, Id), + ) -> ( + &Buffer<'a, T, D, IS>, + &mut Buffer<'a, T, D, IS>, + &Buffer<'a, T, D, OS>, + ) + where + T: 'static, + IS: Shape, + OS: Shape, + D: Alloc + 'static, + { + let [xid, oid] = parents.ids(); + // self.grads_pool.add_buf_once::(device, oid); + + // let x_grad_ptr = self.get_mut(device, xid) as *mut _; + let x_grad_ptr = self.may_get_mut(xid).unwrap() as *mut _; + let x_grad_mut = unsafe { &mut *x_grad_ptr }; + let o_grad = self.may_get_ref(oid).unwrap(); + + (self.get_buf_from_no_grad_pool(xid), x_grad_mut, o_grad) + } +} + +#[cfg(test)] +mod tests { + use crate::{Autograd, Base, Buffer, HasId, Retriever, CPU}; + + #[test] + fn test_same_types_get_double_return() { + let device = CPU::>::new(); + + // let mut gradients = Gradients::default(); + + let buf = Buffer::::new(&device, 10); + // unsafe { register_buf(&mut gradients.no_grads_pool.borrow_mut().cache, &buf) } + + let out = device.retrieve::(buf.len(), ()); + // unsafe { register_buf(&mut gradients.no_grads_pool.borrow_mut().cache, &out) } + + device + .modules + .tape + .borrow_mut() + .grads + .get_double::>>>>>((buf.id(), out.id())); + } + + #[test] + #[should_panic] + fn test_different_types_get_double_return() { + let device = CPU::>::new(); + + // let mut gradients = Gradients::default(); + + let buf = Buffer::::new(&device, 10); + // unsafe { register_buf(&mut gradients.no_grads_pool.borrow_mut().cache, &buf) } + + let out = device.retrieve::(buf.len(), ()); + // unsafe { register_buf(&mut gradients.no_grads_pool.borrow_mut().cache, &out) } + + device + .modules + .tape + .borrow_mut() + .grads + .get_double::>>>>>((buf.id(), out.id())); + } +} diff --git a/src/modules/autograd/tape.rs b/src/modules/autograd/tape.rs new file mode 100644 index 00000000..bc327fa1 --- /dev/null +++ b/src/modules/autograd/tape.rs @@ -0,0 +1,72 @@ +use core::{fmt::Debug, hash::BuildHasherDefault, panic::Location}; +use std::collections::HashMap; + +use crate::{ + prelude::One, Alloc, Buffer, Device, HasId, HashLocation, LocationHasher, Shape, WriteBuf, +}; + +use super::Gradients; + +// does not require the device param ??? +pub type GradFn = Box; + +/// Stores the grad functions and gradient cache. +#[derive(Default)] +pub struct Tape { + /// Caches gradients for each [`Buffer`]'s id ([`Ident`]). + pub grads: Gradients, + grad_fns: Vec, + grad_fns_loc: HashMap, GradFn, BuildHasherDefault>, + grad_fn_order: Vec>, +} + +impl Debug for Tape { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("Tape") + .field("grads", &self.grads) + .field("grad_fns", &self.grad_fns.len()) + .finish() + } +} + +impl Tape { + /// Adds a gradient function to the tape. + #[inline] + #[track_caller] + pub fn add_grad_fn(&mut self, grad_fn: F) { + let hash_location = Location::caller().into(); + + if self.grad_fns_loc.contains_key(&hash_location) { + return; + } + + self.grad_fns_loc.insert(hash_location, Box::new(grad_fn)); + self.grad_fn_order.push(hash_location) + + // self.grad_fns.push(Box::new(grad_fn)) + } + + /// Calls all gradient functions in reverse order. + pub fn backward(&mut self, device: &D) { + for grad_fn_id in self.grad_fn_order.iter().rev() { + let grad_fn = self.grad_fns_loc.get(grad_fn_id).unwrap(); + grad_fn(&mut self.grads); + } + /*for grad_fn in self.grad_fns.drain(..).rev() { + grad_fn(&mut self.grads); + }*/ + } + + /// Backward pass with seeded gradient. + /// The seed of the gradient contains `buf.len()` elements, all of them are set to 1. + pub fn backward_seeded(&mut self, buf: &Buffer) + where + T: Clone + One + 'static, + D: Alloc + WriteBuf + 'static, + { + let out = self.grads.get_mut::(buf.device(), buf.id()); + out.write(&vec![T::one(); out.len()]); + + self.backward(buf.device()) + } +} diff --git a/src/modules/base.rs b/src/modules/base.rs new file mode 100644 index 00000000..0d329a87 --- /dev/null +++ b/src/modules/base.rs @@ -0,0 +1,59 @@ +use core::any::Any; + +use crate::{ + flag::AllocFlag, AddOperation, Alloc, Buffer, Device, Module, OnDropBuffer, OnNewBuffer, + Parents, Retrieve, Setup, Shape, TapeActions, +}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)] +pub struct Base; + +impl Module for Base { + type Module = Base; + + #[inline] + fn new() -> Self::Module { + Base + } +} + +impl AddOperation for Base { + #[inline] + unsafe fn add_operation( + &self, + out: &mut Buffer, + operation: impl Fn(&mut dyn Any), + ) { + let out: &mut Buffer = unsafe { std::mem::transmute(out) }; + operation(out); + } + + #[inline] + fn add_operation2(&self, mut operation: impl crate::Operation) { + operation.forward() + } +} + +impl Setup for Base {} + +impl OnNewBuffer for Base {} + +impl OnDropBuffer for Base {} + +impl Retrieve for Base { + #[inline] + fn retrieve( + &self, + device: &D, + len: usize, + _parents: impl Parents, + ) -> ::Data + where + S: crate::Shape, + D: Alloc, + { + device.alloc(len, AllocFlag::None) + } +} + +impl TapeActions for Base {} diff --git a/src/modules/cached.rs b/src/modules/cached.rs new file mode 100644 index 00000000..832314e9 --- /dev/null +++ b/src/modules/cached.rs @@ -0,0 +1,254 @@ +use core::{cell::RefCell, marker::PhantomData}; + +use crate::{ + Alloc, Buffer, Cache, Device, Module, OnDropBuffer, OnNewBuffer, Parents, PtrConv, Retrieve, + Setup, Shape, TapeActions, +}; + +// creator struct +#[derive(Debug, PartialEq, Eq, Default)] +pub struct Cached { + pd: PhantomData, +} + +/*impl Retrieve for Cached { + fn retrieve(&self, device: &D, len: usize) -> ::Data + where + D: Alloc + { + todo!() + } +}*/ + +impl, D: Device> Module for Cached { + type Module = CachedModule; + + fn new() -> Self::Module { + CachedModule { + modules: Mods::new(), + cache: RefCell::new(Cache { + nodes: Default::default(), + }), + } + } +} + +impl OnDropBuffer for Cached {} + +pub struct CachedModule { + pub modules: Mods, + cache: RefCell>, +} + +impl, D: Device, NewDev> Setup for CachedModule { + #[inline] + fn setup(device: &mut NewDev) { + Mods::setup(device) + } +} + +impl, SD: Device> OnNewBuffer + for CachedModule +{ + fn on_new_buffer(&self, device: &D, new_buf: &Buffer) { + self.modules.on_new_buffer(device, new_buf) + } +} + +impl OnDropBuffer for CachedModule { + #[inline] + fn on_drop_buffer<'a, T, D: Device, S: Shape>(&self, device: &'a D, buf: &Buffer) { + self.modules.on_drop_buffer(device, buf) + } +} + +// TODO: a more general OnDropBuffer => "Module" +impl, D: Device + PtrConv, SimpleDevice: Device + PtrConv> + Retrieve for CachedModule +{ + #[inline] + fn retrieve( + &self, + device: &D, + len: usize, + _parents: impl Parents, + ) -> D::Data + where + D: Alloc, + { + self.cache.borrow_mut().get(device, len, || ()) + } + + #[inline] + fn on_retrieve_finish(&self, retrieved_buf: &Buffer) + where + T: 'static, + D: Alloc, + { + self.modules.on_retrieve_finish(retrieved_buf) + } +} + +impl TapeActions for CachedModule { + #[inline] + fn tape(&self) -> Option> { + self.modules.tape() + } + + #[inline] + fn tape_mut(&self) -> Option> { + self.modules.tape_mut() + } +} + +#[macro_export] +macro_rules! debug_assert_tracked { + () => { + #[cfg(debug_assertions)] + { + let location = core::panic::Location::caller(); + assert_ne!( + (file!(), line!(), column!()), + (location.file(), location.line(), location.column()), + "Function and operation must be annotated with `#[track_caller]`." + ); + } + }; +} + +/// This macro is nothing but a mechanism to ensure that the specific operation is annotated with `#[track_caller]`. +/// If the operation is not annotated with `#[track_caller]`, then the macro will cause a panic (in debug mode). +/// +/// This macro turns the device, length and optionally type information into the following line of code: +/// ## From: +/// ```ignore +/// retrieve!(device, 10, f32) +/// ``` +/// ## To: +/// ```ignore +/// custos::debug_assert_tracked!(); +/// device.retrieve::(10) +/// ``` +/// +/// If you ensure that the operation is annotated with `#[track_caller]`, then you can just write the following: +/// ```ignore +/// device.retrieve::(10) +/// ``` +/// +/// # Example +/// Operation is not annotated with `#[track_caller]` and therefore will panic: +/// ```should_panic +/// use custos::{retrieve, module_comb::{CPU, Retriever, Buffer, Retrieve, Cached, Base}}; +/// +/// fn add_bufs>>(device: &CPU) -> Buffer, ()> { +/// retrieve!(device, 10, f32) +/// } +/// +/// let device = CPU::>::new(); +/// add_bufs(&device); +/// ``` +/// Operation is annotated with `#[track_caller]`: +/// ``` +/// use custos::{Dim1, retrieve, module_comb::{CPU, Retriever, Buffer, Retrieve, Cached, Base}}; +/// +/// #[track_caller] +/// fn add_bufs>>(device: &CPU) -> Buffer, Dim1<30>> { +/// retrieve!(device, 10, f32, Dim1<30>); // you can also specify the shape +/// retrieve!(device, 10) // or infer the type and shape from the output type +/// } +/// +/// let device = CPU::>::new(); +/// add_bufs(&device); +/// ``` +#[macro_export] +macro_rules! retrieve { + ($device:ident, $len:expr, $parents:expr) => {{ + $crate::debug_assert_tracked!(); + $device.retrieve($len, $parents) + }}; /*($device:ident, $len:expr, $dtype:ty, ) => {{ + $crate::debug_assert_tracked!(); + $device.retrieve::<$dtype, ()>($len) + }}; + ($device:ident, $len:expr, $dtype:ty, $shape:ty) => {{ + $crate::debug_assert_tracked!(); + $device.retrieve::<$dtype, $shape>($len) + }};*/ +} + +#[cfg(test)] +mod tests { + use core::{panic::Location, ptr::addr_of}; + + // crate::modules + use crate::{location, Base, Buffer, Retrieve, Retriever, CPU}; + + use super::Cached; + + // forgot to add track_caller + #[cfg(debug_assertions)] + fn add_bufs>>(device: &CPU) -> Buffer, ()> { + retrieve!(device, 10, ()) + } + + #[test] + #[cfg(debug_assertions)] + #[should_panic] + fn test_forgot_track_caller_runtime_detection() { + let device = CPU::>::new(); + + let _out = add_bufs(&device); + let _out = add_bufs(&device); + } + + #[track_caller] + fn add_bufs_tracked>>( + device: &CPU, + ) -> Buffer, ()> { + retrieve!(device, 10, ()) + } + + #[test] + fn test_added_track_caller() { + let device = CPU::>::new(); + + let _out = add_bufs_tracked(&device); + let _out = add_bufs_tracked(&device); + } + + #[test] + fn test_location_ref_unique() { + let ptr = location(); + let ptr1 = location(); + // bad + assert_ne!(addr_of!(ptr), addr_of!(ptr1)); + } + + #[track_caller] + fn location_tracked() -> &'static Location<'static> { + Location::caller() + } + + #[test] + fn test_location_file_ptr_unique() { + let ptr = location(); + let ptr1 = location(); + // good + assert_eq!(ptr.file().as_ptr(), ptr1.file().as_ptr()); + } + + #[test] + fn test_location_file_tracked_ptr_unique() { + let ptr = location_tracked(); + let ptr1 = location_tracked(); + // good + assert_eq!(ptr.file().as_ptr(), ptr1.file().as_ptr()); + } + + #[test] + fn test_location_with_different_file_location_ptr_unique() { + let ptr = location_tracked(); + let ptr1 = location(); + // good + assert_ne!(ptr.file().as_ptr(), ptr1.file().as_ptr()); + } +} diff --git a/src/modules/graph.rs b/src/modules/graph.rs new file mode 100644 index 00000000..7b59e432 --- /dev/null +++ b/src/modules/graph.rs @@ -0,0 +1,4 @@ +#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)] +pub struct Graph { + modules: Mods, +} diff --git a/src/modules/lazy.rs b/src/modules/lazy.rs new file mode 100644 index 00000000..f1b159a9 --- /dev/null +++ b/src/modules/lazy.rs @@ -0,0 +1,180 @@ +use core::{any::Any, cell::RefCell, fmt::Debug, hash::BuildHasherDefault}; +use std::collections::HashMap; + +use crate::{ + AddOperation, Alloc, Buffer, Device, HasId, Id, Module, NoHasher, OnDropBuffer, OnNewBuffer, + Operation, Parents, PtrConv, Retrieve, Setup, Shape, TapeActions, UniqueId, +}; + +use super::register_buf; + +#[derive(Default)] +pub struct Lazy { + mods: Mods, + outs: RefCell, BuildHasherDefault>>, + ops: RefCell>>, + out_ids: RefCell>, + ops2: RefCell>>, +} + +impl Debug for Lazy { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("Lazy") + .field("mods", &self.mods) + .field("ops_count", &self.ops.borrow().len()) + .finish() + } +} + +pub trait LazySetup { + fn lazy_setup(&mut self) {} +} + +impl, D: LazySetup> Module for Lazy { + type Module = Lazy; + + #[inline] + fn new() -> Self::Module { + Lazy { + mods: Mods::new(), + outs: Default::default(), + ops: Default::default(), + out_ids: Default::default(), + ops2: Default::default(), + } + } +} + +impl AddOperation for Lazy { + #[inline] + unsafe fn add_operation( + &self, + out: &mut Buffer, + operation: impl Fn(&mut dyn Any), + ) { + // operation(out); + self.out_ids.borrow_mut().push(out.id()); + let operation: Box = Box::new(operation); + let operation: Box = + unsafe { std::mem::transmute(operation) }; + self.ops.borrow_mut().push(operation) + } + + #[inline] + fn add_operation2(&self, operation: impl Operation) { + let operation: Box = Box::new(operation); + let operation: Box = unsafe { std::mem::transmute(operation) }; + self.ops2.borrow_mut().push(operation) + } + + #[inline] + fn call_lazily(&self) { + for (op, out_id) in self.ops.borrow().iter().zip(self.out_ids.borrow().iter()) { + let mut outs = self.outs.borrow_mut(); + let out = &mut **outs.get_mut(out_id).unwrap(); + op(out) + } + } +} + +impl> Setup for Lazy { + #[inline] + fn setup(device: &mut D) { + device.lazy_setup(); + Mods::setup(device) + } +} + +impl OnDropBuffer for Lazy { + #[inline] + fn on_drop_buffer<'a, T, D: Device, S: Shape>(&self, device: &'a D, buf: &Buffer) { + super::unregister_buf(&mut self.outs.borrow_mut(), buf.id()); + self.mods.on_drop_buffer(device, buf) + } +} + +impl> + OnNewBuffer for Lazy +{ + #[inline] + fn on_new_buffer(&self, device: &D, new_buf: &Buffer) { + unsafe { super::register_buf(&mut self.outs.borrow_mut(), new_buf) }; + self.mods.on_new_buffer(device, new_buf) + } +} + +impl TapeActions for Lazy { + #[inline] + fn tape(&self) -> Option> { + self.mods.tape() + } + + #[inline] + fn tape_mut(&self) -> Option> { + self.mods.tape_mut() + } +} + +impl, D: PtrConv + 'static> Retrieve for Lazy { + #[inline] + fn retrieve( + &self, + device: &D, + len: usize, + parents: impl Parents, + ) -> ::Data + where + T: 'static, + S: Shape, + D: Alloc, + { + self.mods.retrieve(device, len, parents) + } + + #[inline] + fn on_retrieve_finish(&self, retrieved_buf: &Buffer) + where + T: 'static, + D: Alloc, + { + unsafe { register_buf(&mut self.outs.borrow_mut(), retrieved_buf) }; + + // pass down + self.mods.on_retrieve_finish(retrieved_buf) + } +} + +#[cfg(test)] +mod tests { + use crate::{AddOperation, Alloc, Base, Buffer, Combiner, CPU}; + + use super::Lazy; + + #[test] + fn test_lazy_device_use() { + // let device = CPU::>::new(); + // let data = device.alloc::(10, crate::flag::AllocFlag::None); + } + + #[test] + fn test_lazy_device_use_cuda() { + // let device = CUDA::>::new(); + // let data = device.alloc::(10, crate::flag::AllocFlag::None); + } + + use crate::ApplyFunction; + + #[test] + fn test_lazy_execution() { + let device = CPU::>::new(); + + let buf = Buffer::::new(&device, 10); + let out = device.apply_fn(&buf, |x| x.add(3.)); + + device.call_lazily(); + println!("out: {:?}", &*out); + + drop(out); + drop(buf); + } +} diff --git a/src/modules/mod.rs b/src/modules/mod.rs new file mode 100644 index 00000000..8477fdcf --- /dev/null +++ b/src/modules/mod.rs @@ -0,0 +1,14 @@ +mod autograd; +pub use autograd::*; + +mod base; +pub use base::*; + +mod cached; +pub use cached::*; + +mod graph; +pub use graph::*; + +mod lazy; +pub use lazy::*; diff --git a/src/op_traits.rs b/src/op_traits.rs index 2a2451ad..74abcb6a 100644 --- a/src/op_traits.rs +++ b/src/op_traits.rs @@ -1,6 +1,6 @@ use core::ops::{Bound, Range, RangeBounds}; -use crate::{shape::Shape, Alloc, Buffer, Device}; +use crate::{shape::Shape, Alloc, Buffer, Device, OnDropBuffer, OnNewBuffer}; /// Trait for implementing the clear() operation for the compute devices. pub trait ClearBuf { @@ -40,7 +40,7 @@ pub trait CopySlice: Sized + Device { range: R, ) -> Buffer<'a, T, Self> where - Self: for<'b> Alloc<'b, T>, + Self: Alloc + OnDropBuffer + OnNewBuffer, { let range = bounds_to_range(range, buf.len()); let mut copied = Buffer::new(self, range.end - range.start); diff --git a/src/parents.rs b/src/parents.rs new file mode 100644 index 00000000..741608a7 --- /dev/null +++ b/src/parents.rs @@ -0,0 +1,78 @@ +use crate::{Buffer, Device, HasId, Id, Shape}; + +pub trait Parents { + fn ids(self) -> [Id; N]; +} + +impl Parents<0> for () { + #[inline] + fn ids(self) -> [Id; 0] { + [] + } +} + +impl Parents<1> for Id { + #[inline] + fn ids(self) -> [Id; 1] { + [self] + } +} + +impl Parents<2> for (Id, Id) { + #[inline] + fn ids(self) -> [Id; 2] { + [self.0, self.1] + } +} + +impl Parents<3> for (Id, Id, Id) { + #[inline] + fn ids(self) -> [Id; 3] { + [self.0, self.1, self.2] + } +} + +impl Parents for [Id; N] { + #[inline] + fn ids(self) -> [Id; N] { + self + } +} + +impl Parents<1> for &Buffer<'_, T, D, S> { + #[inline] + fn ids(self) -> [Id; 1] { + [self.id()] + } +} + +impl Parents<2> + for (&Buffer<'_, T, D, S>, &Buffer<'_, T1, D1, S1>) +{ + #[inline] + fn ids(self) -> [Id; 2] { + let (lhs, rhs) = self; + [lhs.id(), rhs.id()] + } +} + +impl Parents<3> + for ( + &Buffer<'_, T, D, S>, + &Buffer<'_, T1, D1, S1>, + &Buffer<'_, T2, D2, S2>, + ) +{ + #[inline] + fn ids(self) -> [Id; 3] { + let (buf, buf1, buf2) = self; + [buf.id(), buf1.id(), buf2.id()] + } +} + +impl Parents for [&Buffer<'_, T, D, S>; N] { + #[inline] + fn ids(self) -> [Id; N] { + self.map(|buf| buf.id()) + } +} diff --git a/src/module_comb/ptr_conv.rs b/src/ptr_conv.rs similarity index 71% rename from src/module_comb/ptr_conv.rs rename to src/ptr_conv.rs index dfeec922..dbf410eb 100644 --- a/src/module_comb/ptr_conv.rs +++ b/src/ptr_conv.rs @@ -1,10 +1,10 @@ use core::mem::{align_of, size_of}; -use crate::{cpu::CPUPtr, flag::AllocFlag, Shape}; +use crate::{cpu::CPUPtr, flag::AllocFlag, Device, OnDropBuffer, Shape}; -use super::{Alloc, CPU}; +use super::CPU; -pub trait PtrConv: Alloc { +pub trait PtrConv: Device { unsafe fn convert( data: &Self::Data, flag: AllocFlag, @@ -12,7 +12,7 @@ pub trait PtrConv: Alloc { } // impl for all devices -impl PtrConv> for CPU { +impl PtrConv> for CPU { #[inline] unsafe fn convert( data: &CPUPtr, diff --git a/src/shape.rs b/src/shape.rs index 109520f4..051c78ae 100644 --- a/src/shape.rs +++ b/src/shape.rs @@ -24,7 +24,7 @@ impl Shape for () { pub trait IsShapeIndep: Device {} #[cfg(not(feature = "no-std"))] -impl IsShapeIndep for D {} +impl IsShapeIndep for D {} /// If the [`Shape`] is provides a fixed size, than this trait should be implemented. /// Forgot how this is useful. @@ -90,16 +90,16 @@ impl Shape for Dim3 { pub trait ToDim: crate::Device { /// Converts a pointer to a different [`Shape`]. /// This is only possible for [`Buffer`](crate::Buffer)s that are not allocated on the stack. - fn to_dim(&self, ptr: Self::Ptr) -> Self::Ptr; + fn to_dim(&self, ptr: Self::Data) -> Self::Data; } #[cfg(not(feature = "no-std"))] -impl ToDim for D +impl ToDim for D where - Self::Ptr: crate::PtrType, + Self::Data: crate::PtrType, { #[inline] - fn to_dim(&self, ptr: Self::Ptr) -> D::Ptr { + fn to_dim(&self, ptr: Self::Data) -> D::Data { // resources are now mananged by the destructed raw pointer (prevents double free). let ptr = core::mem::ManuallyDrop::new(ptr); @@ -166,9 +166,9 @@ mod tests { #[cfg(feature = "cpu")] #[test] fn test_transmute_of_stackless_buf() { - use crate::{Buffer, CPU}; + use crate::{Base, Buffer, CPU}; - let device = CPU::new(); + let device = CPU::::new(); let buf = Buffer::>::new(&device, 10); let other_buf = unsafe { diff --git a/src/two_way_ops/mod.rs b/src/two_way_ops/mod.rs index 4b2908e2..9f12db43 100644 --- a/src/two_way_ops/mod.rs +++ b/src/two_way_ops/mod.rs @@ -322,9 +322,9 @@ mod tests { #[cfg(all(feature = "cpu", feature = "macro"))] #[test] fn test_apply_fn_cpu() { - use crate::{ApplyFunction, Buffer, Combiner, CPU}; + use crate::{ApplyFunction, Base, Buffer, Combiner, CPU}; - let device = CPU::new(); + let device = CPU::::new(); let buf = Buffer::from((&device, &[3, 3, 4, 5, 3, 2])); @@ -350,9 +350,9 @@ mod tests { #[cfg(all(feature = "cpu", feature = "macro"))] #[test] fn test_run_apply_fn_cpu_more_complex() { - use crate::{ApplyFunction, Buffer, CPU}; + use crate::{ApplyFunction, Base, Buffer, CPU}; - let device = CPU::new(); + let device = CPU::::new(); let buf = Buffer::from((&device, &[3., 3., 4., 5., 3., 2.])); diff --git a/src/unary.rs b/src/unary.rs index d6b09150..77310566 100644 --- a/src/unary.rs +++ b/src/unary.rs @@ -90,7 +90,7 @@ impl UnaryElementWiseMayGrad for D where T: 'static, D: ApplyFunction + UnaryGrad + MayTapeReturn, - D: for<'b> Alloc<'b, T, S> + 'static, + D: Alloc + 'static, S: Shape, { #[inline(always)]