diff --git a/src/cache/locking/locked_map.rs b/src/cache/locking/locked_map.rs index 395e80ab..ea81f07e 100644 --- a/src/cache/locking/locked_map.rs +++ b/src/cache/locking/locked_map.rs @@ -1,6 +1,5 @@ -use core::ops::Deref; use std::{ - cell::{Ref, RefCell, RefMut, UnsafeCell}, + cell::{Ref, RefCell, RefMut}, collections::HashMap, hash::{BuildHasher, Hash, RandomState}, }; @@ -32,7 +31,7 @@ impl LockedMap { pub fn len(&self) -> usize { self.data.borrow().len() } - + #[inline] pub fn is_empty(&self) -> bool { self.data.borrow().is_empty() diff --git a/src/cache/owned_cache.rs b/src/cache/owned_cache.rs index ead8c03e..520fb8cc 100644 --- a/src/cache/owned_cache.rs +++ b/src/cache/owned_cache.rs @@ -1,11 +1,22 @@ -mod fast_cache; -pub use fast_cache::*; +// mod fast_cache; +mod fast_cache2; +use core::cell::RefMut; -mod length_cache; -pub use length_cache::*; +// pub use fast_cache::*; +pub use fast_cache2::*; + +// mod length_cache; +// pub use length_cache::*; use crate::{Alloc, ShallowCopy, Shape, UniqueId, Unit}; +use super::State; + +pub trait Cache2 { + fn get_mut(&self, id: UniqueId, len: usize) -> State>; + fn insert(&self, id: UniqueId, len: usize, data: T); +} + pub trait Cache { unsafe fn get( &mut self, diff --git a/src/modules/cached.rs b/src/modules/cached.rs index 7cb94095..4b956459 100644 --- a/src/modules/cached.rs +++ b/src/modules/cached.rs @@ -1,13 +1,10 @@ use core::{ - cell::{Cell, RefCell, RefMut}, - marker::PhantomData, + any::Any, cell::{Cell, RefCell, RefMut}, marker::PhantomData }; +use std::sync::Arc; use crate::{ - AddGradFn, AddLayer, AddOperation, Alloc, Buffer, Cache, CachedBuffers, CowMut, Cursor, Device, - ExecNow, FastCache, Guard, HasId, HasModules, IsBasePtr, IsShapeIndep, LockInfo, Module, - OnDropBuffer, OnNewBuffer, Parents, PtrType, RemoveLayer, ReplaceBuf, Retrieve, RunModule, - SetOpHint, Setup, ShallowCopy, Shape, State, UniqueId, Unit, WrappedData, + AddGradFn, AddLayer, AddOperation, Alloc, Buffer, Cache, Cache2, CachedBuffers, CowMut, Cursor, Device, ExecNow, FastCache2, Guard, HasId, HasModules, IsBasePtr, IsShapeIndep, LockInfo, Module, OnDropBuffer, OnNewBuffer, Parents, PtrType, RemoveLayer, ReplaceBuf, Retrieve, RunModule, SetOpHint, Setup, ShallowCopy, Shape, State, UniqueId, Unit, WrappedData }; #[cfg(feature = "graph")] @@ -16,7 +13,7 @@ use crate::{DeviceError, Optimize}; // creator struct, however => // TODO: could remove D generic and therefore CachedModule #[derive(Debug, PartialEq, Eq, Default)] -pub struct Cached { +pub struct Cached { pd: PhantomData, cache_type: PhantomData, } @@ -53,8 +50,7 @@ where fn new() -> Self::Module { CachedModule { modules: Mods::new(), - cache: RefCell::new(CacheType::default()), - cache3: Default::default(), + cache: CacheType::default(), pd: PhantomData, cursor: Default::default(), } @@ -64,10 +60,9 @@ where // impl OnDropBuffer for Cached {} // TODO: could remove D generic and therefore CachedModule -pub struct CachedModule { +pub struct CachedModule { pub modules: Mods, - pub cache: RefCell, - pub cache3: crate::LockedMap>, + pub cache: CacheType, pub(crate) pd: PhantomData, cursor: Cell, // would move this to `Cache`, however -> RefCell; TODO: maybe add a Cursor Module } @@ -155,18 +150,20 @@ impl OnDropBuffer for CachedModule CachedModule where Mods: WrappedData, + CacheType: Cache2>, SimpleDevice: Device, { - pub fn get( + pub fn get_mut( &'a self, id: u64, + len: usize, ) -> State>>> where D: Device, T: 'static, S: Shape, { - let entry = self.cache3.get_mut(&id)?; + let entry = self.cache.get_mut(id, len)?; let entry = RefMut::map(entry, |x| { x.downcast_mut::>>() .unwrap() @@ -184,7 +181,7 @@ where D: Device + IsShapeIndep + Cursor, D::Base: 'static, SimpleDevice: Device, - CacheType: Cache, + CacheType: Cache2>, { #[inline] unsafe fn retrieve_entry( @@ -196,15 +193,17 @@ where where D: Alloc, { + dbg!("retrieve entry"); let id = device.cursor() as UniqueId; - match self.get::(id) { + match self.get_mut::(id, len) { Ok(out) => Ok(out), Err(state) => match state { LockInfo::Locked => panic!("Locked!!"), LockInfo::None => { - self.cache3 - .insert(id, Box::new(self.modules.retrieve(device, len, _parents))); - Ok(self.get::(id).unwrap()) + dbg!("insert entry"); + self.cache + .insert(id, len, Box::new(self.modules.retrieve(device, len, _parents)?)); + Ok(self.get_mut::(id, len).unwrap()) } }, } @@ -326,7 +325,6 @@ impl AddLayer for Cached<() crate::CachedModule { modules: inner_mods, cache: Default::default(), - cache3: Default::default(), pd: core::marker::PhantomData, cursor: Default::default(), } @@ -580,7 +578,7 @@ mod tests { assert_eq!(buf.len(), buf_base.len()); } - let buf_base = buf_base.to_device_type(&device); + let buf_base: Buffer = buf_base.to_device_type(&device); for _ in 0..10 { let buf: Buffer = device.retrieve(10, &buf_base).unwrap();