Skip to content

Commit

Permalink
Rename get to get_mut (Cached)
Browse files Browse the repository at this point in the history
  • Loading branch information
elftausend committed Nov 15, 2024
1 parent 9ebd957 commit cfade2f
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 28 deletions.
5 changes: 2 additions & 3 deletions src/cache/locking/locked_map.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use core::ops::Deref;
use std::{
cell::{Ref, RefCell, RefMut, UnsafeCell},
cell::{Ref, RefCell, RefMut},
collections::HashMap,
hash::{BuildHasher, Hash, RandomState},
};
Expand Down Expand Up @@ -32,7 +31,7 @@ impl<K, T, S: BuildHasher> LockedMap<K, T, S> {
pub fn len(&self) -> usize {
self.data.borrow().len()
}

#[inline]
pub fn is_empty(&self) -> bool {
self.data.borrow().is_empty()
Expand Down
19 changes: 15 additions & 4 deletions src/cache/owned_cache.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
mod fast_cache;
pub use fast_cache::*;
// mod fast_cache;
mod fast_cache2;
use core::cell::RefMut;

mod length_cache;
pub use length_cache::*;
// pub use fast_cache::*;
pub use fast_cache2::*;

// mod length_cache;
// pub use length_cache::*;

use crate::{Alloc, ShallowCopy, Shape, UniqueId, Unit};

use super::State;

pub trait Cache2<T> {
fn get_mut(&self, id: UniqueId, len: usize) -> State<RefMut<T>>;
fn insert(&self, id: UniqueId, len: usize, data: T);
}

pub trait Cache {
unsafe fn get<T, S, D>(
&mut self,
Expand Down
40 changes: 19 additions & 21 deletions src/modules/cached.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
use core::{
cell::{Cell, RefCell, RefMut},
marker::PhantomData,
any::Any, cell::{Cell, RefCell, RefMut}, marker::PhantomData
};
use std::sync::Arc;

use crate::{
AddGradFn, AddLayer, AddOperation, Alloc, Buffer, Cache, CachedBuffers, CowMut, Cursor, Device,
ExecNow, FastCache, Guard, HasId, HasModules, IsBasePtr, IsShapeIndep, LockInfo, Module,
OnDropBuffer, OnNewBuffer, Parents, PtrType, RemoveLayer, ReplaceBuf, Retrieve, RunModule,
SetOpHint, Setup, ShallowCopy, Shape, State, UniqueId, Unit, WrappedData,
AddGradFn, AddLayer, AddOperation, Alloc, Buffer, Cache, Cache2, CachedBuffers, CowMut, Cursor, Device, ExecNow, FastCache2, Guard, HasId, HasModules, IsBasePtr, IsShapeIndep, LockInfo, Module, OnDropBuffer, OnNewBuffer, Parents, PtrType, RemoveLayer, ReplaceBuf, Retrieve, RunModule, SetOpHint, Setup, ShallowCopy, Shape, State, UniqueId, Unit, WrappedData
};

#[cfg(feature = "graph")]
Expand All @@ -16,7 +13,7 @@ use crate::{DeviceError, Optimize};
// creator struct, however =>
// TODO: could remove D generic and therefore CachedModule
#[derive(Debug, PartialEq, Eq, Default)]
pub struct Cached<Mods, CacheType = FastCache> {
pub struct Cached<Mods, CacheType = FastCache2> {
pd: PhantomData<Mods>,
cache_type: PhantomData<CacheType>,
}
Expand Down Expand Up @@ -53,8 +50,7 @@ where
fn new() -> Self::Module {
CachedModule {
modules: Mods::new(),
cache: RefCell::new(CacheType::default()),
cache3: Default::default(),
cache: CacheType::default(),
pd: PhantomData,
cursor: Default::default(),
}
Expand All @@ -64,10 +60,9 @@ where
// impl<Mods> OnDropBuffer for Cached<Mods> {}

// TODO: could remove D generic and therefore CachedModule
pub struct CachedModule<Mods, D: Device, CacheType = FastCache> {
pub struct CachedModule<Mods, D: Device, CacheType = FastCache2> {
pub modules: Mods,
pub cache: RefCell<CacheType>,
pub cache3: crate::LockedMap<u64, Box<dyn core::any::Any>>,
pub cache: CacheType,
pub(crate) pd: PhantomData<D>,
cursor: Cell<usize>, // would move this to `Cache`, however -> RefCell; TODO: maybe add a Cursor Module
}
Expand Down Expand Up @@ -155,18 +150,20 @@ impl<CacheType, Mods: OnDropBuffer, SD: Device> OnDropBuffer for CachedModule<Mo
impl<'a, CacheType, Mods, SimpleDevice> CachedModule<Mods, SimpleDevice, CacheType>
where
Mods: WrappedData,
CacheType: Cache2<Box<dyn Any>>,
SimpleDevice: Device,
{
pub fn get<D, T, S>(
pub fn get_mut<D, T, S>(
&'a self,
id: u64,
len: usize,
) -> State<Guard<'a, Mods::Wrap<'static, T, D::Base<T, S>>>>
where
D: Device,
T: 'static,
S: Shape,
{
let entry = self.cache3.get_mut(&id)?;
let entry = self.cache.get_mut(id, len)?;
let entry = RefMut::map(entry, |x| {
x.downcast_mut::<Mods::Wrap<'static, T, D::Base<T, S>>>()
.unwrap()
Expand All @@ -184,7 +181,7 @@ where
D: Device + IsShapeIndep + Cursor,
D::Base<T, S>: 'static,
SimpleDevice: Device,
CacheType: Cache,
CacheType: Cache2<Box<dyn Any>>,
{
#[inline]
unsafe fn retrieve_entry<const NUM_PARENTS: usize>(
Expand All @@ -196,15 +193,17 @@ where
where
D: Alloc<T>,
{
dbg!("retrieve entry");
let id = device.cursor() as UniqueId;
match self.get::<D, T, S>(id) {
match self.get_mut::<D, T, S>(id, len) {
Ok(out) => Ok(out),
Err(state) => match state {
LockInfo::Locked => panic!("Locked!!"),
LockInfo::None => {
self.cache3
.insert(id, Box::new(self.modules.retrieve(device, len, _parents)));
Ok(self.get::<D, T, S>(id).unwrap())
dbg!("insert entry");
self.cache
.insert(id, len, Box::new(self.modules.retrieve(device, len, _parents)?));
Ok(self.get_mut::<D, T, S>(id, len).unwrap())
}
},
}
Expand Down Expand Up @@ -326,7 +325,6 @@ impl<CacheType, CurrentMods, SD: Device> AddLayer<CurrentMods, SD> for Cached<()
crate::CachedModule {
modules: inner_mods,
cache: Default::default(),
cache3: Default::default(),
pd: core::marker::PhantomData,
cursor: Default::default(),
}
Expand Down Expand Up @@ -580,7 +578,7 @@ mod tests {
assert_eq!(buf.len(), buf_base.len());
}

let buf_base = buf_base.to_device_type(&device);
let buf_base: Buffer<f32, _> = buf_base.to_device_type(&device);
for _ in 0..10 {
let buf: Buffer<f32, _> = device.retrieve(10, &buf_base).unwrap();

Expand Down

0 comments on commit cfade2f

Please sign in to comment.