Skip to content

Commit

Permalink
Add testing Lazy functionality, feature flag for cuda device
Browse files Browse the repository at this point in the history
  • Loading branch information
elftausend committed Aug 10, 2023
1 parent 7704eb3 commit 6a0a01c
Show file tree
Hide file tree
Showing 10 changed files with 161 additions and 34 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ min-cl = { version = "0.2.0", optional=true }

[features]
#default = ["no-std"]
default = ["blas", "cpu", "stack", "static-api", "macro", "opencl", "autograd", "cuda"]
default = ["blas", "cpu", "stack", "static-api", "macro", "opencl", "autograd"]
#default = ["stack", "macro", "cpu", "blas", "opencl", "static-api", "autograd"]
#default = ["stack", "cpu", "blas", "static-api", "opencl", "macro"]
cpu = []
Expand Down
1 change: 1 addition & 0 deletions src/module_comb/apply_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ pub trait ApplyFunction<T, S: Shape = (), D: Device = Self>: Device {
/// let out = device.apply_fn(&a, |x| x.mul(2.));
/// assert_eq!(&*out, &[2., 4., 6., 6., 4., 2.,]);
/// ```
#[track_caller]
fn apply_fn<F>(&self, buf: &Buffer<T, D, S>, f: impl Fn(Resolve<T>) -> F) -> Buffer<T, Self, S>
where
F: Eval<T> + MayToCLSource;
Expand Down
4 changes: 3 additions & 1 deletion src/module_comb/devices/cpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::{
impl_buffer_hook_traits, impl_retriever,
module_comb::{
Alloc, Base, Buffer, Cached, CachedModule, HasId, HasModules, MainMemory, Module,
OnDropBuffer, OnNewBuffer, Retrieve, Retriever, Setup, TapeActions,
OnDropBuffer, OnNewBuffer, Retrieve, Retriever, Setup, TapeActions, LazySetup,
},
Shape,
};
Expand Down Expand Up @@ -129,3 +129,5 @@ impl<Mods: TapeActions> TapeActions for CPU<Mods> {
self.modules.tape_mut()
}
}

impl<Mods> LazySetup for CPU<Mods> {}
46 changes: 37 additions & 9 deletions src/module_comb/devices/cpu/ops.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use crate::{
module_comb::{
ApplyFunction, Buffer, Device, HasId, MainMemory, OnDropBuffer, Retrieve, Retriever,
TapeActions, WriteBuf,
AddOperation, ApplyFunction, Buffer, Device, HasId, MainMemory, OnDropBuffer, Retrieve,
Retriever, TapeActions, WriteBuf,
},
Shape,
Shape, ToVal,
};

use super::CPU;
Expand All @@ -20,8 +20,25 @@ impl<Mods: OnDropBuffer, T: Copy, D: MainMemory, S: Shape> WriteBuf<T, S, D> for
}
}

impl<Mods: Retrieve<Self> + TapeActions + 'static, T: 'static, S: Shape, D: MainMemory>
ApplyFunction<T, S, D> for CPU<Mods>
impl<Mods: AddOperation> AddOperation for CPU<Mods> {
#[inline]
fn add_operation(&self, operation: impl FnOnce()) {
self.modules.add_operation(operation)
}

#[inline]
fn call_lazily(&self) {
self.modules.call_lazily()
}

}

