diff --git a/dfdx-core/src/tensor/webgpu/device.rs b/dfdx-core/src/tensor/webgpu/device.rs index 4eb62c97..086eb4e4 100644 --- a/dfdx-core/src/tensor/webgpu/device.rs +++ b/dfdx-core/src/tensor/webgpu/device.rs @@ -13,18 +13,13 @@ use crate::{ }; #[cfg(feature = "no-std")] -use spin::Mutex; +use spin::{Mutex, RwLock}; use core::any::TypeId; #[cfg(not(feature = "no-std"))] -use std::sync::Mutex; +use std::sync::{Mutex, RwLock}; -use std::{ - collections::HashMap, - marker::PhantomData, - sync::{Arc, RwLock}, - vec::Vec, -}; +use std::{collections::HashMap, marker::PhantomData, sync::Arc, vec::Vec}; use super::allocate::round_to_buffer_alignment; @@ -209,22 +204,37 @@ impl Webgpu { Ok(data) } + #[cfg(not(feature = "no-std"))] pub(crate) fn shader_module_loaded(&self, name: TypeId) -> bool { self.cs_cache.read().unwrap().contains_key(&name) } + #[cfg(feature = "no-std")] + pub(crate) fn shader_module_loaded(&self, name: TypeId) -> bool { + self.cs_cache.read().contains_key(&name) + } + pub(crate) fn load_shader_module(&self, name: TypeId, source: &str) { let module = Arc::new(self.dev.create_shader_module(ShaderModuleDescriptor { label: None, source: wgpu::ShaderSource::Wgsl(source.into()), })); + #[cfg(not(feature = "no-std"))] self.cs_cache.write().unwrap().insert(name, module); + #[cfg(feature = "no-std")] + self.cs_cache.write().insert(name, module); } + #[cfg(not(feature = "no-std"))] pub(crate) fn get_shader_module(&self, name: TypeId) -> Option> { self.cs_cache.read().unwrap().get(&name).cloned() } + #[cfg(feature = "no-std")] + pub(crate) fn get_shader_module(&self, name: TypeId) -> Option> { + self.cs_cache.read().get(&name).cloned() + } + // #[allow(unused)] // pub(crate) unsafe fn get_workspace(&self, len: usize) -> Result, Error> { // let num_bytes_required = len * std::mem::size_of::(); @@ -369,7 +379,7 @@ impl Storage for Webgpu { type Vec = CachableBuffer; fn try_alloc_len(&self, len: usize) -> Result { - let data = unsafe { self.alloc_empty::(len) }?; + let data = self.alloc_empty::(len)?; Ok(CachableBuffer { dev: self.dev.clone(), queue: self.queue.clone(),