Skip to content

Commit

Permalink
no-std
Browse files Browse the repository at this point in the history
  • Loading branch information
favilo committed Dec 3, 2023
1 parent 21c9f62 commit 5439341
Showing 1 changed file with 19 additions and 9 deletions.
28 changes: 19 additions & 9 deletions dfdx-core/src/tensor/webgpu/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,13 @@ use crate::{
};

#[cfg(feature = "no-std")]
use spin::Mutex;
use spin::{Mutex, RwLock};

use core::any::TypeId;
#[cfg(not(feature = "no-std"))]
use std::sync::Mutex;
use std::sync::{Mutex, RwLock};

use std::{
collections::HashMap,
marker::PhantomData,
sync::{Arc, RwLock},
vec::Vec,
};
use std::{collections::HashMap, marker::PhantomData, sync::Arc, vec::Vec};

use super::allocate::round_to_buffer_alignment;

Expand Down Expand Up @@ -209,22 +204,37 @@ impl Webgpu {
Ok(data)
}

#[cfg(not(feature = "no-std"))]
pub(crate) fn shader_module_loaded(&self, name: TypeId) -> bool {
self.cs_cache.read().unwrap().contains_key(&name)
}

#[cfg(feature = "no-std")]
pub(crate) fn shader_module_loaded(&self, name: TypeId) -> bool {
self.cs_cache.read().contains_key(&name)
}

pub(crate) fn load_shader_module(&self, name: TypeId, source: &str) {
let module = Arc::new(self.dev.create_shader_module(ShaderModuleDescriptor {
label: None,
source: wgpu::ShaderSource::Wgsl(source.into()),
}));
#[cfg(not(feature = "no-std"))]
self.cs_cache.write().unwrap().insert(name, module);
#[cfg(feature = "no-std")]
self.cs_cache.write().insert(name, module);
}

#[cfg(not(feature = "no-std"))]
pub(crate) fn get_shader_module(&self, name: TypeId) -> Option<Arc<ShaderModule>> {
self.cs_cache.read().unwrap().get(&name).cloned()
}

#[cfg(feature = "no-std")]
pub(crate) fn get_shader_module(&self, name: TypeId) -> Option<Arc<ShaderModule>> {
self.cs_cache.read().get(&name).cloned()
}

// #[allow(unused)]
// pub(crate) unsafe fn get_workspace<E>(&self, len: usize) -> Result<MutexGuard<Buffer>, Error> {
// let num_bytes_required = len * std::mem::size_of::<E>();
Expand Down Expand Up @@ -369,7 +379,7 @@ impl<E: Unit> Storage<E> for Webgpu {
type Vec = CachableBuffer<E>;

fn try_alloc_len(&self, len: usize) -> Result<Self::Vec, Error> {
let data = unsafe { self.alloc_empty::<E>(len) }?;
let data = self.alloc_empty::<E>(len)?;
Ok(CachableBuffer {
dev: self.dev.clone(),
queue: self.queue.clone(),
Expand Down

0 comments on commit 5439341

Please sign in to comment.