Skip to content

Commit

Permalink
Add remove in cb
Browse files Browse the repository at this point in the history
  • Loading branch information
elftausend committed Nov 17, 2024
1 parent 2dca3c2 commit 07b9e71
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 10 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ min-cl = { git = "https://github.com/elftausend/min-cl", optional = true }
[features]
# default = ["cpu", "blas", "static-api", "macro", "cached", "autograd", "stack", "opencl", "fork", "graph", "untyped"]

default = ["cpu", "cached"]
default = ["cpu", "cached", "autograd"]
# default = ["no-std"]
# default = ["opencl"]
# default = ["untyped", "cpu", "lazy", "graph", "autograd", "fork", "serde", "json", "half", "cached", "static-api", "stack", "opencl", "nnapi"]
Expand Down
2 changes: 1 addition & 1 deletion src/modules/autograd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ where
Ok(ReqGradWrapper {
requires_grad,
data,
remove_id_cb: &|id| {},
remove_id_cb: None,
_pd: core::marker::PhantomData,
})
}
Expand Down
24 changes: 17 additions & 7 deletions src/modules/autograd/wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::{
pub struct ReqGradWrapper<'a, Data, T> {
pub requires_grad: bool,
pub data: Data,
pub remove_id_cb: Option<&'a dyn Fn(UniqueId)>,
pub remove_id_cb: Option<Box<dyn Fn(UniqueId) + 'a>>,
pub _pd: PhantomData<&'a T>,
}

Expand All @@ -27,15 +27,25 @@ impl<'dev, Mods: WrappedData> WrappedData for Autograd<'dev, Mods> {
type Wrap<'a, T: Unit, Base: IsBasePtr> = ReqGradWrapper<'a, Mods::Wrap<'a, T, Base>, T>;

#[inline]
fn wrap_in_base<'a, T: Unit, Base: IsBasePtr>(&self, base: Base) -> Self::Wrap<'a, T, Base> {
fn wrap_in_base<'a, T: Unit, Base: IsBasePtr>(&'a self, base: Base) -> Self::Wrap<'a, T, Base> {
ReqGradWrapper {
// by default: true -> if lazy layer is (accidentally) put before autograd, all gradients will be computed instead of none.. subject to change
requires_grad: true,
data: self.modules.wrap_in_base(base),
remove_id_cb: &|id| {

// unsafe { &mut (*self.grads.get()).no_grads_pool }.remove(&id);
},
remove_id_cb: Some(Box::new(|id| {
unsafe { &mut (*self.grads.get()).no_grads_pool }.remove(&id);
})),
_pd: PhantomData,
}
}

#[inline]
fn wrap_in_base_unbound<'a, T: Unit, Base: IsBasePtr>(&self, base: Base) -> Self::Wrap<'a, T, Base> {
ReqGradWrapper {
// by default: true -> if lazy layer is (accidentally) put before autograd, all gradients will be computed instead of none.. subject to change
requires_grad: true,
data: self.modules.wrap_in_base_unbound(base),
remove_id_cb: None,
_pd: PhantomData,
}
}
Expand Down Expand Up @@ -97,7 +107,7 @@ where
ReqGradWrapper {
requires_grad: self.requires_grad,
data: self.data.shallow(),
remove_id_cb: self.remove_id_cb,
remove_id_cb: None,
_pd: PhantomData,
}
}
Expand Down
10 changes: 9 additions & 1 deletion src/modules/lazy/wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,20 @@ impl<T2, Mods: WrappedData> WrappedData for Lazy<'_, Mods, T2> {
type Wrap<'a, T: Unit, Base: IsBasePtr> = LazyWrapper<Mods::Wrap<'a, T, Base>, T>;

#[inline]
fn wrap_in_base<'a, T: Unit, Base: IsBasePtr>(&self, base: Base) -> Self::Wrap<'a, T, Base> {
fn wrap_in_base<'a, T: Unit, Base: IsBasePtr>(&'a self, base: Base) -> Self::Wrap<'a, T, Base> {
LazyWrapper {
maybe_data: MaybeData::Data(self.modules.wrap_in_base(base)),
_pd: PhantomData,
}
}

#[inline]
fn wrap_in_base_unbound<'a, T: Unit, Base: IsBasePtr>(&self, base: Base) -> Self::Wrap<'a, T, Base> {
LazyWrapper {
maybe_data: MaybeData::Data(self.modules.wrap_in_base_unbound(base)),
_pd: PhantomData,
}
}

#[inline]
fn wrapped_as_base<'a, 'b, T: Unit, Base: IsBasePtr>(
Expand Down

0 comments on commit 07b9e71

Please sign in to comment.