Skip to content

Commit

Permalink
Add ToBase, ToDim, remove unsafe from retrieve, ..
Browse files Browse the repository at this point in the history
  • Loading branch information
elftausend committed Nov 17, 2024
1 parent a472d3c commit 2bb7672
Show file tree
Hide file tree
Showing 22 changed files with 292 additions and 256 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ min-cl = { git = "https://github.com/elftausend/min-cl", optional = true }
[features]
# default = ["cpu", "blas", "static-api", "macro", "cached", "autograd", "stack", "opencl", "fork", "graph", "untyped"]

default = ["cpu", "cached"]
default = ["cpu", "cached", "autograd"]
# default = ["no-std"]
# default = ["opencl"]
# default = ["untyped", "cpu", "lazy", "graph", "autograd", "fork", "serde", "json", "half", "cached", "static-api", "stack", "opencl", "nnapi"]
Expand Down
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,19 +93,19 @@ This operation is only affected by the `Cached` module (and partially `Autograd`
use custos::prelude::*;
use std::ops::{Deref, Mul};

pub trait MulBuf<T: Unit, S: Shape = (), D: Device = Self>: Sized + Device {
fn mul(&self, lhs: &Buffer<T, D, S>, rhs: &Buffer<T, D, S>) -> Buffer<T, Self, S>;
pub trait MulBuf<'a, T: Unit, S: Shape = (), D: Device = Self>: Sized + Device {
fn mul(&'a self, lhs: &Buffer<T, D, S>, rhs: &Buffer<T, D, S>) -> Buffer<'a, T, Self, S>;
}

impl<Mods, T, S, D> MulBuf<T, S, D> for CPU<Mods>
impl<'a, Mods, T, S, D> MulBuf<'a, T, S, D> for CPU<Mods>
where
Mods: Retrieve<Self, T, S>,
Mods: Retrieve<'a, Self, T, S>,
T: Unit + Mul<Output = T> + Copy + 'static,
S: Shape,
D: Device,
D::Base<T, S>: Deref<Target = [T]>,
{
fn mul(&self, lhs: &Buffer<T, D, S>, rhs: &Buffer<T, D, S>) -> Buffer<T, Self, S> {
fn mul(&'a self, lhs: &Buffer<T, D, S>, rhs: &Buffer<T, D, S>) -> Buffer<'a, T, Self, S> {
let mut out = self.retrieve(lhs.len(), (lhs, rhs)).unwrap(); // unwrap or return error (update trait)

for ((lhs, rhs), out) in lhs.iter().zip(rhs.iter()).zip(&mut out) {
Expand Down
19 changes: 11 additions & 8 deletions examples/custom_module.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use custos::{
Alloc, Base, Device, HasId, IsBasePtr, Module, OnDropBuffer, Parents, PtrType, Retrieve, Setup, Shape, Unit, WrappedData, CPU
Alloc, Base, Device, HasId, IsBasePtr, Module, OnDropBuffer, Parents, PtrType, Retrieve, Setup,
Shape, Unit, WrappedData, CPU,
};

pub struct CustomModule<Mods> {
Expand Down Expand Up @@ -43,7 +44,9 @@ impl<Mods: WrappedData> WrappedData for CustomModule<Mods> {
}

#[inline]
fn wrapped_as_base<'a, 'b, T: Unit, Base: IsBasePtr>(wrap: &'b Self::Wrap<'a, T, Base>) -> &'b Base {
fn wrapped_as_base<'a, 'b, T: Unit, Base: IsBasePtr>(
wrap: &'b Self::Wrap<'a, T, Base>,
) -> &'b Base {
Mods::wrapped_as_base(wrap)
}

Expand Down Expand Up @@ -88,19 +91,19 @@ where
self.mods.retrieve_entry(device, len, parents)
}

fn on_retrieve_finish<const NUM_PARENTS: usize>(&self,
fn on_retrieve_finish<const NUM_PARENTS: usize>(
&self,
len: usize,
parents: impl Parents<NUM_PARENTS>,
retrieved_buf: &custos::prelude::Buffer<T, D, S>
)
where
retrieved_buf: &custos::prelude::Buffer<T, D, S>,
) where
D: Alloc<T>,
{
// inject custom behaviour in this body

self.mods.on_retrieve_finish(len, parents, retrieved_buf)
}

unsafe fn retrieve<const NUM_PARENTS: usize>(
&self,
device: &D,
Expand All @@ -109,7 +112,7 @@ where
) -> custos::Result<Self::Wrap<'a, T, <D>::Base<T, S>>>
where
S: Shape,
D: Alloc<T>
D: Alloc<T>,
{
self.mods.retrieve(device, len, parents)
}
Expand Down
4 changes: 2 additions & 2 deletions examples/modules_usage.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use std::ops::{Add, AddAssign, Deref, DerefMut, Mul};

use custos::{
AddGradFn, AddOperation, Alloc, Buffer, Device, MayGradActions,
Retrieve, Retriever, Shape, Unit, ZeroGrad, CPU,
AddGradFn, AddOperation, Alloc, Buffer, Device, MayGradActions, Retrieve, Retriever, Shape,
Unit, ZeroGrad, CPU,
};

pub trait ElementWise<'a, T: Unit, D: Device, S: Shape>: Device {
Expand Down
45 changes: 19 additions & 26 deletions src/buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ use crate::CPU;

use crate::{
flag::AllocFlag, shape::Shape, Alloc, Base, ClearBuf, CloneBuf, Device, DevicelessAble, HasId,
IsShapeIndep, OnDropBuffer, OnNewBuffer, PtrType, Read, ReplaceBuf, ShallowCopy, Unit,
WrappedCopy, WrappedData, WriteBuf, ZeroGrad,
IsShapeIndep, OnDropBuffer, OnNewBuffer, PtrType, Read, ReplaceBuf, ShallowCopy, ToDim, Unit,
WrappedData, WriteBuf, ZeroGrad,
};

pub use self::num::Num;
Expand Down Expand Up @@ -136,7 +136,7 @@ impl<'a, T: Unit, D: Device, S: Shape> Buffer<'a, T, D, S> {
// #[inline]
// fn id(&self) -> super::Id {
// self.data.id()
// }
// }i
// }

impl<'a, T: Unit, D: Device, S: Shape> HasId for Buffer<'a, T, D, S> {
Expand Down Expand Up @@ -231,7 +231,7 @@ impl<'a, T: Unit, D: Device + OnNewBuffer<'a, T, D, S>, S: Shape> Buffer<'a, T,
/// Creates a new `Buffer` from an nd-array.
/// The dimension is defined by the [`Shape`].
#[inline]
pub fn from_array(device: &'a D, array: S::ARR<T>) -> Buffer<T, D, S>
pub fn from_array(device: &'a D, array: S::ARR<T>) -> Buffer<'a, T, D, S>
where
T: Clone,
D: Alloc<T>,
Expand Down Expand Up @@ -271,21 +271,24 @@ impl<'a, T: Unit, D: Device, S: Shape> Buffer<'a, T, D, S> {
}

#[inline]
pub fn to_deviceless<'b>(self) -> Buffer<'b, T, D, S>
pub fn to_deviceless<'b>(mut self) -> Buffer<'b, T, D, S>
where
D::Data<'b, T, S>: Default,
D::Base<T, S>: ShallowCopy,
{
if let Some(device) = self.device {
if self.data.flag() != AllocFlag::None {
device.on_drop_buffer(device, &self)
}
}
todo!()
// let mut val = ManuallyDrop::new(self);

// let data = core::mem::take(&mut val.data);
unsafe { self.set_flag(AllocFlag::Wrapper) };
let mut base = unsafe { self.base().shallow() };
unsafe { base.set_flag(AllocFlag::None) };

let data: <D as Device>::Data<'b, T, S> = self.device().base_to_data::<T, S>(base);

// Buffer { data, device: None }
Buffer { data, device: None }
}

/// Returns the device of the `Buffer`.
Expand Down Expand Up @@ -460,7 +463,6 @@ impl<'a, T: Unit, D: Device, S: Shape> Buffer<'a, T, D, S> {
}
}

// TODO better solution for the to_dims stack problem?
impl<'a, T: Unit, D: Device, S: Shape> Buffer<'a, T, D, S> {
/// Converts a non stack allocated `Buffer` with shape `S` to a `Buffer` with shape `O`.
/// # Example
Expand All @@ -474,24 +476,15 @@ impl<'a, T: Unit, D: Device, S: Shape> Buffer<'a, T, D, S> {
///
/// ```
#[inline]
pub fn to_dims<O: Shape>(self) -> Buffer<'a, T, D, O>
pub fn to_dims<O: Shape>(mut self) -> Buffer<'a, T, D, O>
where
// D: crate::ToDim<T, S, O>,
D::Data<'a, T, S>: WrappedCopy<Base = D::Base<T, S>>,
D::Base<T, S>: ShallowCopy,
D::Data<'a, T, S>: Default + ToDim<Out = D::Data<'a, T, O>>,
{
let base = unsafe { (*self).shallow() };
let data = self.data.wrapped_copy(base);
let buf = ManuallyDrop::new(self);

// let mut data = buf.device().to_dim(data);
// unsafe { data.set_flag(AllocFlag::None) };
todo!()

// Buffer {
// data,
// device: buf.device,
// }
let data = std::mem::take(&mut self.data).to_dim();
Buffer {
data,
device: self.device,
}
}
}

Expand Down
17 changes: 8 additions & 9 deletions src/cache/borrow_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,22 +122,21 @@ impl BorrowCache {
}

#[inline]
pub fn get_buf<T, D, S>(
&self,
pub fn get_buf<'a, 'b, T, D, S>(
&'a self,
_device: &D,
id: Id,
) -> Result<&Buffer<'_, T, D, S>, CachingError>
) -> Result<&'a Buffer<'b, T, D, S>, CachingError>
where
T: Unit + 'static,
D: Device + 'static,
S: Shape,
{
todo!()
// self.cache
// .get(&id)
// .ok_or(CachingError::InvalidId)?
// .downcast_ref()
// .ok_or(CachingError::InvalidTypeInfo)
let out = self.cache.get(&id).ok_or(CachingError::InvalidId)?;
if !out.is::<Buffer<T, D, S>>() {
return Err(CachingError::InvalidTypeInfo);
}
Ok(unsafe { out.downcast_ref_unchecked::<Buffer<'_, T, D, S>>() })
}

#[inline]
Expand Down
23 changes: 17 additions & 6 deletions src/cache/locking/guard.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
use core::{
mem::ManuallyDrop,
ops::{Deref, DerefMut},
};
use core::ops::{Deref, DerefMut};

use crate::{CowMutCell, HasId, HostPtr, PtrType, ShallowCopy};
use crate::{CowMutCell, HasId, HostPtr, PtrType, ShallowCopy, ToDim};

#[derive(Debug)]
pub struct Guard<'a, T> {
Expand All @@ -24,7 +21,7 @@ impl<'a, T> Guard<'a, T> {
Guard { data: f(data) }
}

#[inline]
#[inline]
pub fn make_static(self) -> Option<Guard<'static, T>> {
match self.data {
CowMutCell::Borrowed(_) => None,
Expand Down Expand Up @@ -91,3 +88,17 @@ impl<'a, T, P: PtrType + HostPtr<T>> HostPtr<T> for Guard<'a, P> {
self.data.get_mut().unwrap().ptr_mut()
}
}

impl<'a, P> ToDim for Guard<'a, P> {
type Out = Self;

#[inline]
fn to_dim(self) -> Self::Out {
self
}

#[inline]
fn as_dim(&self) -> &Self::Out {
self
}
}
8 changes: 6 additions & 2 deletions src/cache/owned_cache/fast_cache2.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
use core::{any::Any, cell::{Ref, RefMut}, hash::BuildHasherDefault};
use crate::{LockedMap, NoHasher, State, UniqueId};
use core::{
any::Any,
cell::{Ref, RefMut},
hash::BuildHasherDefault,
};

use super::Cache;

Expand All @@ -18,7 +22,7 @@ impl Cache<Box<dyn Any>> for FastCache2 {
fn insert(&self, id: UniqueId, _len: usize, data: Box<dyn Any>) {
self.nodes.insert(id, data);
}

#[inline]
fn get(&self, id: UniqueId, _len: usize) -> State<Ref<Box<dyn Any>>> {
self.nodes.get(&id)
Expand Down
34 changes: 24 additions & 10 deletions src/devices/cpu/cpu_ptr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ use core::{

use std::alloc::handle_alloc_error;

use crate::{flag::AllocFlag, HasId, HostPtr, Id, PtrType, ShallowCopy, Unit, WrappedCopy};
use crate::{
flag::AllocFlag, Device, HasId, HostPtr, Id, PtrType, ShallowCopy, Shape, ToBase, ToDim, Unit,
};

/// The pointer used for `CPU` [`Buffer`](crate::Buffer)s
#[derive(Debug)]
Expand Down Expand Up @@ -229,15 +231,6 @@ impl<T: Unit> PtrType for CPUPtr<T> {
}
}

impl<T> WrappedCopy for CPUPtr<T> {
type Base = Self;

#[inline]
fn wrapped_copy(&self, to_wrap: Self::Base) -> Self {
to_wrap
}
}

impl<T> ShallowCopy for CPUPtr<T> {
#[inline]
unsafe fn shallow(&self) -> Self {
Expand Down Expand Up @@ -303,6 +296,27 @@ impl Drop for DeallocWithLayout {
}
}

impl<T> ToDim for CPUPtr<T> {
type Out = Self;

#[inline]
fn to_dim(self) -> Self::Out {
self
}

#[inline]
fn as_dim(&self) -> &Self::Out {
self
}
}

impl<T: Unit, D: Device<Base<T, S> = CPUPtr<T>>, S: Shape> ToBase<T, D, S> for CPUPtr<T> {
#[inline]
fn to_base(self) -> D::Base<T, S> {
self
}
}

#[cfg(feature = "serde")]
pub mod serde {
use core::{fmt, marker::PhantomData};
Expand Down
11 changes: 1 addition & 10 deletions src/devices/stack_array.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use core::ops::{Deref, DerefMut};

use crate::{shape::Shape, HasId, HostPtr, PtrType, ShallowCopy, Unit, WrappedCopy};
use crate::{shape::Shape, HasId, HostPtr, PtrType, ShallowCopy, Unit};

/// A possibly multi-dimensional array allocated on the stack.
/// It uses `S:`[`Shape`] to get the type of the array.
Expand Down Expand Up @@ -137,15 +137,6 @@ impl<S: Shape, T: Unit> HostPtr<T> for StackArray<S, T> {
}
}

impl<S: Shape, T> WrappedCopy for StackArray<S, T> {
type Base = Self;

#[inline]
fn wrapped_copy(&self, to_wrap: Self::Base) -> Self {
to_wrap
}
}

impl<S: Shape, T> ShallowCopy for StackArray<S, T>
where
S::ARR<T>: Copy,
Expand Down
9 changes: 7 additions & 2 deletions src/features.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ pub trait Feature: OnDropBuffer {}
pub trait Retrieve<'a, D, T: Unit, S: Shape = ()>: OnDropBuffer {
// "generator"
#[track_caller]
unsafe fn retrieve_entry<const NUM_PARENTS: usize>(
fn retrieve_entry<const NUM_PARENTS: usize>(
&'a self,
device: &D,
len: usize,
Expand All @@ -38,7 +38,7 @@ pub trait Retrieve<'a, D, T: Unit, S: Shape = ()>: OnDropBuffer {
D: Alloc<T>;

#[track_caller]
unsafe fn retrieve<const NUM_PARENTS: usize>(
fn retrieve<const NUM_PARENTS: usize>(
&self,
device: &D,
len: usize,
Expand Down Expand Up @@ -684,6 +684,11 @@ pub trait CachedBuffers {
) -> Option<RefMut<crate::Buffers<Box<dyn crate::BoxedShallowCopy>>>> {
None
}

#[inline]
fn is_supplied_from_below_module(&self) -> bool {
false
}
}

#[macro_export]
Expand Down
Loading

0 comments on commit 2bb7672

Please sign in to comment.