impl<Mods, T, S, D> ApplyFunction<T, S, D> for CPU<Mods>
where
Mods: Retrieve<Self> + TapeActions + AddOperation + 'static,
T: Copy + Default + ToVal + 'static,
S: Shape,
D: MainMemory + 'static,
{
// actually take &mut Buf instead of returning an owned Buf?
fn apply_fn<F>(
Expand All @@ -31,12 +48,23 @@ impl<Mods: Retrieve<Self> + TapeActions + 'static, T: 'static, S: Shape, D: Main
) -> Buffer<T, Self, S>
where
F: crate::Eval<T> + crate::MayToCLSource,
{
let out = self.retrieve(buf.len());
{
let mut out = self.retrieve(buf.len());
println!("out_ptr: {:?}", out.data.ptr);

let ids = (buf.id(), out.id());
self.add_grad_fn::<T, S, 2>(ids, move |grads| {
// let (lhs, lhs_grad, out_grad) = grads.get_double::<T, S, S>(ids);
self.add_grad_fn::<T, S>(move |grads| {
let (lhs, lhs_grad, out_grad) = grads.get_double::<T, S, S, D>(ids);
});

// self.apply_fn(buf, f)

self.add_operation(|| {
// let mut out = self.retrieve::<T, S>(buf.len());
// println!("out_ptr closure: {:?}", out.data.ptr);
for (x, out) in buf.iter().zip(out.iter_mut()) {
*out = f((*x).to_val()).eval();
}
});

out
Expand Down
2 changes: 2 additions & 0 deletions src/module_comb/devices/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
mod cpu;
pub use cpu::*;

#[cfg(feature = "cuda")]
mod cuda;
#[cfg(feature = "cuda")]
pub use cuda::*;

use super::{Alloc, OnDropBuffer};
Expand Down
15 changes: 10 additions & 5 deletions src/module_comb/features.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,19 +50,19 @@ pub trait TapeActions {
// use track caller to identify a specific grad function
//-> if backward is not called (.drain()), the grad fn vector will gradually fill up
#[track_caller]
fn add_grad_fn<T, S: Shape, const N: usize>(
fn add_grad_fn<T, S: Shape>(
&self,
ids: impl AllocGradsFrom<N>,
// ids: impl AllocGradsFrom<N>,
grad_fn: impl Fn(&mut Gradients) + 'static,
) where
T: 'static,
Self: Device + 'static,
{
if let Some(mut tape) = self.tape_mut() {
// the type T must match for every Id!
for id in ids.ids() {
tape.grads.grads_pool.add_buf_once::<T, Self, S>(self, id)
}
// for id in ids.ids() {
// tape.grads.grads_pool.add_buf_once::<T, Self, S>(self, id)
// }

tape.add_grad_fn(grad_fn)
}
Expand Down Expand Up @@ -100,3 +100,8 @@ impl<const N: usize> AllocGradsFrom<N> for [Id; N] {
self
}
}

pub trait AddOperation {
fn add_operation(&self, operation: impl FnOnce());
fn call_lazily(&self) {}
}
34 changes: 26 additions & 8 deletions src/module_comb/modules/autograd/gradients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,11 @@ impl Gradients {

/// Returns the forward [`Buffer`] x and the gradient `Buffer`s x_grad and out_grad.
/// Useful for unary operations.
///
#[inline]
pub fn get_double<'a, T, IS, OS, D>(
&mut self,
device: &'a D,
// device: &'a D,
(xid, oid): (Id, Id),
) -> (
&Buffer<'a, T, D, IS>,
Expand All @@ -154,9 +155,10 @@ impl Gradients {
OS: Shape,
D: Device + 'static,
{
self.grads_pool.add_buf_once::<T, _, IS>(device, oid);
// self.grads_pool.add_buf_once::<T, _, IS>(device, oid);

let x_grad_ptr = self.get_mut(device, xid) as *mut _;
// let x_grad_ptr = self.get_mut(device, xid) as *mut _;
let x_grad_ptr = self.may_get_mut(xid).unwrap() as *mut _;
let x_grad_mut = unsafe { &mut *x_grad_ptr };
let o_grad = self.may_get_ref(oid).unwrap();

Expand All @@ -166,14 +168,30 @@ impl Gradients {

#[cfg(test)]
mod tests {
use core::borrow::BorrowMut;
use crate::module_comb::{Autograd, Base, Buffer, HasId, Retriever, CPU};

#[test]
fn test_same_types_get_double_return() {
let device = CPU::<Autograd<Base>>::new();

// let mut gradients = Gradients::default();

let buf = Buffer::<i32, _>::new(&device, 10);
// unsafe { register_buf(&mut gradients.no_grads_pool.borrow_mut().cache, &buf) }

use crate::module_comb::{register_buf, Autograd, Base, Buffer, HasId, Retriever, CPU};
let out = device.retrieve::<i32, ()>(buf.len());
// unsafe { register_buf(&mut gradients.no_grads_pool.borrow_mut().cache, &out) }

use super::Gradients;
device
.modules
.tape
.borrow_mut()
.grads
.get_double::<i32, (), (), CPU<Autograd<crate::module_comb::CachedModule<Base, CPU<Autograd<Base>>>>>>((buf.id(), out.id()));
}

#[test]
//#[should_panic]
#[should_panic]
fn test_different_types_get_double_return() {
let device = CPU::<Autograd<Base>>::new();

Expand All @@ -190,6 +208,6 @@ mod tests {
.tape
.borrow_mut()
.grads
.get_double::<i32, (), (), _>(&device, (buf.id(), out.id()));
.get_double::<i32, (), (), CPU<Autograd<crate::module_comb::CachedModule<Base, CPU<Autograd<Base>>>>>>((buf.id(), out.id()));
}
}
9 changes: 8 additions & 1 deletion src/module_comb/modules/base.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
flag::AllocFlag,
module_comb::{Alloc, Device, Module, OnDropBuffer, OnNewBuffer, Retrieve, Setup, TapeActions},
module_comb::{Alloc, Device, Module, OnDropBuffer, OnNewBuffer, Retrieve, Setup, TapeActions, AddOperation},
Shape,
};

Expand All @@ -16,6 +16,13 @@ impl<D> Module<D> for Base {
}
}

impl AddOperation for Base {
#[inline]
fn add_operation(&self, mut operation: impl FnOnce()) {
operation();
}
}

impl<D> Setup<D> for Base {}

impl<T, D: Device, S: Shape> OnNewBuffer<T, D, S> for Base {}
Expand Down
79 changes: 71 additions & 8 deletions src/module_comb/modules/lazy.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
use core::marker::PhantomData;
use core::{marker::PhantomData, fmt::Debug, cell::RefCell};

use crate::{
module_comb::{Alloc, Buffer, Device, Module, OnDropBuffer, Retrieve, Setup},
module_comb::{Alloc, Buffer, Device, Module, OnDropBuffer, Retrieve, Setup, AddOperation, OnNewBuffer, TapeActions},
Shape,
};

#[derive(Debug, Default)]
#[derive(Default)]
pub struct Lazy<Mods> {
mods: Mods,
ops: RefCell<Vec<Box<dyn FnOnce() + 'static>>>
}

impl<Mods: Debug> Debug for Lazy<Mods> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("Lazy").field("mods", &self.mods).field("ops_count", &self.ops.borrow().len()).finish()
}
}

pub trait LazySetup {
Expand All @@ -19,10 +26,29 @@ impl<Mods: Module<D>, D: LazySetup> Module<D> for Lazy<Mods> {

#[inline]
fn new() -> Self::Module {
Lazy { mods: Mods::new() }
Lazy { mods: Mods::new(), ops: Default::default() }
}
}

impl<Mods> AddOperation for Lazy<Mods> {
#[inline]
fn add_operation(&self, operation: impl FnOnce()) {
let operation: Box<dyn FnOnce()> = Box::new(operation);
let operation: Box<dyn FnOnce() + 'static> = unsafe {
std::mem::transmute(operation)
};
self.ops.borrow_mut().push(operation)
}

#[inline]
fn call_lazily(&self) {
for op in self.ops.borrow_mut().drain(..) {
op()
}
}

}

impl<D: LazySetup, Mods: Setup<D>> Setup<D> for Lazy<Mods> {
#[inline]
fn setup(device: &mut D) {
Expand All @@ -39,28 +65,49 @@ impl<Mods: OnDropBuffer> OnDropBuffer for Lazy<Mods> {
}
}

impl<Mods: OnDropBuffer, D> Retrieve<D> for Lazy<Mods> {
impl<T, D: Device, S: Shape, Mods: OnNewBuffer<T, D, S>> OnNewBuffer<T, D, S> for Lazy<Mods> {
#[inline]
fn on_new_buffer(&self, device: &D, new_buf: &Buffer<T, D, S>) {
self.mods.on_new_buffer(device, new_buf)
}
}

impl<Mods: TapeActions> TapeActions for Lazy<Mods> {
#[inline]
fn tape(&self) -> Option<core::cell::Ref<super::Tape>> {
self.mods.tape()
}

#[inline]
fn tape_mut(&self) -> Option<core::cell::RefMut<super::Tape>> {
self.mods.tape_mut()
}
}

impl<Mods: Retrieve<D>, D> Retrieve<D> for Lazy<Mods> {
#[inline]
fn retrieve<T, S: crate::Shape>(&self, device: &D, len: usize) -> <D>::Data<T, S>
where
T: 'static,
D: crate::module_comb::Alloc,
{
todo!()
self.mods.retrieve(device, len)
}

#[inline]
fn on_retrieve_finish<T, S: Shape>(&self, retrieved_buf: &Buffer<T, D, S>)
where
T: 'static,
D: Device,
{
// pass down
todo!()
self.mods.on_retrieve_finish(retrieved_buf)
}
}

#[cfg(test)]
mod tests {
use crate::module_comb::{Alloc, Base, CPU, CUDA};
use crate::{module_comb::{Alloc, Base, CPU, Buffer, AddOperation}, Combiner};

use super::Lazy;

Expand All @@ -75,4 +122,20 @@ mod tests {
// let device = CUDA::<Lazy<Base>>::new();
// let data = device.alloc::<f32, ()>(10, crate::flag::AllocFlag::None);
}

use crate::module_comb::ApplyFunction;

#[test]
fn test_lazy_execution() {
let device = CPU::<Base>::new();

let buf = Buffer::<f32, _>::new(&device, 10);
let out = device.apply_fn(&buf, |x| x.add(3.));

device.call_lazily();
println!("out: {:?}", &*out);

drop(out);
drop(buf);
}
}
3 changes: 2 additions & 1 deletion src/two_way_ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,8 @@ mod tests {

let a = f(4f32.to_val(), 3f32.to_val());

roughly_eq_slices(&[a.eval()], &[22.2]);
// roughly_eq_slices(&[a.eval()], &[22.2]);
assert_eq!(a.eval(), 22.2);

let r = f("x".to_marker(), "y".to_marker()).to_cl_source();
assert_eq!("(((x + y) * 3.6) - y)", r);
Expand Down

0 comments on commit 6a0a01c

Please sign in to comment.