Skip to content

Commit

Permalink
Add new fn to ReqGradWrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
elftausend committed Nov 18, 2024
1 parent af91dba commit 1072c95
Showing 1 changed file with 18 additions and 14 deletions.
32 changes: 18 additions & 14 deletions src/modules/autograd/wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,35 +33,39 @@ impl<'a, Data: HasId, T> Drop for ReqGradWrapper<'a, Data, T> {
}
}

impl<'a, Data: HasId, T> ReqGradWrapper<'a, Data, T> {
#[inline]
pub fn new(data: Data, remove_id_cb: Option<Box<dyn Fn(UniqueId) + 'a>>) -> Self {
// by default: true -> if lazy layer is (accidentally) put before autograd, all gradients will be computed instead of none.. subject to change
ReqGradWrapper {
requires_grad: true,
data,
remove_id_cb,
_pd: PhantomData,
}
}
}

impl<'dev, Mods: WrappedData> WrappedData for Autograd<'dev, Mods> {
type Wrap<'a, T: Unit, Base: IsBasePtr> = ReqGradWrapper<'a, Mods::Wrap<'a, T, Base>, T>;

#[inline]
fn wrap_in_base<'a, T: Unit, Base: IsBasePtr>(&'a self, base: Base) -> Self::Wrap<'a, T, Base> {
ReqGradWrapper {
// by default: true -> if lazy layer is (accidentally) put before autograd, all gradients will be computed instead of none.. subject to change
requires_grad: true,
data: self.modules.wrap_in_base(base),
remove_id_cb: Some(Box::new(|id| {
ReqGradWrapper::new(
self.modules.wrap_in_base(base),
Some(Box::new(|id| {
unsafe { (*self.grads.get()).buf_requires_grad.remove(&id) };
unsafe { (*self.grads.get()).no_grads_pool.remove(&id) };
})),
_pd: PhantomData,
}
)
}

#[inline]
fn wrap_in_base_unbound<'a, T: Unit, Base: IsBasePtr>(
&self,
base: Base,
) -> Self::Wrap<'a, T, Base> {
ReqGradWrapper {
// by default: true -> if lazy layer is (accidentally) put before autograd, all gradients will be computed instead of none.. subject to change
requires_grad: true,
data: self.modules.wrap_in_base_unbound(base),
remove_id_cb: None,
_pd: PhantomData,
}
ReqGradWrapper::new(self.modules.wrap_in_base_unbound(base), None)
}

#[inline]
Expand Down

0 comments on commit 1072c95

Please sign in to comment.