Skip to content

Commit

Permalink
Add Cursor::cached fn
Browse files Browse the repository at this point in the history
  • Loading branch information
elftausend committed Sep 25, 2024
1 parent 340fc1b commit 0d7bc56
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 2 deletions.
21 changes: 21 additions & 0 deletions src/cache/owned_cache/fast_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,4 +183,25 @@ mod tests {
}
assert_eq!(dev.cursor(), 2);
}

#[cfg(feature = "cpu")]
#[test]
fn test_cache_with_cached_call() {
use crate::{Base, Buffer, Cursor, Retriever};

let dev = CPU::<Cached<Base>>::new();

let mut _buf: Option<Buffer<f32, _>> = None;
_buf = dev.retrieve(10, ()).ok();

for _ in 0..10 {
dev.cached(|| {
_buf = dev.retrieve(10, ()).ok();
_buf = dev.retrieve(10, ()).ok();
_buf = dev.retrieve(10, ()).ok();
let nodes = &dev.modules.cache.borrow().nodes;
assert_eq!(nodes.len(), 4);
});
}
}
}
10 changes: 10 additions & 0 deletions src/features.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,16 @@ pub trait Cursor {
device: self,
}
}

#[inline]
fn cached(&self, mut cb: impl FnMut())
where
Self: Sized,
{
let mut range = self.range(1).into_iter();
cb();
range.next();
}
}

#[macro_export]
Expand Down
2 changes: 1 addition & 1 deletion src/modules/cached.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
AddGradFn, AddLayer, AddOperation, Alloc, Buffer, Cache, CachedBuffers, Cursor, Device,
ExecNow, FastCache, HasId, HasModules, IsShapeIndep, Module, OnDropBuffer, OnNewBuffer,
Parents, PtrType, RemoveLayer, ReplaceBuf, Retrieve, RunModule, SetOpHint, Setup, ShallowCopy,
Shape, Unit, WrappedData, UniqueId
Shape, UniqueId, Unit, WrappedData,
};

#[cfg(feature = "graph")]
Expand Down
3 changes: 2 additions & 1 deletion src/range.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@ impl<'a, D: Cursor> Iterator for CursorRangeIter<'a, D> {
if self.range.start >= self.range.end {
return None;
}
let epoch = self.range.start;

unsafe {
self.range.device.set_cursor(self.previous_cursor);
}
let epoch = self.range.start;
self.range.start += 1;
Some(epoch)
}
Expand Down

0 comments on commit 0d7bc56

Please sign in to comment.