Skip to content

Commit

Permalink
Add Drop for autograd wrapper, &mut for on new buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
elftausend committed Nov 17, 2024
1 parent 07b9e71 commit e72eb72
Show file tree
Hide file tree
Showing 7 changed files with 25 additions and 15 deletions.
6 changes: 3 additions & 3 deletions src/buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,13 @@ impl<'a, T: Unit, D: Device, S: Shape> Buffer<'a, T, D, S> {
D: OnNewBuffer<'a, T, D, S>,
{
let data = device.base_to_data(base);
let buf = Buffer {
let mut buf = Buffer {
data,
device: Some(device),
};

// mind: on_new_buffer must be called for user buffers!
unsafe { device.on_new_buffer(device, &buf) };
unsafe { device.on_new_buffer(device, &mut buf) };
buf
}

Expand All @@ -110,7 +110,7 @@ impl<'a, T: Unit, D: Device, S: Shape> Buffer<'a, T, D, S> {
}
let mut buf = self;
buf.set_requires_grad(require_grad);
unsafe { buf.device().on_new_buffer(buf.device(), &buf) };
unsafe { buf.device().on_new_buffer(buf.device(), &mut buf) };
buf
}

Expand Down
2 changes: 1 addition & 1 deletion src/devices.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ macro_rules! impl_buffer_hook_traits {
Self: 'dev,
{
#[inline]
unsafe fn on_new_buffer(&self, device: &'dev D, new_buf: &Buffer<'dev, T, D, S>) {
unsafe fn on_new_buffer(&'dev self, device: &'dev D, new_buf: &mut Buffer<'dev, T, D, S>) {
unsafe { self.modules.on_new_buffer(device, new_buf) }
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/hooks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@ pub trait OnDropBuffer: WrappedData {

pub trait OnNewBuffer<'dev, T: Unit, D: Device, S: Shape = ()> {
#[track_caller]
unsafe fn on_new_buffer<'s>(&'s self, _device: &'dev D, _new_buf: &'s Buffer<'dev, T, D, S>) {}
unsafe fn on_new_buffer<'s>(&'dev self, _device: &'dev D, _new_buf: &'s mut Buffer<'dev, T, D, S>) {}
}
2 changes: 1 addition & 1 deletion src/modules/autograd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ where
Mods: OnNewBuffer<'dev, T, D, S> + CachedBuffers,
{
#[inline]
unsafe fn on_new_buffer(&self, device: &'dev D, new_buf: &Buffer<'dev, T, D, S>) {
unsafe fn on_new_buffer(&'dev self, device: &'dev D, new_buf: &mut Buffer<'dev, T, D, S>) {
// let mut no_grads = self.no_grads_pool.borrow_mut();
// let wrapped_data = unsafe { new_buf.data.shallow() };

Expand Down
24 changes: 17 additions & 7 deletions src/modules/autograd/wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ use crate::{
};

// #[derive(Debug, PartialEq, Eq, PartialOrd, Ord)]
pub struct ReqGradWrapper<'a, Data, T> {
pub struct ReqGradWrapper<'a, Data: HasId, T> {
pub requires_grad: bool,
pub data: Data,
pub remove_id_cb: Option<Box<dyn Fn(UniqueId) + 'a>>,
pub _pd: PhantomData<&'a T>,
}

impl<'a, Data: Debug, T> Debug for ReqGradWrapper<'a, Data, T> {
impl<'a, Data: HasId + Debug, T> Debug for ReqGradWrapper<'a, Data, T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("ReqGradWrapper")
.field("requires_grad", &self.requires_grad)
Expand All @@ -23,6 +23,15 @@ impl<'a, Data: Debug, T> Debug for ReqGradWrapper<'a, Data, T> {
}
}

impl<'a, Data: HasId, T> Drop for ReqGradWrapper<'a, Data, T> {
#[inline]
fn drop(&mut self) {
if let Some(remove_id_cb) = &self.remove_id_cb {
remove_id_cb(*self.id())
}
}
}

impl<'dev, Mods: WrappedData> WrappedData for Autograd<'dev, Mods> {
type Wrap<'a, T: Unit, Base: IsBasePtr> = ReqGradWrapper<'a, Mods::Wrap<'a, T, Base>, T>;

Expand Down Expand Up @@ -82,7 +91,7 @@ impl<'a, Data: HasId, T> HasId for ReqGradWrapper<'a, Data, T> {
}
}

impl<'a, Data: PtrType, T: Unit> PtrType for ReqGradWrapper<'a, Data, T> {
impl<'a, Data: HasId + PtrType, T: Unit> PtrType for ReqGradWrapper<'a, Data, T> {
#[inline]
fn size(&self) -> usize {
self.data.size()
Expand All @@ -101,7 +110,7 @@ impl<'a, Data: PtrType, T: Unit> PtrType for ReqGradWrapper<'a, Data, T> {

impl<'a, Data, T> ShallowCopy for ReqGradWrapper<'a, Data, T>
where
Data: ShallowCopy,
Data: ShallowCopy + HasId,
{
unsafe fn shallow(&self) -> Self {
ReqGradWrapper {
Expand All @@ -113,16 +122,17 @@ where
}
}

impl<'a, T: Unit, S: Shape, Data: ToBase<T, D, S>, T1, D: Device> ToBase<T, D, S>
impl<'a, T: Unit, S: Shape, Data: ToBase<T, D, S> + HasId, T1, D: Device> ToBase<T, D, S>
for ReqGradWrapper<'a, Data, T1>
{
#[inline]
fn to_base(self) -> <D as Device>::Base<T, S> {
self.data.to_base()
todo!()
// self.data.to_base()
}
}

impl<'a, T, Data> ToDim for ReqGradWrapper<'a, Data, T> {
impl<'a, T, Data: HasId> ToDim for ReqGradWrapper<'a, Data, T> {
type Out = Self;

#[inline]
Expand Down
2 changes: 1 addition & 1 deletion src/modules/cached.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ where
S: Shape,
{
#[inline]
unsafe fn on_new_buffer(&self, device: &'a D, new_buf: &Buffer<'a, T, D, S>) {
unsafe fn on_new_buffer(&'a self, device: &'a D, new_buf: &mut Buffer<'a, T, D, S>) {
self.modules.on_new_buffer(device, new_buf)
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/modules/lazy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ where
S: Shape,
{
#[inline]
unsafe fn on_new_buffer<'s>(&'s self, device: &'a D, new_buf: &'s Buffer<'a, T, D, S>) {
unsafe fn on_new_buffer<'s>(&'a self, device: &'a D, new_buf: &'s mut Buffer<'a, T, D, S>) {
unsafe { register_buf_copyable(&mut self.buffers.borrow_mut(), new_buf) };
self.modules.on_new_buffer(device, new_buf)
}
Expand Down

0 comments on commit e72eb72

Please sign in to comment.