Skip to content

Commit

Permalink
Add Unit bound to T: Wrap<T>
Browse files Browse the repository at this point in the history
  • Loading branch information
elftausend committed Oct 22, 2024
1 parent 9c94c26 commit afdb963
Show file tree
Hide file tree
Showing 17 changed files with 71 additions and 68 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ min-cl = { git = "https://github.com/elftausend/min-cl", optional = true }
# min-cl = { version = "0.3.0", optional=true }

[features]
default = ["cpu", "blas", "static-api", "macro", "cached", "autograd", "vulkan", "stack"]
default = ["cpu", "blas", "static-api", "macro", "cached", "autograd", "vulkan", "stack", "opencl", "cuda"]

# default = ["cpu"]
# default = ["no-std"]
Expand Down
8 changes: 4 additions & 4 deletions examples/custom_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,20 @@ fn main() {
// Implementing pass down traits / features

impl<Mods: WrappedData> WrappedData for CustomModule<Mods> {
type Wrap<T, Base: HasId + PtrType> = Mods::Wrap<T, Base>;
type Wrap<T: Unit, Base: HasId + PtrType> = Mods::Wrap<T, Base>;

#[inline]
fn wrap_in_base<T, Base: HasId + PtrType>(&self, base: Base) -> Self::Wrap<T, Base> {
fn wrap_in_base<T: Unit, Base: HasId + PtrType>(&self, base: Base) -> Self::Wrap<T, Base> {
self.mods.wrap_in_base(base)
}

#[inline]
fn wrapped_as_base<T, Base: HasId + PtrType>(wrap: &Self::Wrap<T, Base>) -> &Base {
fn wrapped_as_base<T: Unit, Base: HasId + PtrType>(wrap: &Self::Wrap<T, Base>) -> &Base {
Mods::wrapped_as_base(wrap)
}

#[inline]
fn wrapped_as_base_mut<T, Base: HasId + PtrType>(wrap: &mut Self::Wrap<T, Base>) -> &mut Base {
fn wrapped_as_base_mut<T: Unit, Base: HasId + PtrType>(wrap: &mut Self::Wrap<T, Base>) -> &mut Base {
Mods::wrapped_as_base_mut(wrap)
}
}
Expand Down
10 changes: 5 additions & 5 deletions src/buffer/num.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ pub struct Num<T> {
pub num: T,
}

impl<T> PtrType for Num<T> {
impl<T: Unit> PtrType for Num<T> {
#[inline]
fn size(&self) -> usize {
0
Expand Down Expand Up @@ -107,20 +107,20 @@ impl<T: Unit + Default> Alloc<T> for () {
}

impl WrappedData for () {
type Wrap<T, Base: crate::HasId + crate::PtrType> = Base;
type Wrap<T: Unit, Base: crate::HasId + crate::PtrType> = Base;

#[inline]
fn wrap_in_base<T, Base: HasId + PtrType>(&self, base: Base) -> Self::Wrap<T, Base> {
fn wrap_in_base<T: Unit, Base: HasId + PtrType>(&self, base: Base) -> Self::Wrap<T, Base> {
base
}

#[inline]
fn wrapped_as_base<T, Base: HasId + PtrType>(wrap: &Self::Wrap<T, Base>) -> &Base {
fn wrapped_as_base<T: Unit, Base: HasId + PtrType>(wrap: &Self::Wrap<T, Base>) -> &Base {
wrap
}

#[inline]
fn wrapped_as_base_mut<T, Base: HasId + PtrType>(wrap: &mut Self::Wrap<T, Base>) -> &mut Base {
fn wrapped_as_base_mut<T: Unit, Base: HasId + PtrType>(wrap: &mut Self::Wrap<T, Base>) -> &mut Base {
wrap
}
}
Expand Down
6 changes: 3 additions & 3 deletions src/devices/cpu/cpu_ptr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use core::{

use std::alloc::handle_alloc_error;

use crate::{flag::AllocFlag, HasId, HostPtr, Id, PtrType, ShallowCopy, WrappedCopy};
use crate::{flag::AllocFlag, HasId, HostPtr, Id, PtrType, ShallowCopy, Unit, WrappedCopy};

/// The pointer used for `CPU` [`Buffer`](crate::Buffer)s
#[derive(Debug)]
Expand Down Expand Up @@ -151,7 +151,7 @@ impl<T> CPUPtr<T> {
}
}

impl<T> HostPtr<T> for CPUPtr<T> {
impl<T: Unit> HostPtr<T> for CPUPtr<T> {
#[inline]
fn ptr(&self) -> *const T {
self.ptr
Expand Down Expand Up @@ -212,7 +212,7 @@ impl<T> Drop for CPUPtr<T> {
}
}

impl<T> PtrType for CPUPtr<T> {
impl<T: Unit> PtrType for CPUPtr<T> {
#[inline]
fn size(&self) -> usize {
self.len
Expand Down
4 changes: 2 additions & 2 deletions src/devices/cuda/cuda_ptr.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::api::{cu_read, cufree, cumalloc, CudaResult};
use crate::{flag::AllocFlag, HasId, Id, PtrType, ShallowCopy, WrappedCopy};
use crate::{flag::AllocFlag, HasId, Id, PtrType, ShallowCopy, Unit, WrappedCopy};
use core::marker::PhantomData;

/// The pointer used for `CUDA` [`Buffer`](crate::Buffer)s
Expand Down Expand Up @@ -97,7 +97,7 @@ impl<T> ShallowCopy for CUDAPtr<T> {
}
}

impl<T> PtrType for CUDAPtr<T> {
impl<T: Unit> PtrType for CUDAPtr<T> {
#[inline]
fn size(&self) -> usize {
self.len
Expand Down
10 changes: 5 additions & 5 deletions src/devices/opencl/cl_ptr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::HostPtr;

use min_cl::api::release_mem_object;

use crate::{flag::AllocFlag, HasId, Id, PtrType, ShallowCopy, WrappedCopy};
use crate::{flag::AllocFlag, HasId, Id, PtrType, ShallowCopy, Unit, WrappedCopy};

/// The pointer used for `OpenCL` [`Buffer`](crate::Buffer)s
#[derive(Debug, PartialEq, Eq)]
Expand Down Expand Up @@ -80,7 +80,7 @@ impl<T> ShallowCopy for CLPtr<T> {
}
}

impl<T> PtrType for CLPtr<T> {
impl<T: Unit> PtrType for CLPtr<T> {
#[inline]
fn size(&self) -> usize {
self.len
Expand All @@ -98,7 +98,7 @@ impl<T> PtrType for CLPtr<T> {
}

#[cfg(unified_cl)]
impl<T> HostPtr<T> for CLPtr<T> {
impl<T: Unit> HostPtr<T> for CLPtr<T> {
#[inline]
fn ptr(&self) -> *const T {
self.host_ptr
Expand All @@ -111,7 +111,7 @@ impl<T> HostPtr<T> for CLPtr<T> {
}

#[cfg(unified_cl)]
impl<T> Deref for CLPtr<T> {
impl<T: Unit> Deref for CLPtr<T> {
type Target = [T];

#[inline]
Expand All @@ -121,7 +121,7 @@ impl<T> Deref for CLPtr<T> {
}

#[cfg(unified_cl)]
impl<T> DerefMut for CLPtr<T> {
impl<T: Unit> DerefMut for CLPtr<T> {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { self.as_mut_slice() }
Expand Down
6 changes: 3 additions & 3 deletions src/devices/stack_array.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use core::ops::{Deref, DerefMut};

use crate::{shape::Shape, HasId, HostPtr, PtrType, ShallowCopy, WrappedCopy};
use crate::{shape::Shape, HasId, HostPtr, PtrType, ShallowCopy, Unit, WrappedCopy};

/// A possibly multi-dimensional array allocated on the stack.
/// It uses `S:`[`Shape`] to get the type of the array.
Expand Down Expand Up @@ -110,7 +110,7 @@ impl<S: Shape, T> DerefMut for StackArray<S, T> {
}
}

impl<S: Shape, T> PtrType for StackArray<S, T> {
impl<S: Shape, T: Unit> PtrType for StackArray<S, T> {
#[inline]
fn size(&self) -> usize {
S::LEN
Expand All @@ -125,7 +125,7 @@ impl<S: Shape, T> PtrType for StackArray<S, T> {
unsafe fn set_flag(&mut self, _flag: crate::flag::AllocFlag) {}
}

impl<S: Shape, T> HostPtr<T> for StackArray<S, T> {
impl<S: Shape, T: Unit> HostPtr<T> for StackArray<S, T> {
#[inline]
fn ptr(&self) -> *const T {
self.as_ptr()
Expand Down
12 changes: 6 additions & 6 deletions src/devices/vulkan/vk_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use core::{
};
use std::rc::Rc;

use crate::{flag::AllocFlag, HasId, HostPtr, PtrType, ShallowCopy, WrappedCopy};
use crate::{flag::AllocFlag, HasId, HostPtr, PtrType, ShallowCopy, Unit, WrappedCopy};

use super::{context::Context, submit_and_wait};

Expand All @@ -25,7 +25,7 @@ pub struct VkArray<T> {
unsafe impl<T: Sync> Sync for VkArray<T> {}
unsafe impl<T: Send> Send for VkArray<T> {}

impl<T> PtrType for VkArray<T> {
impl<T: Unit> PtrType for VkArray<T> {
#[inline]
fn size(&self) -> usize {
self.len
Expand All @@ -51,7 +51,7 @@ impl<T> HasId for VkArray<T> {
}
}

impl<T> VkArray<T> {
impl<T: Unit> VkArray<T> {
pub fn new(
context: Rc<Context>,
len: usize,
Expand Down Expand Up @@ -267,7 +267,7 @@ impl<T> Drop for VkArray<T> {
}
}

impl<T> HostPtr<T> for VkArray<T> {
impl<T: Unit> HostPtr<T> for VkArray<T> {
#[inline]
fn ptr(&self) -> *const T {
self.mapped_ptr
Expand All @@ -280,7 +280,7 @@ impl<T> HostPtr<T> for VkArray<T> {
}

// TODO: impl deref only when using unified memory
impl<T> Deref for VkArray<T> {
impl<T: Unit> Deref for VkArray<T> {
type Target = [T];

#[inline]
Expand All @@ -290,7 +290,7 @@ impl<T> Deref for VkArray<T> {
}
}

impl<T> DerefMut for VkArray<T> {
impl<T: Unit> DerefMut for VkArray<T> {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
assert!(!self.ptr_mut().is_null());
Expand Down
8 changes: 4 additions & 4 deletions src/devices/wgsl/wgsl_device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,25 +66,25 @@ impl<D: Device, Mods: OnDropBuffer> Device for Wgsl<D, Mods> {
}

impl<D: Device, Mods: WrappedData> WrappedData for Wgsl<D, Mods> {
type Wrap<T, Base: HasId + PtrType> = Mods::Wrap<T, Base>;
type Wrap<T: Unit, Base: HasId + PtrType> = Mods::Wrap<T, Base>;

#[inline]
fn wrap_in_base<T, Base: crate::HasId + crate::PtrType>(
fn wrap_in_base<T: Unit, Base: crate::HasId + crate::PtrType>(
&self,
base: Base,
) -> Self::Wrap<T, Base> {
self.modules.wrap_in_base(base)
}

#[inline]
fn wrapped_as_base<T, Base: crate::HasId + crate::PtrType>(
fn wrapped_as_base<T: Unit, Base: crate::HasId + crate::PtrType>(
wrap: &Self::Wrap<T, Base>,
) -> &Base {
Mods::wrapped_as_base(wrap)
}

#[inline]
fn wrapped_as_base_mut<T, Base: crate::HasId + crate::PtrType>(
fn wrapped_as_base_mut<T: Unit, Base: crate::HasId + crate::PtrType>(
wrap: &mut Self::Wrap<T, Base>,
) -> &mut Base {
Mods::wrapped_as_base_mut(wrap)
Expand Down
4 changes: 2 additions & 2 deletions src/exec_on_cpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ mod tests {
Ok(())
}

pub trait AddEw<T, D: crate::Device = Self>: crate::Device {
pub trait AddEw<T: 'static, D: crate::Device = Self>: crate::Device {
#[allow(dead_code)]
fn add(&self, lhs: &crate::Buffer<T, D>, rhs: &crate::Buffer<T, D>) -> crate::Buffer<T, D>;
}
Expand All @@ -443,7 +443,7 @@ mod tests {
where
Mods: crate::hooks::OnDropBuffer + crate::Retrieve<Self, T> + 'static,
Self::Base<T, ()>: core::ops::Deref<Target = [T]>,
T: core::ops::Add<Output = T> + Copy,
T: 'static + core::ops::Add<Output = T> + Copy,
{
fn add(
&self,
Expand Down
6 changes: 3 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ pub(crate) type OperationFn<B> =
Box<dyn Fn(&[Id], &mut Buffers<B>, &dyn core::any::Any) -> crate::Result<()> + 'static>;

/// This trait is implemented for every pointer type.
pub trait PtrType {
pub trait PtrType: 'static {
/// Returns the element count.
fn size(&self) -> usize;
/// Returns the [`AllocFlag`].
Expand All @@ -178,9 +178,9 @@ pub trait HostPtr<T>: PtrType {
}

/// Minimum requirements for an element inside a Buffer.
pub trait Unit {} // useful for Sync and Send or 'static
pub trait Unit: 'static {} // useful for Sync and Send or 'static

impl<T> Unit for T {}
impl<T: 'static> Unit for T {}

pub trait WrappedCopy {
type Base;
Expand Down
12 changes: 6 additions & 6 deletions src/modules/autograd/wrapper.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use core::marker::PhantomData;

use crate::{flag::AllocFlag, Autograd, HasId, PtrType, ShallowCopy, WrappedCopy, WrappedData};
use crate::{flag::AllocFlag, Autograd, HasId, PtrType, ShallowCopy, Unit, WrappedCopy, WrappedData};

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)]
pub struct ReqGradWrapper<Data, T> {
Expand All @@ -10,10 +10,10 @@ pub struct ReqGradWrapper<Data, T> {
}

impl<'dev, Mods: WrappedData> WrappedData for Autograd<'dev, Mods> {
type Wrap<T, Base: crate::HasId + crate::PtrType> = ReqGradWrapper<Mods::Wrap<T, Base>, T>;
type Wrap<T: Unit, Base: crate::HasId + crate::PtrType> = ReqGradWrapper<Mods::Wrap<T, Base>, T>;

#[inline]
fn wrap_in_base<T, Base: crate::HasId + crate::PtrType>(
fn wrap_in_base<T: Unit, Base: crate::HasId + crate::PtrType>(
&self,
base: Base,
) -> Self::Wrap<T, Base> {
Expand All @@ -26,14 +26,14 @@ impl<'dev, Mods: WrappedData> WrappedData for Autograd<'dev, Mods> {
}

#[inline]
fn wrapped_as_base<T, Base: crate::HasId + crate::PtrType>(
fn wrapped_as_base<T: Unit, Base: crate::HasId + crate::PtrType>(
wrap: &Self::Wrap<T, Base>,
) -> &Base {
Mods::wrapped_as_base(&wrap.data)
}

#[inline]
fn wrapped_as_base_mut<T, Base: crate::HasId + crate::PtrType>(
fn wrapped_as_base_mut<T: Unit, Base: crate::HasId + crate::PtrType>(
wrap: &mut Self::Wrap<T, Base>,
) -> &mut Base {
Mods::wrapped_as_base_mut(&mut wrap.data)
Expand All @@ -57,7 +57,7 @@ impl<Data: HasId, T> HasId for ReqGradWrapper<Data, T> {
}
}

impl<Data: PtrType, T> PtrType for ReqGradWrapper<Data, T> {
impl<Data: PtrType, T: Unit> PtrType for ReqGradWrapper<Data, T> {
#[inline]
fn size(&self) -> usize {
self.data.size()
Expand Down
8 changes: 4 additions & 4 deletions src/modules/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,20 @@ use crate::{
pub struct Base;

impl WrappedData for Base {
type Wrap<T, Base: HasId + PtrType> = Base;
type Wrap<T: Unit, Base: HasId + PtrType> = Base;

#[inline]
fn wrap_in_base<T, Base: HasId + PtrType>(&self, base: Base) -> Self::Wrap<T, Base> {
fn wrap_in_base<T: Unit, Base: HasId + PtrType>(&self, base: Base) -> Self::Wrap<T, Base> {
base
}

#[inline]
fn wrapped_as_base<T, Base: HasId + PtrType>(wrap: &Self::Wrap<T, Base>) -> &Base {
fn wrapped_as_base<T: Unit, Base: HasId + PtrType>(wrap: &Self::Wrap<T, Base>) -> &Base {
wrap
}

#[inline]
fn wrapped_as_base_mut<T, Base: HasId + PtrType>(wrap: &mut Self::Wrap<T, Base>) -> &mut Base {
fn wrapped_as_base_mut<T: Unit, Base: HasId + PtrType>(wrap: &mut Self::Wrap<T, Base>) -> &mut Base {
wrap
}
}
Expand Down
Loading

0 comments on commit afdb963

Please sign in to comment.