From 07b9e71572f792c882c381ae815d426476599795 Mon Sep 17 00:00:00 2001 From: elftausend <76885970+elftausend@users.noreply.github.com> Date: Mon, 18 Nov 2024 00:50:19 +0100 Subject: [PATCH] Add remove in cb --- Cargo.toml | 2 +- src/modules/autograd.rs | 2 +- src/modules/autograd/wrapper.rs | 24 +++++++++++++++++------- src/modules/lazy/wrapper.rs | 10 +++++++++- 4 files changed, 28 insertions(+), 10 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 08373c4f..1a947d98 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -53,7 +53,7 @@ min-cl = { git = "https://github.com/elftausend/min-cl", optional = true } [features] # default = ["cpu", "blas", "static-api", "macro", "cached", "autograd", "stack", "opencl", "fork", "graph", "untyped"] -default = ["cpu", "cached"] +default = ["cpu", "cached", "autograd"] # default = ["no-std"] # default = ["opencl"] # default = ["untyped", "cpu", "lazy", "graph", "autograd", "fork", "serde", "json", "half", "cached", "static-api", "stack", "opencl", "nnapi"] diff --git a/src/modules/autograd.rs b/src/modules/autograd.rs index 22f96b73..2ec65c5d 100644 --- a/src/modules/autograd.rs +++ b/src/modules/autograd.rs @@ -154,7 +154,7 @@ where Ok(ReqGradWrapper { requires_grad, data, - remove_id_cb: &|id| {}, + remove_id_cb: None, _pd: core::marker::PhantomData, }) } diff --git a/src/modules/autograd/wrapper.rs b/src/modules/autograd/wrapper.rs index 1a20f41a..a84ba226 100644 --- a/src/modules/autograd/wrapper.rs +++ b/src/modules/autograd/wrapper.rs @@ -8,7 +8,7 @@ use crate::{ pub struct ReqGradWrapper<'a, Data, T> { pub requires_grad: bool, pub data: Data, - pub remove_id_cb: Option<&'a dyn Fn(UniqueId)>, + pub remove_id_cb: Option>, pub _pd: PhantomData<&'a T>, } @@ -27,15 +27,25 @@ impl<'dev, Mods: WrappedData> WrappedData for Autograd<'dev, Mods> { type Wrap<'a, T: Unit, Base: IsBasePtr> = ReqGradWrapper<'a, Mods::Wrap<'a, T, Base>, T>; #[inline] - fn wrap_in_base<'a, T: Unit, Base: IsBasePtr>(&self, base: Base) -> Self::Wrap<'a, T, Base> { + fn wrap_in_base<'a, T: Unit, Base: IsBasePtr>(&'a self, base: Base) -> Self::Wrap<'a, T, Base> { ReqGradWrapper { // by default: true -> if lazy layer is (accidentally) put before autograd, all gradients will be computed instead of none.. subject to change requires_grad: true, data: self.modules.wrap_in_base(base), - remove_id_cb: &|id| { - - // unsafe { &mut (*self.grads.get()).no_grads_pool }.remove(&id); - }, + remove_id_cb: Some(Box::new(|id| { + unsafe { &mut (*self.grads.get()).no_grads_pool }.remove(&id); + })), + _pd: PhantomData, + } + } + + #[inline] + fn wrap_in_base_unbound<'a, T: Unit, Base: IsBasePtr>(&self, base: Base) -> Self::Wrap<'a, T, Base> { + ReqGradWrapper { + // by default: true -> if lazy layer is (accidentally) put before autograd, all gradients will be computed instead of none.. subject to change + requires_grad: true, + data: self.modules.wrap_in_base_unbound(base), + remove_id_cb: None, _pd: PhantomData, } } @@ -97,7 +107,7 @@ where ReqGradWrapper { requires_grad: self.requires_grad, data: self.data.shallow(), - remove_id_cb: self.remove_id_cb, + remove_id_cb: None, _pd: PhantomData, } } diff --git a/src/modules/lazy/wrapper.rs b/src/modules/lazy/wrapper.rs index 5bbd81f5..969b6ded 100644 --- a/src/modules/lazy/wrapper.rs +++ b/src/modules/lazy/wrapper.rs @@ -21,12 +21,20 @@ impl WrappedData for Lazy<'_, Mods, T2> { type Wrap<'a, T: Unit, Base: IsBasePtr> = LazyWrapper, T>; #[inline] - fn wrap_in_base<'a, T: Unit, Base: IsBasePtr>(&self, base: Base) -> Self::Wrap<'a, T, Base> { + fn wrap_in_base<'a, T: Unit, Base: IsBasePtr>(&'a self, base: Base) -> Self::Wrap<'a, T, Base> { LazyWrapper { maybe_data: MaybeData::Data(self.modules.wrap_in_base(base)), _pd: PhantomData, } } + + #[inline] + fn wrap_in_base_unbound<'a, T: Unit, Base: IsBasePtr>(&self, base: Base) -> Self::Wrap<'a, T, Base> { + LazyWrapper { + maybe_data: MaybeData::Data(self.modules.wrap_in_base_unbound(base)), + _pd: PhantomData, + } + } #[inline] fn wrapped_as_base<'a, 'b, T: Unit, Base: IsBasePtr>(