From ac76de9250ab3e40d5d36818884ee4b6d02d88a2 Mon Sep 17 00:00:00 2001
From: elftausend <76885970+elftausend@users.noreply.github.com>
Date: Mon, 14 Aug 2023 23:01:59 +0200
Subject: [PATCH] Move changes made in module_comb into main structure
---
Cargo.toml | 2 +-
src/buffer.rs | 281 ++++++++---------
src/buffer/impl_from.rs | 35 +--
src/buffer/impl_from_const.rs | 16 +-
src/buffer/num.rs | 69 +++--
src/cache.rs | 134 ++++++++
src/cache/borrow_cache.rs | 190 ++++++++++++
src/cache/location_hasher.rs | 73 +++++
src/cache/nohasher.rs | 31 ++
src/count.rs | 173 -----------
src/device_traits.rs | 88 ++++++
src/devices.rs | 112 +++++++
src/devices/addons.rs | 124 --------
src/devices/borrowing_cache.rs | 123 --------
src/devices/cache.rs | 321 --------------------
src/devices/caller_cache.rs | 166 ----------
src/devices/cpu/cpu_device.rs | 128 ++++----
src/devices/cpu/mod.rs | 5 +-
src/devices/cpu/ops.rs | 33 +-
src/devices/cpu_stack_ops.rs | 11 +-
src/devices/ident.rs | 61 ----
src/devices/mod.rs | 155 ----------
src/exec_on_cpu.rs | 28 +-
src/features.rs | 93 ++++++
src/{module_comb => }/hooks.rs | 4 +-
src/{module_comb => }/id.rs | 0
src/lib.rs | 217 +++----------
src/module_comb/buffer.rs | 33 ++
src/module_comb/devices/cpu.rs | 16 +-
src/module_comb/devices/cuda.rs | 2 +-
src/module_comb/devices/mod.rs | 56 ----
src/module_comb/features.rs | 76 -----
src/module_comb/mod.rs | 61 ----
src/module_comb/modules/autograd.rs | 2 +-
src/module_comb/modules/autograd/tape.rs | 2 +-
src/module_comb/modules/cached.rs | 2 +-
src/modules/autograd.rs | 369 +++++++++++++++++++++++
src/modules/autograd/gradients.rs | 214 +++++++++++++
src/modules/autograd/tape.rs | 72 +++++
src/modules/base.rs | 59 ++++
src/modules/cached.rs | 254 ++++++++++++++++
src/modules/graph.rs | 4 +
src/modules/lazy.rs | 180 +++++++++++
src/modules/mod.rs | 14 +
src/op_traits.rs | 4 +-
src/parents.rs | 78 +++++
src/{module_comb => }/ptr_conv.rs | 8 +-
src/shape.rs | 14 +-
src/two_way_ops/mod.rs | 8 +-
src/unary.rs | 2 +-
50 files changed, 2394 insertions(+), 1809 deletions(-)
create mode 100644 src/cache.rs
create mode 100644 src/cache/borrow_cache.rs
create mode 100644 src/cache/location_hasher.rs
create mode 100644 src/cache/nohasher.rs
delete mode 100644 src/count.rs
create mode 100644 src/device_traits.rs
create mode 100644 src/devices.rs
delete mode 100644 src/devices/addons.rs
delete mode 100644 src/devices/borrowing_cache.rs
delete mode 100644 src/devices/cache.rs
delete mode 100644 src/devices/caller_cache.rs
delete mode 100644 src/devices/ident.rs
delete mode 100644 src/devices/mod.rs
create mode 100644 src/features.rs
rename src/{module_comb => }/hooks.rs (55%)
rename src/{module_comb => }/id.rs (100%)
create mode 100644 src/modules/autograd.rs
create mode 100644 src/modules/autograd/gradients.rs
create mode 100644 src/modules/autograd/tape.rs
create mode 100644 src/modules/base.rs
create mode 100644 src/modules/cached.rs
create mode 100644 src/modules/graph.rs
create mode 100644 src/modules/lazy.rs
create mode 100644 src/modules/mod.rs
create mode 100644 src/parents.rs
rename src/{module_comb => }/ptr_conv.rs (71%)
diff --git a/Cargo.toml b/Cargo.toml
index b643cbb5..aa49df4e 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -35,7 +35,7 @@ min-cl = { version = "0.2.0", optional=true }
[features]
#default = ["no-std"]
-default = ["blas", "cpu", "stack", "static-api", "macro", "autograd", "cuda"]
+default = ["blas", "cpu", "macro"]
#default = ["stack", "macro", "cpu", "blas", "opencl", "static-api", "autograd"]
#default = ["stack", "cpu", "blas", "static-api", "opencl", "macro"]
cpu = []
diff --git a/src/buffer.rs b/src/buffer.rs
index f29cfa2f..11eca8f5 100644
--- a/src/buffer.rs
+++ b/src/buffer.rs
@@ -7,8 +7,9 @@ use crate::cpu::{CPUPtr, CPU};
use crate::CPU;
use crate::{
- flag::AllocFlag, shape::Shape, Alloc, ClearBuf, CloneBuf, CommonPtrs, Device, DevicelessAble,
- Ident, IsShapeIndep, MainMemory, PtrType, Read, ShallowCopy, WriteBuf,
+ flag::AllocFlag, shape::Shape, Alloc, Base, ClearBuf, CloneBuf, CommonPtrs, Device,
+ DevicelessAble, HasId, IsShapeIndep, MainMemory, OnNewBuffer, PtrType, Read, ShallowCopy,
+ WriteBuf,
};
pub use self::num::Num;
@@ -34,20 +35,13 @@ mod num;
/// buffer_f32_cpu(&buf);
/// buffer_generic(&buf);
/// ```
-pub struct Buffer<'a, T = f32, D: Device = CPU, S: Shape = ()> {
+pub struct Buffer<'a, T = f32, D: Device = CPU, S: Shape = ()> {
/// the type of pointer
- pub ptr: D::Ptr,
+ pub data: D::Data,
/// A reference to the corresponding device. Mainly used for operations without a device parameter.
pub device: Option<&'a D>,
- /// Used as a cache and autograd identifier.
- #[cfg(not(feature = "no-std"))]
- pub ident: Option,
}
-unsafe impl<'a, T, D: Device, S: Shape> Send for Buffer<'a, T, D, S> {}
-
-unsafe impl<'a, T, D: Device, S: Shape> Sync for Buffer<'a, T, D, S> {}
-
impl<'a, T, D: Device, S: Shape> Buffer<'a, T, D, S> {
/// Creates a zeroed (or values set to default) `Buffer` with the given length on the specified device.
/// This `Buffer` can't outlive the device specified as a parameter.
@@ -67,25 +61,92 @@ impl<'a, T, D: Device, S: Shape> Buffer<'a, T, D, S> {
///
/// ```
#[inline]
- pub fn new(device: &'a D, len: usize) -> Buffer<'a, T, D, S>
+ pub fn new(device: &'a D, len: usize) -> Self
where
- D: Alloc<'a, T, S>, /*+ GraphReturn*/
+ D: OnNewBuffer + Alloc,
{
- let ptr = device.alloc(len, AllocFlag::None);
-
- #[cfg(not(feature = "no-std"))]
- let ident = device.add_to_cache(&ptr);
+ let data = device.alloc(len, crate::flag::AllocFlag::None);
+ Buffer::from_new_alloc(device, data)
+ }
- Buffer {
- ptr,
+ #[inline]
+ fn from_new_alloc(device: &'a D, data: D::Data) -> Self
+ where
+ D: OnNewBuffer,
+ {
+ let buf = Buffer {
+ data,
device: Some(device),
- // TODO: enable, if leafs get more important
- //node: device.graph().add_leaf(len),
- #[cfg(not(feature = "no-std"))]
- ident,
+ };
+
+ // mind: on_new_buffer must be called for user buffers!
+ device.on_new_buffer(device, &buf);
+ buf
+ }
+}
+
+impl<'a, T, D: Device, S: Shape> HasId for Buffer<'a, T, D, S> {
+ #[inline]
+ fn id(&self) -> super::Id {
+ self.data.id()
+ }
+}
+
+impl<'a, T, D: Device, S: Shape> Drop for Buffer<'a, T, D, S> {
+ #[inline]
+ fn drop(&mut self) {
+ if self.data.flag() != AllocFlag::None {
+ return;
}
+
+ if let Some(device) = self.device {
+ device.on_drop_buffer(device, self)
+ }
+ }
+}
+
+impl<'a, T, D: Device + OnNewBuffer, S: Shape> Buffer<'a, T, D, S> {
+ /// Creates a new `Buffer` from a slice (&[T]).
+ #[inline]
+ pub fn from_slice(device: &'a D, slice: &[T]) -> Self
+ where
+ T: Clone,
+ D: Alloc,
+ {
+ let data = device.alloc_from_slice(slice);
+ Buffer::from_new_alloc(device, data)
+ }
+
+ /// Creates a new `Buffer` from a `Vec`.
+ #[cfg(not(feature = "no-std"))]
+ #[inline]
+ pub fn from_vec(device: &'a D, data: Vec) -> Self
+ where
+ T: Clone,
+ D: Alloc,
+ {
+ let data = device.alloc_from_vec(data);
+ Buffer::from_new_alloc(device, data)
+ }
+
+ /// Creates a new `Buffer` from an nd-array.
+ /// The dimension is defined by the [`Shape`].
+ #[inline]
+ pub fn from_array(device: &'a D, array: S::ARR) -> Buffer
+ where
+ T: Clone,
+ D: Alloc,
+ {
+ let data = device.alloc_from_array(array);
+ Buffer::from_new_alloc(device, data)
}
+}
+
+unsafe impl<'a, T, D: Device, S: Shape> Send for Buffer<'a, T, D, S> {}
+unsafe impl<'a, T, D: Device, S: Shape> Sync for Buffer<'a, T, D, S> {}
+
+impl<'a, T, D: Device, S: Shape> Buffer<'a, T, D, S> {
/// Buffers created with this method can outlive the device used to create this `Buffer`.
/// No operations can be performed on this `Buffer` without a device parameter.
/// # Examples
@@ -109,9 +170,7 @@ impl<'a, T, D: Device, S: Shape> Buffer<'a, T, D, S> {
D: DevicelessAble<'b, T, S>,
{
Buffer {
- ptr: device.alloc(len, AllocFlag::None),
- #[cfg(not(feature = "no-std"))]
- ident: None,
+ data: device.alloc(len, AllocFlag::None),
device: None,
}
}
@@ -199,7 +258,7 @@ impl<'a, T, D: Device, S: Shape> Buffer<'a, T, D, S> {
/// ```
#[inline]
pub fn len(&self) -> usize {
- self.ptr.size()
+ self.data.size()
}
/// Creates a shallow copy of &self.
@@ -211,13 +270,11 @@ impl<'a, T, D: Device, S: Shape> Buffer<'a, T, D, S> {
#[inline]
pub unsafe fn shallow(&self) -> Buffer<'a, T, D, S>
where
- ::Ptr: ShallowCopy,
+ ::Data: ShallowCopy,
{
Buffer {
- ptr: self.ptr.shallow(),
+ data: self.data.shallow(),
device: self.device,
- #[cfg(not(feature = "no-std"))]
- ident: self.ident,
}
}
@@ -230,7 +287,7 @@ impl<'a, T, D: Device, S: Shape> Buffer<'a, T, D, S> {
/// Furthermore, the resulting `Buffer` can outlive `self`.
pub unsafe fn shallow_or_clone(&self) -> Buffer<'a, T, D, S>
where
- ::Ptr: ShallowCopy,
+ ::Data: ShallowCopy,
T: Clone,
D: CloneBuf<'a, T, S>,
{
@@ -243,20 +300,6 @@ impl<'a, T, D: Device, S: Shape> Buffer<'a, T, D, S> {
self.clone()
}
- /// Returns the [`Ident`] of a `Buffer`.
- /// A `Buffer` receives an id, if it is useable for caching, graph optimization or autograd.
- /// Panics, if `Buffer` hasn't an id.
- #[inline]
- pub fn id(&self) -> Ident {
- #[cfg(feature = "no-std")]
- {
- unimplemented!("This buffer has no trackable id. Who?: e.g. 'Stack' Buffer, Buffers created via Buffer::from_raw_host..(..), `Num` (scalar) Buffer")
- }
-
- #[cfg(not(feature = "no-std"))]
- self.ident.expect("This buffer has no trackable id. Who?: e.g. 'Stack' Buffer, Buffers created via Buffer::from_raw_host..(..), `Num` (scalar) Buffer")
- }
-
/// Sets all elements in `Buffer` to the default value.
pub fn clear(&mut self)
where
@@ -266,22 +309,6 @@ impl<'a, T, D: Device, S: Shape> Buffer<'a, T, D, S> {
}
}
-impl<'a, T, D: Device, S: Shape> Drop for Buffer<'a, T, D, S> {
- #[inline]
- fn drop(&mut self) {
- if self.ptr.flag() != AllocFlag::None {
- return;
- }
-
- #[cfg(not(feature = "no-std"))]
- if let Some(device) = self.device {
- if let Some(ident) = self.ident {
- device.remove(ident)
- }
- }
- }
-}
-
// TODO better solution for the to_dims stack problem?
impl<'a, T, D: Device, S: Shape> Buffer<'a, T, D, S> {
/// Converts a non stack allocated `Buffer` with shape `S` to a `Buffer` with shape `O`.
@@ -299,17 +326,15 @@ impl<'a, T, D: Device, S: Shape> Buffer<'a, T, D, S> {
pub fn to_dims(self) -> Buffer<'a, T, D, O>
where
D: crate::ToDim,
- D::Ptr: ShallowCopy,
+ D::Data: ShallowCopy,
{
let buf = ManuallyDrop::new(self);
- let ptr = buf.device().to_dim(unsafe { buf.ptr.shallow() });
+ let data = buf.device().to_dim(unsafe { buf.data.shallow() });
Buffer {
- ptr,
+ data,
device: buf.device,
- #[cfg(not(feature = "no-std"))]
- ident: buf.ident,
}
}
}
@@ -335,18 +360,18 @@ impl<'a, T, D: IsShapeIndep, S: Shape> Buffer<'a, T, D, S> {
impl<'a, T, D: Device, S: Shape> Buffer<'a, T, D, S>
where
- D::Ptr: CommonPtrs,
+ D::Data: CommonPtrs,
{
#[inline]
/// Returns all types of pointers. (host, OpenCL, CUDA)
pub fn ptrs(&self) -> (*const T, *mut c_void, u64) {
- self.ptr.ptrs()
+ self.data.ptrs()
}
#[inline]
/// Returns all types of pointers. (host, OpenCL, CUDA)
pub fn ptrs_mut(&mut self) -> (*mut T, *mut c_void, u64) {
- self.ptr.ptrs_mut()
+ self.data.ptrs_mut()
}
}
@@ -365,73 +390,6 @@ impl<'a, T, D: Device> Buffer<'a, T, D> {
}
}
-impl<'a, T, D: Device, S: Shape> Buffer<'a, T, D, S> {
- /// Creates a new `Buffer` from a slice (&[T]).
- /// The pointer of the allocation may be added to the cache of the device.
- /// Usually, this pointer / `Buffer` is then returned by a `device.get_existing_buf(..)` (accesses the cache) call.
- #[inline]
- pub fn from_slice(device: &'a D, slice: &[T]) -> Self
- where
- T: Clone,
- D: Alloc<'a, T, S>,
- {
- let ptr = device.with_slice(slice);
-
- #[cfg(not(feature = "no-std"))]
- let ident = device.add_to_cache(&ptr);
-
- Buffer {
- ptr,
- #[cfg(not(feature = "no-std"))]
- ident,
- device: Some(device),
- }
- }
-
- /// Creates a new `Buffer` from a `Vec`.
- /// The pointer of the allocation may be added to the cache of the device.
- /// Usually, this pointer / `Buffer` is then returned by a `device.get_existing_buf(..)` call.
- #[cfg(not(feature = "no-std"))]
- #[inline]
- pub fn from_vec(device: &'a D, data: Vec) -> Self
- where
- T: Clone,
- D: Alloc<'a, T, S>,
- {
- let ptr = device.alloc_with_vec(data);
- let ident = device.add_to_cache(&ptr);
-
- Buffer {
- ptr,
- ident,
- device: Some(device),
- }
- }
-
- /// Creates a new `Buffer` from an nd-array.
- /// The dimension is defined by the [`Shape`].
- /// The pointer of the allocation may be added to the cache of the device.
- /// Usually, this pointer / `Buffer` is then returned by a `device.get_existing_buf(..)` call.
- #[inline]
- pub fn from_array(device: &'a D, array: S::ARR) -> Buffer
- where
- T: Clone,
- D: Alloc<'a, T, S>,
- {
- let ptr = device.with_array(array);
-
- #[cfg(not(feature = "no-std"))]
- let ident = device.add_to_cache(&ptr);
-
- Buffer {
- ptr,
- #[cfg(not(feature = "no-std"))]
- ident,
- device: Some(device),
- }
- }
-}
-
#[cfg(feature = "cpu")]
impl<'a, T, S: Shape> Buffer<'a, T, CPU, S> {
/// Constructs a deviceless `Buffer` out of a host pointer and a length.
@@ -458,9 +416,8 @@ impl<'a, T, S: Shape> Buffer<'a, T, CPU, S> {
#[inline]
pub unsafe fn from_raw_host(ptr: *mut T, len: usize) -> Buffer<'a, T, CPU, S> {
Buffer {
- ptr: CPUPtr::from_ptr(ptr, len, AllocFlag::Wrapper),
+ data: CPUPtr::from_ptr(ptr, len, AllocFlag::Wrapper),
device: None,
- ident: None,
}
}
@@ -477,9 +434,8 @@ impl<'a, T, S: Shape> Buffer<'a, T, CPU, S> {
len: usize,
) -> Buffer<'a, T, CPU, S> {
Buffer {
- ptr: CPUPtr::from_ptr(ptr, len, AllocFlag::Wrapper),
+ data: CPUPtr::from_ptr(ptr, len, AllocFlag::Wrapper),
device: Some(device),
- ident: None,
}
}
}
@@ -490,7 +446,7 @@ impl<'a, T, S: Shape> Buffer<'a, T, crate::OpenCL, S> {
#[inline]
pub fn cl_ptr(&self) -> *mut c_void {
assert!(
- !self.ptr.ptr.is_null(),
+ !self.data.ptr.is_null(),
"called cl_ptr() on an invalid OpenCL buffer"
);
self.ptrs().1
@@ -499,7 +455,7 @@ impl<'a, T, S: Shape> Buffer<'a, T, crate::OpenCL, S> {
#[cfg(feature = "cuda")]
impl<'a, T> Buffer<'a, T, crate::CUDA> {
- // TODO: replace buf.ptr.2 with this fn, do the same with cl, cpu
+ // TODO: replace buf.data.2 with this fn, do the same with cl, cpu
/// Returns a non null CUDA pointer
#[inline]
pub fn cu_ptr(&self) -> u64 {
@@ -507,7 +463,7 @@ impl<'a, T> Buffer<'a, T, crate::CUDA> {
self.ptrs().2 != 0,
"called cu_ptr() on an invalid CUDA buffer"
);
- self.ptr.ptr
+ self.data.ptr
}
}
@@ -525,9 +481,10 @@ impl<'a, T, D: MainMemory, S: Shape> Buffer<'a, T, D, S> {
}
}
+// custos v0.5 compatability
impl<'a, T, D: MainMemory, S: Shape> Buffer<'a, T, D, S>
where
- D::Ptr: CommonPtrs,
+ D::Data: CommonPtrs,
{
/// Returns a non null host pointer
#[inline]
@@ -568,14 +525,12 @@ unsafe impl Sync for Buffer<'a, T> {}*/
impl<'a, T, D: Device, S: Shape> Default for Buffer<'a, T, D, S>
where
- D::Ptr: Default,
+ D::Data: Default,
{
fn default() -> Self {
Self {
- ptr: D::Ptr::::default(),
+ data: D::Data::::default(),
device: None,
- #[cfg(not(feature = "no-std"))]
- ident: None,
}
}
}
@@ -594,7 +549,7 @@ impl AsMut<[T]> for Buffer<'_, T, D> {
}
}
-/// A `Buffer` dereferences into a slice.
+/// A main memory `Buffer` dereferences into a slice.
///
/// # Examples
///
@@ -623,11 +578,11 @@ impl core::ops::Deref for Buffer<'_, T, D, S> {
#[inline]
fn deref(&self) -> &Self::Target {
- unsafe { core::slice::from_raw_parts(D::as_ptr(&self.ptr), self.len()) }
+ unsafe { core::slice::from_raw_parts(D::as_ptr(&self.data), self.len()) }
}
}
-/// A `Buffer` dereferences into a mutable slice.
+/// A main memory `Buffer` dereferences into a mutable slice.
///
/// # Examples
///
@@ -652,7 +607,7 @@ impl core::ops::Deref for Buffer<'_, T, D, S> {
impl core::ops::DerefMut for Buffer<'_, T, D, S> {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
- unsafe { core::slice::from_raw_parts_mut(D::as_ptr_mut(&mut self.ptr), self.len()) }
+ unsafe { core::slice::from_raw_parts_mut(D::as_ptr_mut(&mut self.data), self.len()) }
}
}
@@ -665,7 +620,7 @@ where
T: Debug + Default + Clone + 'a,
D: Read + Device + 'a,
for<'b> >::Read<'b>: Debug,
- D::Ptr: CommonPtrs,
+ D::Data: CommonPtrs,
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("Buffer")
@@ -726,7 +681,9 @@ mod tests {
#[cfg(feature = "cpu")]
#[test]
fn test_deref() {
- let device = crate::CPU::new();
+ use crate::Base;
+
+ let device = crate::CPU::::new();
let buf: Buffer = Buffer::from((&device, [1, 2, 3, 4]));
let slice = &*buf;
assert_eq!(slice, &[1, 2, 3, 4]);
@@ -762,7 +719,9 @@ mod tests {
#[cfg(feature = "cpu")]
#[test]
fn test_debug_print() {
- let device = crate::CPU::new();
+ use crate::Base;
+
+ let device = crate::CPU::::new();
let buf = Buffer::from((&device, [1, 2, 3, 4, 5, 6]));
println!("{buf:?}",);
@@ -771,9 +730,9 @@ mod tests {
#[cfg(feature = "cpu")]
#[test]
fn test_to_dims() {
- use crate::Dim2;
+ use crate::{Base, Dim2};
- let device = crate::CPU::new();
+ let device = crate::CPU::::new();
let buf = Buffer::from((&device, [1, 2, 3, 4, 5, 6]));
let buf_dim2 = buf.to_dims::>();
@@ -783,12 +742,12 @@ mod tests {
#[cfg(feature = "cpu")]
#[test]
fn test_id_cpu() {
- use crate::{Ident, CPU};
+ use crate::{Base, HasId, CPU};
- let device = CPU::new();
+ let device = CPU::::new();
let buf = Buffer::from((&device, [1, 2, 3, 4]));
- assert_eq!(buf.id(), Ident { idx: 0, len: 4 })
+ assert_eq!(buf.id(), buf.data.id())
}
#[cfg(feature = "stack")]
diff --git a/src/buffer/impl_from.rs b/src/buffer/impl_from.rs
index 0005ac48..80e319d6 100644
--- a/src/buffer/impl_from.rs
+++ b/src/buffer/impl_from.rs
@@ -1,6 +1,6 @@
use core::ops::Range;
-use crate::{number::Number, shape::Shape, Alloc, Buffer};
+use crate::{number::Number, shape::Shape, Alloc, Buffer, Device, OnNewBuffer, Retriever};
#[cfg(feature = "cpu")]
use crate::{WriteBuf, CPU};
@@ -9,7 +9,7 @@ impl<'a, T, D, const N: usize> From<(&'a D, [T; N])> for Buffer<'a, T, D>
where
T: Clone,
// TODO: IsShapeIndep ... find way to include Stack
- D: Alloc<'a, T>,
+ D: Alloc + OnNewBuffer,
{
#[inline]
fn from((device, array): (&'a D, [T; N])) -> Self {
@@ -19,7 +19,7 @@ where
impl<'a, T, D> From<(&'a D, usize)> for Buffer<'a, T, D>
where
- D: Alloc<'a, T>,
+ D: Alloc + OnNewBuffer,
{
#[inline]
fn from((device, len): (&'a D, usize)) -> Self {
@@ -30,7 +30,7 @@ where
/*impl<'a, T, D> Buffer<'a, T, D>
where
T: Clone,
- D: Alloc<'a, T>
+ D: Alloc+ OnNewBuffer
{
#[inline]
pub fn from_iter>(device: &'a D, iter: I) -> Self {
@@ -42,7 +42,7 @@ where
impl<'a, T, D> From<(&'a D, Range)> for Buffer<'a, T, D>
where
T: Number,
- D: Alloc<'a, T>,
+ D: Alloc + OnNewBuffer,
{
#[inline]
fn from((device, range): (&'a D, Range)) -> Self {
@@ -55,7 +55,7 @@ where
impl<'a, T, D, I> From<(&'a D, I)> for Buffer<'a, T, D>
where
T: Number,
- D: Alloc<'a, T>,
+ D: Alloc+ OnNewBuffer,
I: IntoIterator- ,
{
#[inline]
@@ -67,7 +67,7 @@ where
/*impl<'a, T, D, const N: usize> From<(&'a D, [T; N])> for Buffer<'a, T, D>
where
T: Clone,
- D: Alloc<'a, T> + IsShapeIndep,
+ D: Alloc+ OnNewBuffer + IsShapeIndep,
{
fn from((device, array): (&'a D, [T; N])) -> Self {
Buffer {
@@ -84,7 +84,7 @@ impl<'a, T, D, const N: usize> From<(&'a D, &[T; N])> for Buffer<'a, T, D>
where
T: Clone,
// TODO: IsShapeIndep ... find way to include Stack
- D: Alloc<'a, T>,
+ D: Alloc + OnNewBuffer,
{
#[inline]
fn from((device, array): (&'a D, &[T; N])) -> Self {
@@ -95,7 +95,7 @@ where
/*impl<'a, T, D, const N: usize> From<(&'a D, &[T; N])> for Buffer<'a, T, D>
where
T: Clone,
- D: Alloc<'a, T> + IsShapeIndep,
+ D: Alloc+ OnNewBuffer + IsShapeIndep,
{
fn from((device, array): (&'a D, &[T; N])) -> Self {
Buffer {
@@ -112,7 +112,7 @@ impl<'a, T, D, S: Shape> From<(&'a D, &[T])> for Buffer<'a, T, D, S>
where
T: Clone,
// TODO: IsShapeIndep ... find way to include Stack
- D: Alloc<'a, T, S>,
+ D: Alloc + OnNewBuffer,
{
#[inline]
fn from((device, slice): (&'a D, &[T])) -> Self {
@@ -123,7 +123,7 @@ where
/*impl<'a, T, D, S: Shape> From<(&'a D, &[T])> for Buffer<'a, T, D, S>
where
T: Clone,
- D: Alloc<'a, T, S> + IsShapeIndep,
+ D: Alloc+ OnNewBuffer + IsShapeIndep,
{
fn from((device, slice): (&'a D, &[T])) -> Self {
Buffer {
@@ -140,7 +140,7 @@ impl<'a, T, D, S: Shape> From<(&'a D, Vec)> for Buffer<'a, T, D, S>
where
T: Clone,
// TODO: IsShapeIndep ... find way to include Stack
- D: Alloc<'a, T, S>,
+ D: Alloc + OnNewBuffer,
{
#[inline]
fn from((device, vec): (&'a D, Vec)) -> Self {
@@ -153,7 +153,7 @@ impl<'a, T, D, S: Shape> From<(&'a D, &Vec)> for Buffer<'a, T, D, S>
where
T: Clone,
// TODO: IsShapeIndep ... find way to include Stack
- D: Alloc<'a, T, S>,
+ D: Alloc + OnNewBuffer,
{
#[inline]
fn from((device, vec): (&'a D, &Vec)) -> Self {
@@ -164,8 +164,9 @@ where
#[cfg(feature = "cpu")]
impl<'a, 'b, T, S, D> From<(&'a D, Buffer<'b, T, CPU, S>)> for Buffer<'a, T, D, S>
where
+ T: 'static,
S: Shape,
- D: WriteBuf + for<'c> Alloc<'c, T, S>,
+ D: WriteBuf + Device + Retriever,
{
fn from((device, buf): (&'a D, Buffer<'b, T, CPU, S>)) -> Self {
let mut out = device.retrieve(buf.len(), &buf);
@@ -180,11 +181,11 @@ mod tests {
#[cfg(feature = "cpu")]
#[test]
fn test_buf_device_conversion_cpu() {
- use crate::{Buffer, Read, CPU};
+ use crate::{Base, Buffer, Read, CPU};
- let device = CPU::new();
+ let device = CPU::::new();
- let cpu = CPU::new();
+ let cpu = CPU::::new();
let cpu_buf = Buffer::from((&cpu, [1, 2, 4, 5]));
let out = Buffer::from((&device, cpu_buf));
diff --git a/src/buffer/impl_from_const.rs b/src/buffer/impl_from_const.rs
index 4ba6724d..4fd0bca2 100644
--- a/src/buffer/impl_from_const.rs
+++ b/src/buffer/impl_from_const.rs
@@ -1,4 +1,4 @@
-use crate::{prelude::Number, shape::Shape, Alloc, Buffer, Dim1, Dim2};
+use crate::{prelude::Number, shape::Shape, Alloc, Buffer, Device, Dim1, Dim2, OnNewBuffer};
/// Trait for creating [`Buffer`]s with a [`Shape`]. The [`Shape`] is inferred from the array.
pub trait WithShape {
@@ -20,7 +20,7 @@ pub trait WithShape {
impl<'a, T, D, const N: usize> WithShape<&'a D, [T; N]> for Buffer<'a, T, D, Dim1>
where
T: Number, // using Number here, because T could be an array type
- D: Alloc<'a, T, Dim1>,
+ D: Alloc + OnNewBuffer>,
{
#[inline]
fn with(device: &'a D, array: [T; N]) -> Self {
@@ -31,7 +31,7 @@ where
impl<'a, T, D, const N: usize> WithShape<&'a D, &[T; N]> for Buffer<'a, T, D, Dim1>
where
T: Number,
- D: Alloc<'a, T, Dim1>,
+ D: Alloc + OnNewBuffer>,
{
#[inline]
fn with(device: &'a D, array: &[T; N]) -> Self {
@@ -43,7 +43,7 @@ impl<'a, T, D, const B: usize, const A: usize> WithShape<&'a D, [[T; A]; B]>
for Buffer<'a, T, D, Dim2>
where
T: Number,
- D: Alloc<'a, T, Dim2>,
+ D: Alloc + OnNewBuffer>,
{
#[inline]
fn with(device: &'a D, array: [[T; A]; B]) -> Self {
@@ -55,7 +55,7 @@ impl<'a, T, D, const B: usize, const A: usize> WithShape<&'a D, &[[T; A]; B]>
for Buffer<'a, T, D, Dim2>
where
T: Number,
- D: Alloc<'a, T, Dim2>,
+ D: Alloc + OnNewBuffer>,
{
#[inline]
fn with(device: &'a D, array: &[[T; A]; B]) -> Self {
@@ -65,7 +65,7 @@ where
impl<'a, T, D, S: Shape> WithShape<&'a D, ()> for Buffer<'a, T, D, S>
where
- D: Alloc<'a, T, S>,
+ D: Alloc + OnNewBuffer,
{
fn with(device: &'a D, _: ()) -> Self {
Buffer::new(device, S::LEN)
@@ -77,9 +77,9 @@ mod tests {
#[cfg(feature = "cpu")]
#[test]
fn test_with_const_dim2_cpu() {
- use crate::{Buffer, WithShape, CPU};
+ use crate::{Base, Buffer, WithShape, CPU};
- let device = CPU::new();
+ let device = CPU::::new();
let buf = Buffer::with(&device, [[1.0, 2.0], [3.0, 4.0]]);
diff --git a/src/buffer/num.rs b/src/buffer/num.rs
index f64f9e60..3b9afda3 100644
--- a/src/buffer/num.rs
+++ b/src/buffer/num.rs
@@ -1,11 +1,15 @@
use core::{
+ convert::Infallible,
ffi::c_void,
ops::{Deref, DerefMut},
ptr::null_mut,
};
-use crate::{shape::Shape, Buffer, CloneBuf, CommonPtrs, Device, PtrType};
+use crate::{
+ flag::AllocFlag, Alloc, Buffer, CloneBuf, CommonPtrs, Device, HasId, OnDropBuffer, PtrType,
+};
+#[derive(Debug, Default)]
/// Makes it possible to use a single number in a [`Buffer`].
pub struct Num {
/// The stored number.
@@ -36,25 +40,53 @@ impl CommonPtrs for Num {
}
}
+impl HasId for Num {
+ fn id(&self) -> crate::Id {
+ todo!()
+ }
+}
+
+impl From for Num {
+ #[inline]
+ fn from(num: T) -> Self {
+ Num { num }
+ }
+}
+
impl Device for () {
- type Ptr = Num;
- type Cache = ();
+ type Data = Num;
+ type Error = Infallible;
- fn new() -> crate::Result {
+ fn new() -> Result {
Ok(())
}
}
+impl Alloc for () {
+ #[inline]
+ fn alloc(&self, _len: usize, _flag: AllocFlag) -> Self::Data {
+ Num::default()
+ }
+
+ #[inline]
+ fn alloc_from_slice(&self, data: &[T]) -> Self::Data
+ where
+ T: Clone,
+ {
+ data[0].clone().into()
+ }
+}
+
+impl OnDropBuffer for () {}
+
impl<'a, T: Clone> CloneBuf<'a, T> for () {
#[inline]
fn clone_buf(&self, buf: &Buffer<'a, T, Self>) -> Buffer<'a, T, Self> {
Buffer {
- ptr: Num {
- num: buf.ptr.num.clone(),
+ data: Num {
+ num: buf.data.num.clone(),
},
device: buf.device,
- #[cfg(not(feature = "no-std"))]
- ident: buf.ident,
}
}
}
@@ -63,10 +95,8 @@ impl From for Buffer<'_, T, ()> {
#[inline]
fn from(ptr: T) -> Self {
Buffer {
- ptr: Num { num: ptr },
+ data: Num { num: ptr },
device: None,
- #[cfg(not(feature = "no-std"))]
- ident: None,
}
}
}
@@ -80,10 +110,8 @@ impl<'a, T> Buffer<'a, T, ()> {
T: Copy,
{
Buffer {
- ptr: Num { num: self.ptr.num },
+ data: Num { num: self.data.num },
device: self.device,
- #[cfg(not(feature = "no-std"))]
- ident: self.ident,
}
}
@@ -105,7 +133,7 @@ impl<'a, T> Buffer<'a, T, ()> {
where
T: Copy,
{
- self.ptr.num
+ self.data.num
}
}
@@ -114,14 +142,14 @@ impl<'a, T> Deref for Buffer<'a, T, ()> {
#[inline]
fn deref(&self) -> &Self::Target {
- &self.ptr.num
+ &self.data.num
}
}
impl<'a, T> DerefMut for Buffer<'a, T, ()> {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
- &mut self.ptr.num
+ &mut self.data.num
}
}
@@ -144,4 +172,11 @@ mod tests {
*a += 10;
assert_eq!(*a, 15);
}
+
+ #[test]
+ fn test_num_device() {
+ use crate::Device;
+
+ let _device = <()>::new().unwrap();
+ }
}
diff --git a/src/cache.rs b/src/cache.rs
new file mode 100644
index 00000000..16b0b37c
--- /dev/null
+++ b/src/cache.rs
@@ -0,0 +1,134 @@
+use core::{hash::BuildHasherDefault, panic::Location};
+use std::collections::HashMap;
+use std::rc::Rc;
+
+use crate::{flag::AllocFlag, Device, Shape};
+
+use super::{Alloc, PtrConv};
+
+mod location_hasher;
+pub use location_hasher::*;
+
+mod nohasher;
+pub use nohasher::*;
+
+mod borrow_cache;
+pub use borrow_cache::*;
+
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub struct Cache {
+ pub nodes:
+ HashMap, Rc>, BuildHasherDefault>,
+}
+
+impl Default for Cache {
+ #[inline]
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+impl Cache {
+ #[inline]
+ pub fn new() -> Self {
+ Self {
+ nodes: Default::default(),
+ }
+ }
+
+ #[track_caller]
+ #[inline]
+ pub fn get>(
+ &mut self,
+ device: &D,
+ len: usize,
+ callback: fn(),
+ ) -> D::Data
+ where
+ SD: PtrConv,
+ D: PtrConv,
+ {
+ let maybe_allocated = self.nodes.get(&Location::caller().into());
+ match maybe_allocated {
+ Some(data) => unsafe { SD::convert(&data, AllocFlag::Wrapper) },
+ None => self.add_node(device, len, callback),
+ }
+ }
+
+ #[track_caller]
+ pub fn add_node>(
+ &mut self,
+ device: &D,
+ len: usize,
+ callback: fn(),
+ ) -> D::Data
+ where
+ D: PtrConv,
+ {
+ let data = device.alloc::
(len, AllocFlag::Wrapper);
+
+ let untyped_ptr = unsafe { D::convert(&data, AllocFlag::None) };
+ self.nodes
+ .insert(Location::caller().into(), Rc::new(untyped_ptr));
+
+ callback();
+
+ data
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::Cache;
+ use crate::{Base, CPU};
+
+ #[test]
+ fn test_cache_add_node() {
+ let mut cache = Cache::>::default();
+ let device = CPU::::new();
+
+ assert_eq!(cache.nodes.len(), 0);
+
+ let out = cache.add_node::(&device, 10, || ());
+
+ assert_eq!(cache.nodes.len(), 1);
+ assert_eq!(out.len, 10);
+
+ let out1 = cache.get::(&device, 10, || ());
+ assert_ne!(out.ptr, out1.ptr);
+ }
+
+ #[test]
+ fn test_cache_get_at_different_locations() {
+ let mut cache = Cache::>::default();
+ let device = CPU::::new();
+
+ assert_eq!(cache.nodes.len(), 0);
+
+ let out1 = cache.get::(&device, 10, || ());
+ assert_eq!(cache.nodes.len(), 1);
+
+ let out2 = cache.get::(&device, 10, || ());
+
+ assert_ne!(out1.ptr, out2.ptr);
+ assert_eq!(cache.nodes.len(), 2);
+ }
+
+ #[test]
+ fn test_cache_get_reuse_based_on_location() {
+ let mut cache = Cache::>::default();
+ let device = CPU::::new();
+
+ let mut prev = None;
+ for _ in 0..1000 {
+ let out3 = cache.get::(&device, 10, || ());
+ if prev.is_none() {
+ prev = Some(out3.ptr);
+ }
+ assert_eq!(prev.unwrap(), out3.ptr);
+ assert_eq!(cache.nodes.len(), 1);
+ prev = Some(out3.ptr);
+ }
+ assert_eq!(cache.nodes.len(), 1);
+ }
+}
diff --git a/src/cache/borrow_cache.rs b/src/cache/borrow_cache.rs
new file mode 100644
index 00000000..e19a8659
--- /dev/null
+++ b/src/cache/borrow_cache.rs
@@ -0,0 +1,190 @@
+use core::{
+ any::Any,
+ fmt::{Debug, Display},
+ hash::BuildHasherDefault,
+ mem::transmute,
+};
+use std::collections::HashMap;
+
+use crate::{flag::AllocFlag, Alloc, Buffer, Device, Id, Shape};
+
+use super::NoHasher;
+
+pub type UniqueId = u64;
+
+#[derive(Clone, Copy)]
+pub enum CachingError {
+ InvalidId,
+ InvalidTypeInfo,
+}
+
+impl CachingError {
+ pub fn as_str(&self) -> &'static str {
+ match self {
+ CachingError::InvalidId => "InvalidId: Invalid Buffer identifier.",
+ CachingError::InvalidTypeInfo => "InvalidTypeInfo: Invalid type information provided for allocated Buffer. Does your specific operation use mixed types?",
+ }
+ }
+}
+
+impl Debug for CachingError {
+ fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
+ Display::fmt(&self, f)
+ }
+}
+
+impl Display for CachingError {
+ fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
+ write!(f, "{}", self.as_str())
+ }
+}
+
+impl std::error::Error for CachingError {}
+
+#[derive(Debug, Default)]
+pub struct BorrowCache {
+ pub cache: HashMap, BuildHasherDefault>,
+}
+
+// TODO: make BorrowedCache unuseable without device (=> Static get methods with D: CacheReturn)
+impl BorrowCache {
+ pub fn add_or_get<'a, T, D, S>(&mut self, device: &'a D, id: Id) -> &Buffer<'a, T, D, S>
+ where
+ T: 'static,
+ D: Alloc + 'static,
+ S: Shape,
+ {
+ self.add_buf_once::(device, id);
+
+ let buf_any = self.cache.get(&id).unwrap();
+ buf_any.downcast_ref().unwrap()
+ }
+
+ pub fn add_or_get_mut<'a, T, D, S>(&mut self, device: &D, id: Id) -> &mut Buffer<'a, T, D, S>
+ where
+ T: 'static,
+ D: Alloc + 'static,
+ S: Shape,
+ {
+ self.add_buf_once::(device, id);
+ self.get_buf_mut(id).unwrap()
+ }
+
+ pub fn add_buf_once<'a, T, D, S>(&mut self, device: &'a D, id: Id)
+ where
+ T: 'static,
+ D: Alloc + 'static,
+ S: Shape,
+ {
+ if self.cache.get(&id).is_some() {
+ return;
+ }
+
+ self.add_buf::(device, id)
+ }
+
+ pub fn add_buf<'a, T, D, S>(&mut self, device: &'a D, id: Id)
+ where
+ T: 'static,
+ D: Alloc + 'static,
+ S: Shape,
+ {
+ // not using ::new, because this buf would get added to the cache of the device.
+ // not anymore ?
+ let buf = Buffer {
+ data: device.alloc::(id.len, AllocFlag::BorrowedCache),
+ device: Some(device),
+ };
+
+ let buf = unsafe { transmute::<_, Buffer<'static, T, D, S>>(buf) };
+ self.cache.insert(*id, Box::new(buf));
+ }
+
+ #[inline]
+ pub fn get_buf_with_dev<'a, 'b, T, D, S>(
+ &'b self,
+ id: Id,
+ _dev: &'a D,
+ ) -> Option<&'b Buffer<'a, T, D, S>>
+ where
+ T: 'static,
+ D: Alloc + 'static,
+ S: Shape,
+ {
+ self.cache.get(&id)?.downcast_ref()
+ }
+
+ #[inline]
+ pub fn get_buf<'a, T, D, S>(&self, id: Id) -> Result<&Buffer<'a, T, D, S>, CachingError>
+ where
+ T: 'static,
+ D: Device + 'static,
+ S: Shape,
+ {
+ self.cache
+ .get(&id)
+ .ok_or(CachingError::InvalidId)?
+ .downcast_ref()
+ .ok_or(CachingError::InvalidTypeInfo)
+ }
+
+ #[inline]
+ pub fn get_buf_mut<'a, T, D, S>(
+ &mut self,
+ id: Id,
+ ) -> Result<&mut Buffer<'a, T, D, S>, CachingError>
+ where
+ T: 'static,
+ D: Device + 'static,
+ S: Shape,
+ {
+ unsafe {
+ transmute(
+ self.cache
+ .get_mut(&id)
+ .ok_or(CachingError::InvalidId)?
+ .downcast_mut::>()
+ .ok_or(CachingError::InvalidTypeInfo),
+ )
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use crate::{Base, CPU};
+
+ use super::BorrowCache;
+
+ /*#[test]
+ fn test_comp_error() {
+ let device = CPU::new();
+
+
+ let a = {
+ let mut cache = BorrowingCache::default();
+ cache.add_or_get::(&device, Id::new(10))
+ };
+ }*/
+
+ /*#[test]
+ fn test_get_borrowed() {
+ let device = CPU::::default();
+ let mut cache = BorrowCache::default();
+
+ let (fid, sid, tid) = (
+ Id::new_bumped(10),
+ Id::new_bumped(10),
+ Id::new_bumped(10),
+ );
+
+ cache.add_buf_once::(&device, fid);
+ cache.add_buf_once::(&device, sid);
+ cache.add_buf_once::(&device, tid);
+
+ let a = cache.get_buf::(fid).unwrap();
+ let b = cache.get_buf::(fid).unwrap();
+
+ assert_eq!(a.ptr, b.ptr);
+ }*/
+}
diff --git a/src/cache/location_hasher.rs b/src/cache/location_hasher.rs
new file mode 100644
index 00000000..5ea29811
--- /dev/null
+++ b/src/cache/location_hasher.rs
@@ -0,0 +1,73 @@
+use core::{ops::BitXor, panic::Location};
+
+#[derive(Default)]
+pub struct LocationHasher {
+ hash: u64,
+}
+
+const K: u64 = 0x517cc1b727220a95;
+
+impl std::hash::Hasher for LocationHasher {
+ #[inline]
+ fn finish(&self) -> u64 {
+ self.hash
+ }
+
+ #[inline]
+ fn write(&mut self, _bytes: &[u8]) {
+ unimplemented!("LocationHasher only hashes u64, (u32 and usize as u64 cast).")
+ }
+
+ #[inline]
+ fn write_u64(&mut self, i: u64) {
+ self.hash = self.hash.rotate_left(5).bitxor(i).wrapping_mul(K);
+ }
+
+ #[inline]
+ fn write_u32(&mut self, i: u32) {
+ self.write_u64(i as u64);
+ }
+
+ #[inline]
+ fn write_usize(&mut self, i: usize) {
+ self.write_u64(i as u64);
+ }
+}
+
+#[derive(Debug, Clone, Copy, Eq)]
+pub struct HashLocation<'a> {
+ pub file: &'a str,
+ pub line: u32,
+ pub col: u32,
+}
+
+impl PartialEq for HashLocation<'_> {
+ #[inline]
+ fn eq(&self, other: &Self) -> bool {
+ // if filename pointer is actually actually unique, then this works (added units tests to check this... still not sure)
+ if self.file.as_ptr() != other.file.as_ptr() {
+ return false;
+ }
+ self.line == self.line && self.col == self.col
+ }
+}
+
+impl<'a> std::hash::Hash for HashLocation<'a> {
+ #[inline]
+ fn hash(&self, state: &mut H) {
+ self.file.as_ptr().hash(state);
+ let line_col = (self.line as u64) << 9 | self.col as u64;
+ line_col.hash(state);
+ }
+}
+
+impl<'a> From<&'a Location<'a>> for HashLocation<'a> {
+ #[inline]
+ fn from(loc: &'a Location<'a>) -> Self {
+ Self {
+ file: loc.file(),
+ line: loc.line(),
+ col: loc.column(),
+ }
+ }
+}
diff --git a/src/cache/nohasher.rs b/src/cache/nohasher.rs
new file mode 100644
index 00000000..f6facf62
--- /dev/null
+++ b/src/cache/nohasher.rs
@@ -0,0 +1,31 @@
+#[derive(Default)]
+pub struct NoHasher {
+ hash: u64,
+}
+
+impl std::hash::Hasher for NoHasher {
+ #[inline]
+ fn finish(&self) -> u64 {
+ self.hash
+ }
+
+ #[inline]
+ fn write(&mut self, _bytes: &[u8]) {
+ unimplemented!("NoHasher only hashes u64, (u32 and usize as u64 cast).")
+ }
+
+ #[inline]
+ fn write_u64(&mut self, i: u64) {
+ self.hash = i;
+ }
+
+ #[inline]
+ fn write_u32(&mut self, i: u32) {
+ self.write_u64(i as u64);
+ }
+
+ #[inline]
+ fn write_usize(&mut self, i: usize) {
+ self.write_u64(i as u64);
+ }
+}
diff --git a/src/count.rs b/src/count.rs
deleted file mode 100644
index 46b0a93d..00000000
--- a/src/count.rs
+++ /dev/null
@@ -1,173 +0,0 @@
-use core::ops::{Range, RangeInclusive};
-
-/// Converts ranges into a start and end index.
-pub trait AsRangeArg {
- /// Returns the start index of the range.
- fn start(&self) -> usize;
- /// Returns the end index of the range.
- fn end(&self) -> usize;
-}
-
-impl AsRangeArg for Range {
- #[inline]
- fn start(&self) -> usize {
- self.start
- }
-
- #[inline]
- fn end(&self) -> usize {
- self.end
- }
-}
-
-impl AsRangeArg for RangeInclusive {
- #[inline]
- fn start(&self) -> usize {
- *self.start()
- }
-
- #[inline]
- fn end(&self) -> usize {
- *self.end() + 1
- }
-}
-
-impl AsRangeArg for usize {
- #[inline]
- fn start(&self) -> usize {
- 0
- }
-
- #[inline]
- fn end(&self) -> usize {
- *self
- }
-}
-
-impl AsRangeArg for (usize, usize) {
- #[inline]
- fn start(&self) -> usize {
- self.0
- }
-
- #[inline]
- fn end(&self) -> usize {
- self.1
- }
-}
-
-/// `range` resets the cache count in every iteration.
-/// The cache count is used to retrieve the same allocation in each iteration.
-/// Not adding `range` results in allocating new memory in each iteration,
-/// which is only freed when the device is dropped.
-/// To disable this caching behaviour, enable the `realloc` feature.
-///
-/// # Example
-#[cfg_attr(not(feature = "no-std"), doc = "```")]
-#[cfg_attr(feature = "no-std", doc = "```ignore")]
-/// use custos::{get_count, range, Ident, bump_count};
-///
-/// for _ in range(100) { // using only one usize: exclusive range
-/// Ident::new(10); // an 'Ident' is created if a Buffer is retrieved from cache.
-/// bump_count();
-/// assert!(get_count() == 1);
-/// }
-/// assert!(get_count() == 0);
-/// ```
-#[inline]
-pub fn range(range: R) -> Count {
- Count(range.start(), range.end())
-}
-
-/// used to reset the cache count
-#[derive(Debug, Clone, Copy)]
-pub struct Count(pub(super) usize, pub(super) usize);
-
-/// The iterator used for setting the cache count.
-#[derive(Debug)]
-pub struct CountIntoIter {
- epoch: usize,
- #[cfg(not(feature = "no-std"))]
- idx: usize,
- end: usize,
-}
-
-impl Iterator for CountIntoIter {
- type Item = usize;
-
- fn next(&mut self) -> Option {
- #[cfg(not(feature = "no-std"))]
- unsafe {
- crate::set_count(self.idx)
- };
- if self.epoch >= self.end {
- return None;
- }
- let epoch = Some(self.epoch);
- self.epoch += 1;
- epoch
- }
-}
-
-impl IntoIterator for Count {
- type Item = usize;
-
- type IntoIter = CountIntoIter;
-
- #[inline]
- fn into_iter(self) -> Self::IntoIter {
- CountIntoIter {
- epoch: self.0,
- #[cfg(not(feature = "no-std"))]
- idx: crate::get_count(),
- end: self.1,
- }
- }
-}
-
-#[cfg(test)]
-mod tests {
- use crate::{range, Count, CountIntoIter};
-
- fn count_iter(iter: &mut CountIntoIter) {
- iter.next();
- assert_eq!(iter.epoch, 1);
- #[cfg(not(feature = "no-std"))]
- assert_eq!(iter.idx, 0);
- assert_eq!(iter.end, 10);
-
- iter.next();
- assert_eq!(iter.epoch, 2);
- #[cfg(not(feature = "no-std"))]
- assert_eq!(iter.idx, 0);
- assert_eq!(iter.end, 10);
- }
-
- #[test]
- fn test_count_into_iter() {
- let mut iter = CountIntoIter {
- epoch: 0,
- #[cfg(not(feature = "no-std"))]
- idx: 0,
- end: 10,
- };
-
- count_iter(&mut iter);
- }
-
- #[test]
- fn test_count() {
- let count: Count = Count(0, 10);
- count_iter(&mut count.into_iter());
- }
-
- #[test]
- fn test_range_inclusive() {
- let count: Count = range(0..=9);
- count_iter(&mut count.into_iter());
-
- for (idx, other) in count.into_iter().zip(0..=9) {
- assert_eq!(idx, other)
- }
- }
-}
diff --git a/src/device_traits.rs b/src/device_traits.rs
new file mode 100644
index 00000000..41543c30
--- /dev/null
+++ b/src/device_traits.rs
@@ -0,0 +1,88 @@
+// TODO: move to devices folder ig
+
+use crate::{flag::AllocFlag, prelude::Device, Buffer, HasId, Parents, PtrType, Shape, StackArray};
+
+pub trait Alloc: Device + Sized {
+ /// Allocate memory on the implemented device.
+ /// # Example
+ #[cfg_attr(feature = "cpu", doc = "```")]
+ #[cfg_attr(not(feature = "cpu"), doc = "```ignore")]
+ /// use custos::{CPU, Alloc, Buffer, Read, flag::AllocFlag, GraphReturn, cpu::CPUPtr};
+ ///
+ /// let device = CPU::new();
+ /// let ptr = Alloc::::alloc(&device, 12, AllocFlag::None);
+ ///
+ /// let buf: Buffer = Buffer {
+ /// ident: None,
+ /// ptr,
+ /// device: Some(&device),
+ /// };
+ /// assert_eq!(vec![0.; 12], device.read(&buf));
+ /// ```
+ fn alloc(&self, len: usize, flag: AllocFlag) -> Self::Data;
+
+ /// Allocate new memory with data
+ /// # Example
+ #[cfg_attr(feature = "cpu", doc = "```")]
+ #[cfg_attr(not(feature = "cpu"), doc = "```ignore")]
+ /// use custos::{CPU, Alloc, Buffer, Read, GraphReturn, cpu::CPUPtr};
+ ///
+ /// let device = CPU::new();
+ /// let ptr = Alloc::::with_slice(&device, &[1, 5, 4, 3, 6, 9, 0, 4]);
+ ///
+ /// let buf: Buffer = Buffer {
+ /// ident: None,
+ /// ptr,
+ /// device: Some(&device),
+ /// };
+ /// assert_eq!(vec![1, 5, 4, 3, 6, 9, 0, 4], device.read(&buf));
+ /// ```
+ fn alloc_from_slice(&self, data: &[T]) -> Self::Data
+ where
+ T: Clone;
+
+ /// If the vector `vec` was allocated previously, this function can be used in order to reduce the amount of allocations, which may be faster than using a slice of `vec`.
+ #[inline]
+ #[cfg(not(feature = "no-std"))]
+ fn alloc_from_vec(&self, vec: Vec) -> Self::Data
+ where
+ T: Clone,
+ {
+ self.alloc_from_slice(&vec)
+ }
+
+ /// Allocates a pointer with the array provided by the `S:`[`Shape`] generic.
+ /// By default, the array is flattened and then passed to [`Alloc::alloc_from_slice`].
+ #[inline]
+ fn alloc_from_array(&self, array: S::ARR) -> Self::Data
+ where
+ T: Clone,
+ {
+ let stack_array = StackArray::::from_array(array);
+ self.alloc_from_slice(stack_array.flatten())
+ }
+}
+
+pub trait Module {
+ type Module;
+
+ fn new() -> Self::Module;
+}
+
+/// Used for modules that should affect the device.
+pub trait Setup {
+ #[inline]
+ fn setup(_device: &mut D) {}
+}
+
+pub trait Retriever: Device {
+ #[track_caller]
+ fn retrieve(
+ &self,
+ len: usize,
+ parents: impl Parents,
+ ) -> Buffer
+ where
+ T: 'static,
+ S: Shape;
+}
diff --git a/src/devices.rs b/src/devices.rs
new file mode 100644
index 00000000..2996ee41
--- /dev/null
+++ b/src/devices.rs
@@ -0,0 +1,112 @@
+//! This module defines all available compute devices
+
+mod generic_blas;
+pub use generic_blas::*;
+
+#[cfg(feature = "cpu")]
+pub mod cpu;
+
+#[cfg(feature = "cuda")]
+pub mod cuda;
+
+#[cfg(feature = "opencl")]
+pub mod opencl;
+
+#[cfg(feature = "stack")]
+pub mod stack;
+
+#[cfg(feature = "wgpu")]
+pub mod wgpu;
+
+#[cfg(feature = "network")]
+pub mod network;
+
+mod stack_array;
+pub use stack_array::*;
+
+mod cdatatype;
+pub use cdatatype::*;
+
+#[cfg(all(any(feature = "cpu", feature = "stack"), feature = "macro"))]
+mod cpu_stack_ops;
+
+use crate::{Alloc, Buffer, HasId, OnDropBuffer, PtrType, Shape};
+
+pub trait Device: OnDropBuffer + Sized {
+ type Data: HasId + PtrType;
+
+ type Error;
+
+ #[inline]
+ fn new() -> Result {
+ todo!()
+ }
+
+ /// Creates a new [`Buffer`] using `A`.
+ ///
+ /// # Example
+ #[cfg_attr(feature = "cpu", doc = "```")]
+ #[cfg_attr(not(feature = "cpu"), doc = "```ignore")]
+ /// use custos::{CPU, Device};
+ ///
+ /// let device = CPU::new();
+ /// let buf = device.buffer([5, 4, 3]);
+ ///
+ /// assert_eq!(buf.read(), [5, 4, 3]);
+ /// ```
+ fn buffer<'a, T, S: Shape, A>(&'a self, arr: A) -> Buffer<'a, T, Self, S>
+ where
+ Buffer<'a, T, Self, S>: From<(&'a Self, A)>,
+ {
+ Buffer::from((self, arr))
+ }
+}
+
+#[macro_export]
+macro_rules! impl_buffer_hook_traits {
+ ($device:ident) => {
+ impl> OnNewBuffer
+ for $device
+ {
+ #[inline]
+ fn on_new_buffer(&self, device: &D, new_buf: &Buffer) {
+ self.modules.on_new_buffer(device, new_buf)
+ }
+ }
+
+ impl OnDropBuffer for $device {
+ #[inline]
+ fn on_drop_buffer<'a, T, D: Device, S: Shape>(
+ &self,
+ device: &'a D,
+ buf: &Buffer,
+ ) {
+ self.modules.on_drop_buffer(device, buf)
+ }
+ }
+ };
+}
+
+#[macro_export]
+macro_rules! impl_retriever {
+ ($device:ident) => {
+ impl> Retriever for $device {
+ #[inline]
+ fn retrieve(
+ &self,
+ len: usize,
+ parents: impl crate::Parents,
+ ) -> Buffer {
+ let data = self
+ .modules
+ .retrieve::(self, len, parents);
+ let buf = Buffer {
+ data,
+ device: Some(self),
+ };
+ self.modules.on_retrieve_finish(&buf);
+ buf
+ }
+ }
+ };
+}
diff --git a/src/devices/addons.rs b/src/devices/addons.rs
deleted file mode 100644
index b8450bda..00000000
--- a/src/devices/addons.rs
+++ /dev/null
@@ -1,124 +0,0 @@
-use core::{cell::RefCell, fmt::Debug};
-
-use crate::{Cache, CacheReturn, Device, GlobalCount, Graph, GraphReturn, NodeIdx, PtrConv};
-
-use super::caller_cache::{CallerCacheReturn, TrackCallerCache};
-
-/// Provides several addons for a device.
-/// - `graph`: An optimizeable graph.
-/// - `cache`: A cache for allocations.
-/// - `tape`: A (gradient) tape.
-pub struct Addons {
- /// An optimizeable graph.
- pub graph: RefCell>,
- /// A cache for allocations.
- pub cache: RefCell>,
- /// A (gradient) tape.
- #[cfg(feature = "autograd")]
- pub tape: RefCell>,
- pub caller_cache: RefCell>,
-}
-
-impl Debug for Addons
-where
- D::Ptr: Debug,
-{
- fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
- #[cfg(feature = "autograd")]
- {
- f.debug_struct("Addons")
- .field("graph", &self.graph)
- .field("cache", &self.cache)
- .field("tape", &self.tape)
- .finish()
- }
-
- #[cfg(not(feature = "autograd"))]
- f.debug_struct("Addons")
- .field("graph", &self.graph)
- .field("cache", &self.cache)
- .finish()
- }
-}
-
-impl Default for Addons
-where
- D::Ptr: Default,
-{
- fn default() -> Self {
- Self {
- graph: Default::default(),
- cache: Default::default(),
- #[cfg(feature = "autograd")]
- tape: Default::default(),
- caller_cache: Default::default(),
- }
- }
-}
-
-/// `AddonsReturn` is probably implemented for all devices that have an [`Addons`] field.
-pub trait AddonsReturn: Device {
- /// Returns a reference to [`Addons`].
- fn addons(&self) -> &Addons;
-}
-
-impl GraphReturn for D {
- #[inline]
- fn graph(&self) -> std::cell::Ref> {
- self.addons().graph.borrow()
- }
-
- #[inline]
- fn graph_mut(&self) -> std::cell::RefMut> {
- self.addons().graph.borrow_mut()
- }
-}
-
-impl CacheReturn for D {
- #[inline]
- fn cache(&self) -> core::cell::Ref>
- where
- Self: PtrConv,
- {
- self.addons().cache.borrow()
- }
-
- #[inline]
- fn cache_mut(&self) -> core::cell::RefMut>
- where
- Self: PtrConv,
- {
- self.addons().cache.borrow_mut()
- }
-}
-
-impl CallerCacheReturn for D {
- #[inline]
- fn cache(&self) -> core::cell::Ref>
- where
- Self: PtrConv,
- {
- self.addons().caller_cache.borrow()
- }
-
- #[inline]
- fn cache_mut(&self) -> core::cell::RefMut>
- where
- Self: PtrConv,
- {
- self.addons().caller_cache.borrow_mut()
- }
-}
-
-#[cfg(feature = "autograd")]
-impl crate::TapeReturn for D {
- #[inline]
- fn tape(&self) -> core::cell::Ref> {
- self.addons().tape.borrow()
- }
-
- #[inline]
- fn tape_mut(&self) -> core::cell::RefMut> {
- self.addons().tape.borrow_mut()
- }
-}
diff --git a/src/devices/borrowing_cache.rs b/src/devices/borrowing_cache.rs
deleted file mode 100644
index 0ed76a20..00000000
--- a/src/devices/borrowing_cache.rs
+++ /dev/null
@@ -1,123 +0,0 @@
-use core::{any::Any, hash::BuildHasherDefault, mem::transmute};
-use std::collections::HashMap;
-
-use crate::{flag::AllocFlag, Alloc, Buffer, Device, Ident, IdentHasher, Shape};
-
-#[derive(Debug, Default)]
-pub struct BorrowingCache {
- pub cache: HashMap, BuildHasherDefault>,
-}
-
-// TODO: make BorrowedCache unuseable without device (=> Static get methods with D: CacheReturn)
-impl BorrowingCache {
- pub fn add_or_get<'a, T, D, S>(&mut self, device: &'a D, id: Ident) -> &Buffer<'a, T, D, S>
- where
- T: 'static,
- D: Alloc<'a, T, S> + 'static,
- S: Shape,
- {
- self.add_buf_once(device, id);
-
- let buf_any = self.cache.get(&id).unwrap();
- buf_any.downcast_ref().unwrap()
- }
-
- pub fn add_or_get_mut<'a, T, D, S>(&mut self, device: &D, id: Ident) -> &mut Buffer<'a, T, D, S>
- where
- T: 'static,
- D: for<'b> Alloc<'b, T, S> + 'static,
- S: Shape,
- {
- self.add_buf_once(device, id);
- self.get_buf_mut(id).unwrap()
- }
-
- pub fn add_buf_once<'a, T, D, S>(&mut self, device: &'a D, ident: Ident)
- where
- T: 'static,
- D: Alloc<'a, T, S> + 'static,
- S: Shape,
- {
- if self.cache.get(&ident).is_some() {
- return;
- }
-
- self.add_buf(device, ident)
- }
-
- pub fn add_buf<'a, T, D, S>(&mut self, device: &'a D, ident: Ident)
- where
- T: 'static,
- D: Alloc<'a, T, S> + 'static,
- S: Shape,
- {
- // not using ::new, because this buf would get added to the cache of the device.
- let buf = Buffer {
- ptr: device.alloc(ident.len, AllocFlag::BorrowedCache),
- device: Some(device),
- ident: Some(ident),
- };
-
- let buf = unsafe { transmute::<_, Buffer<'static, T, D, S>>(buf) };
- self.cache.insert(ident, Box::new(buf));
- }
-
- #[inline]
- pub fn get_buf<'a, T, D, S>(&self, id: Ident) -> Option<&Buffer<'a, T, D, S>>
- where
- T: 'static,
- D: Device + 'static,
- S: Shape,
- {
- self.cache.get(&id)?.downcast_ref()
- }
-
- #[inline]
- pub fn get_buf_mut<'a, T, D, S>(&mut self, id: Ident) -> Option<&mut Buffer<'a, T, D, S>>
- where
- T: 'static,
- D: Device + 'static,
- S: Shape,
- {
- unsafe { transmute(self.cache.get_mut(&id)?.downcast_mut::>()) }
- }
-}
-
-#[cfg(test)]
-mod tests {
- use crate::{Ident, CPU};
-
- use super::BorrowingCache;
-
- /*#[test]
- fn test_comp_error() {
- let device = CPU::new();
-
-
- let a = {
- let mut cache = BorrowingCache::default();
- cache.add_or_get::(&device, Ident::new(10))
- };
- }*/
-
- #[test]
- fn test_get_borrowed() {
- let device = CPU::new();
- let mut cache = BorrowingCache::default();
-
- let (fid, sid, tid) = (
- Ident::new_bumped(10),
- Ident::new_bumped(10),
- Ident::new_bumped(10),
- );
-
- cache.add_buf_once::(&device, fid);
- cache.add_buf_once::(&device, sid);
- cache.add_buf_once::(&device, tid);
-
- let a = cache.get_buf::(fid).unwrap();
- let b = cache.get_buf::(fid).unwrap();
-
- assert_eq!(a.ptr, b.ptr);
- }
-}
diff --git a/src/devices/cache.rs b/src/devices/cache.rs
deleted file mode 100644
index ecf91c9a..00000000
--- a/src/devices/cache.rs
+++ /dev/null
@@ -1,321 +0,0 @@
-//! Contains the [`Cache`]ing logic.
-
-use core::{cell::RefMut, fmt::Debug, hash::BuildHasherDefault, ops::BitXor};
-use std::collections::HashMap;
-
-use std::rc::Rc;
-
-use crate::{
- flag::AllocFlag, shape::Shape, Alloc, Buffer, CacheAble, Device, GlobalCount, GraphReturn,
- Ident, PtrConv, PtrType,
-};
-
-/// This trait makes a device's [`Cache`] accessible and is implemented for all compute devices.
-pub trait CacheReturn: GraphReturn {
- /// Returns a reference to a device's [`Cache`].
- fn cache(&self) -> core::cell::Ref>
- where
- Self: PtrConv;
-
- /// Returns a mutable reference to a device's [`Cache`].
- fn cache_mut(&self) -> RefMut>
- where
- Self: PtrConv;
-}
-
-const K: usize = 0x517cc1b727220a95;
-
-/// A low-overhead [`Ident`] hasher using "FxHasher".
-#[derive(Default)]
-pub struct IdentHasher {
- hash: usize,
-}
-
-impl std::hash::Hasher for IdentHasher {
- #[inline]
- fn finish(&self) -> u64 {
- self.hash as u64
- }
-
- #[inline]
- fn write(&mut self, _bytes: &[u8]) {
- unimplemented!("IdentHasher only hashes usize.")
- }
-
- #[inline]
- fn write_usize(&mut self, i: usize) {
- self.hash = self.hash.rotate_left(5).bitxor(i).wrapping_mul(K);
- }
-}
-
-impl CacheAble for Cache
-where
- D: PtrConv + CacheReturn,
-{
- #[cfg(not(feature = "realloc"))]
- #[inline]
- fn retrieve(
- device: &D,
- len: usize,
- add_node: impl crate::AddGraph,
- ) -> Buffer
- where
- for<'b> D: Alloc<'b, T, S>,
- {
- device
- .cache_mut()
- .get(device, Ident::new(len), add_node, crate::bump_count)
- }
-
- #[cfg(feature = "realloc")]
- #[inline]
- fn retrieve