diff --git a/src/cache/owned_cache/fast_cache.rs b/src/cache/owned_cache/fast_cache.rs index 798b09a4..c9eb6321 100644 --- a/src/cache/owned_cache/fast_cache.rs +++ b/src/cache/owned_cache/fast_cache.rs @@ -183,4 +183,25 @@ mod tests { } assert_eq!(dev.cursor(), 2); } + + #[cfg(feature = "cpu")] + #[test] + fn test_cache_with_cached_call() { + use crate::{Base, Buffer, Cursor, Retriever}; + + let dev = CPU::>::new(); + + let mut _buf: Option> = None; + _buf = dev.retrieve(10, ()).ok(); + + for _ in 0..10 { + dev.cached(|| { + _buf = dev.retrieve(10, ()).ok(); + _buf = dev.retrieve(10, ()).ok(); + _buf = dev.retrieve(10, ()).ok(); + let nodes = &dev.modules.cache.borrow().nodes; + assert_eq!(nodes.len(), 4); + }); + } + } } diff --git a/src/features.rs b/src/features.rs index 41b7743a..949fce32 100644 --- a/src/features.rs +++ b/src/features.rs @@ -78,6 +78,16 @@ pub trait Cursor { device: self, } } + + #[inline] + fn cached(&self, mut cb: impl FnMut()) + where + Self: Sized, + { + let mut range = self.range(1).into_iter(); + cb(); + range.next(); + } } #[macro_export] diff --git a/src/modules/cached.rs b/src/modules/cached.rs index 836c688c..c774ae62 100644 --- a/src/modules/cached.rs +++ b/src/modules/cached.rs @@ -7,7 +7,7 @@ use crate::{ AddGradFn, AddLayer, AddOperation, Alloc, Buffer, Cache, CachedBuffers, Cursor, Device, ExecNow, FastCache, HasId, HasModules, IsShapeIndep, Module, OnDropBuffer, OnNewBuffer, Parents, PtrType, RemoveLayer, ReplaceBuf, Retrieve, RunModule, SetOpHint, Setup, ShallowCopy, - Shape, Unit, WrappedData, UniqueId + Shape, UniqueId, Unit, WrappedData, }; #[cfg(feature = "graph")] diff --git a/src/range.rs b/src/range.rs index a2a33c7c..1f3337bf 100644 --- a/src/range.rs +++ b/src/range.rs @@ -25,10 +25,11 @@ impl<'a, D: Cursor> Iterator for CursorRangeIter<'a, D> { if self.range.start >= self.range.end { return None; } + let epoch = self.range.start; + unsafe { self.range.device.set_cursor(self.previous_cursor); } - let epoch = self.range.start; self.range.start += 1; Some(epoch) }