Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement AbsKernelOp for WebGPU backend #896

Merged
merged 16 commits into from
Jan 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/cargo-check-features.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ jobs:
matrix:
config:
- toolchain: stable
command: cargo hack check --feature-powerset --no-dev-deps --depth 2 --skip default,nightly,cuda,cudnn
command: cargo hack check --feature-powerset --no-dev-deps --depth 2 --skip default,nightly,cuda,cudnn,webgpu
- toolchain: nightly
command: cargo hack check --each-feature --no-dev-deps --features nightly --skip default,cuda,cudnn
command: cargo hack check --each-feature --no-dev-deps --features nightly --skip default,cuda,cudnn,webgpu

steps:
- uses: actions/checkout@v2
Expand Down
12 changes: 10 additions & 2 deletions dfdx-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ half = { version = "2.3.1", optional = true, features = ["num-traits", "rand_dis
gemm = { version = "0.16.14", default-features = false, optional = true, features = ["rayon"] }
rayon = { version = "1.7.0", optional = true }
libm = { workspace = true }
wgpu = { version = "0.18.0", optional = true }
wgpu = { version = "0.18.0", features = ["glsl", "spirv"], optional = true }
naga = { version = "0.14.1", optional = true }
futures-lite = { version = "2.0.1", optional = true }
thingbuf = { version = "0.1.4", optional = true }

Expand All @@ -62,7 +63,14 @@ fast-alloc = ["std"]

cuda = ["dep:cudarc", "dep:glob"]
cudnn = ["cuda", "cudarc?/cudnn"]
webgpu = ["dep:wgpu", "dep:futures-lite", "dep:thingbuf", "wgpu/expose-ids"]
webgpu = [
"dep:wgpu",
"dep:futures-lite",
"dep:thingbuf",
"dep:naga",
"dep:glob",
"wgpu/expose-ids",
]

f16 = ["dep:half", "cudarc?/f16", "gemm?/f16"]

Expand Down
55 changes: 55 additions & 0 deletions dfdx-core/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ fn main() {

#[cfg(feature = "cuda")]
cuda::build_ptx();

#[cfg(feature = "webgpu")]
webgpu::build_spv();
}

fn maybe_enable_nightly() {
Expand Down Expand Up @@ -210,3 +213,55 @@ mod cuda {
}
}
}

#[cfg(feature = "webgpu")]
mod webgpu {
pub fn build_spv() {
let out_dir = std::env::var("OUT_DIR").unwrap();
let kernel_paths: Vec<std::path::PathBuf> = glob::glob("src/**/*.glsl")
.unwrap()
.map(|p| p.unwrap())
.collect();
for path in &kernel_paths {
println!("cargo:rerun-if-changed={}", path.display());
}

kernel_paths
.iter()
.for_each(|p| println!("cargo:rerun-if-changed={}", p.display()));

let children = kernel_paths
.iter()
.map(|p| {
["float", "double"].iter().map(|ty| {
// TODO: we need to build this for both float and double
let out_path: std::path::PathBuf = out_dir.clone().into();
let base = p.file_stem().unwrap();
let new_name = format!("{}.{ty}.spv", base.to_str().unwrap());
let out_file = &out_path.join(new_name);
std::process::Command::new("glslc")
.args(["-std=460core"])
.args(["-fshader-stage=compute"])
.args([format!("-DTYPENAME={ty}")])
.args(["-o", &out_file.as_os_str().to_str().unwrap()])
.arg(p)
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.spawn()
.expect("glslc failed to start. Ensure that you have shaderc installed and that `glslc` is in your PATH.")
}).collect::<Vec<_>>()
})
.collect::<Vec<_>>();
for (kernel_path, childs) in kernel_paths.iter().zip(children.into_iter()) {
for child in childs {
let output = child.wait_with_output().expect("glslc failed to run. Ensure that you have shaderc installed and that `glslc` is in your PATH.");
assert!(
output.status.success(),
"glslc error while compiling {kernel_path:?}:\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr)
);
}
}
}
}
3 changes: 3 additions & 0 deletions dfdx-core/src/tensor/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ pub enum Error {

#[cfg(feature = "webgpu")]
WebgpuRequestDeviceError(wgpu::RequestDeviceError),

#[cfg(feature = "webgpu")]
WebgpuSourceLoadError,
}

impl std::fmt::Display for Error {
Expand Down
4 changes: 2 additions & 2 deletions dfdx-core/src/tensor/webgpu/allocate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ impl Webgpu {
shape: S,
buf: Vec<E>,
) -> Result<Tensor<S, E, Self>, Error> {
let buffer = unsafe { self.alloc_empty::<E>(buf.len()) }?;
let buffer = self.alloc_empty::<E>(buf.len())?;
buffer.copy_to_device::<E>(&self.dev, &self.queue, &buf);

Ok(self.build_tensor(shape, shape.strides(), buffer))
Expand Down Expand Up @@ -56,7 +56,7 @@ impl<E: Unit + SafeZeros> ZerosTensor<E> for Webgpu {
fn try_zeros_like<S: HasShape>(&self, src: &S) -> Result<Tensor<S::Shape, E, Self>, Error> {
let shape = *src.shape();
let strides = shape.strides();
let data = unsafe { self.alloc_empty::<E>(shape.num_elements()) }?;
let data = self.alloc_empty::<E>(shape.num_elements())?;
data.copy_to_device(&self.dev, &self.queue, &vec![0u8; data.size()]);

Ok(self.build_tensor(shape, strides, data))
Expand Down
100 changes: 90 additions & 10 deletions dfdx-core/src/tensor/webgpu/device.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
use wgpu::{
Adapter, BufferDescriptor, BufferUsages, Device, Instance, InstanceDescriptor, Maintain, Queue,
RequestDeviceError,
util::{make_spirv, make_spirv_raw, BufferInitDescriptor, DeviceExt},
Adapter, BufferDescriptor, BufferUsages, Device, DeviceDescriptor, Features, Instance,
InstanceDescriptor, Maintain, Queue, RequestDeviceError, ShaderModule, ShaderModuleDescriptor,
ShaderModuleDescriptorSpirV,
};

use crate::{
prelude::webgpu_kernels::HasGlslType,
shapes::{Shape, Unit},
tensor::{
cache::TensorCache, cpu::Cpu, Cache, Error, NoneTape, RandomU64, Storage, Synchronize,
Expand All @@ -12,12 +15,13 @@ use crate::{
};

#[cfg(feature = "no-std")]
use spin::Mutex;
use spin::{Mutex, RwLock};

use core::any::TypeId;
#[cfg(not(feature = "no-std"))]
use std::sync::Mutex;
use std::sync::{Mutex, RwLock};

use std::{marker::PhantomData, sync::Arc, vec::Vec};
use std::{collections::HashMap, marker::PhantomData, sync::Arc, vec::Vec};

use super::allocate::round_to_buffer_alignment;

Expand All @@ -40,12 +44,16 @@ impl Buffer {
self.size
}

pub(crate) fn len<E: Unit>(&self) -> usize {
self.size / std::mem::size_of::<E>()
}

#[allow(unused)]
pub(crate) fn capacity(&self) -> usize {
self.data.size() as usize
}

pub(crate) fn copy_to_device<E: Unit>(&self, dev: &Device, queue: &Queue, slice: &[E]) {
pub(crate) fn copy_to_device<E>(&self, dev: &Device, queue: &Queue, slice: &[E]) {
let slice = unsafe {
std::slice::from_raw_parts(
slice.as_ptr() as *const u8,
Expand Down Expand Up @@ -102,6 +110,7 @@ pub struct Webgpu {
pub(crate) queue: Arc<Queue>,

pub(crate) cache: Arc<TensorCache<Buffer>>,
pub(crate) cs_cache: Arc<RwLock<HashMap<TypeId, Arc<ShaderModule>>>>,
}

impl From<RequestDeviceError> for Error {
Expand Down Expand Up @@ -134,8 +143,13 @@ impl Webgpu {
let adapter = futures_lite::future::block_on(instance.request_adapter(&Default::default()))
.ok_or(Error::WebgpuAdapterNotFound)?;
let adapter = Arc::new(adapter);
let descriptor = DeviceDescriptor {
label: None,
features: Features::default() | Features::SPIRV_SHADER_PASSTHROUGH,
limits: Default::default(),
};
let (dev, queue) =
futures_lite::future::block_on(adapter.request_device(&Default::default(), None))?;
futures_lite::future::block_on(adapter.request_device(&descriptor, None))?;
let dev = Arc::new(dev);
let queue = Arc::new(queue);

Expand All @@ -147,18 +161,19 @@ impl Webgpu {
queue,

cache: Default::default(),
cs_cache: Default::default(),
})
}
}

impl Webgpu {
pub(crate) unsafe fn alloc_empty<E>(&self, len: usize) -> Result<Buffer, Error> {
pub(crate) fn alloc_empty<E>(&self, len: usize) -> Result<Buffer, Error> {
let data = self.cache.try_pop::<E>(len).map_or_else(
|| Buffer {
data: self.dev.create_buffer(&BufferDescriptor {
label: None,
size: round_to_buffer_alignment((len * std::mem::size_of::<E>()) as u64),
usage: BufferUsages::COPY_SRC | BufferUsages::COPY_DST,
usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC | BufferUsages::COPY_DST,
mapped_at_creation: false,
}),
size: len * std::mem::size_of::<E>(),
Expand All @@ -168,6 +183,71 @@ impl Webgpu {
Ok(data)
}

pub(crate) fn alloc_init<E>(&self, init: &[E]) -> Result<Buffer, Error> {
let data = self.cache.try_pop::<E>(init.len()).map_or_else(
|| {
let contents = unsafe {
std::slice::from_raw_parts(
init.as_ptr() as *const u8,
init.len() * std::mem::size_of::<E>(),
)
};
Buffer {
data: self.dev.create_buffer_init(&BufferInitDescriptor {
label: None,
usage: BufferUsages::STORAGE
| BufferUsages::COPY_SRC
| BufferUsages::COPY_DST,
contents,
}),
size: init.len() * std::mem::size_of::<E>(),
}
},
|bfr| {
bfr.copy_to_device::<E>(&self.dev, &self.queue, init);
bfr
},
);
Ok(data)
}

#[cfg(not(feature = "no-std"))]
pub(crate) fn shader_module_loaded(&self, name: TypeId) -> bool {
self.cs_cache.read().unwrap().contains_key(&name)
}

#[cfg(feature = "no-std")]
pub(crate) fn shader_module_loaded(&self, name: TypeId) -> bool {
self.cs_cache.read().contains_key(&name)
}

pub(crate) fn load_shader_module<E>(&self, name: TypeId, source: &[u8])
where
E: HasGlslType,
{
let module = Arc::new(unsafe {
self.dev
.create_shader_module_spirv(&ShaderModuleDescriptorSpirV {
label: None,
source: make_spirv_raw(source),
})
});
#[cfg(not(feature = "no-std"))]
self.cs_cache.write().unwrap().insert(name, module);
#[cfg(feature = "no-std")]
self.cs_cache.write().insert(name, module);
}

#[cfg(not(feature = "no-std"))]
pub(crate) fn get_shader_module(&self, name: TypeId) -> Option<Arc<ShaderModule>> {
self.cs_cache.read().unwrap().get(&name).cloned()
}

#[cfg(feature = "no-std")]
pub(crate) fn get_shader_module(&self, name: TypeId) -> Option<Arc<ShaderModule>> {
self.cs_cache.read().get(&name).cloned()
}

// #[allow(unused)]
// pub(crate) unsafe fn get_workspace<E>(&self, len: usize) -> Result<MutexGuard<Buffer>, Error> {
// let num_bytes_required = len * std::mem::size_of::<E>();
Expand Down Expand Up @@ -312,7 +392,7 @@ impl<E: Unit> Storage<E> for Webgpu {
type Vec = CachableBuffer<E>;

fn try_alloc_len(&self, len: usize) -> Result<Self::Vec, Error> {
let data = unsafe { self.alloc_empty::<E>(len) }?;
let data = self.alloc_empty::<E>(len)?;
Ok(CachableBuffer {
dev: self.dev.clone(),
queue: self.queue.clone(),
Expand Down
28 changes: 28 additions & 0 deletions dfdx-core/src/tensor_ops/abs/abs.bwd.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#version 460 core

#extension GL_ARB_compute_shader: enable
#extension GL_ARB_shader_storage_buffer_object: enable

layout(local_size_x = 128) in;

layout(std430, binding = 1) buffer inpBlock {
TYPENAME inp[];
};

layout(std430, binding = 2) buffer outpBlock {
TYPENAME outp[];
};

layout(std430, binding = 3) buffer input_gradBlock {
TYPENAME input_grad[];
};

layout(std430, binding = 4) buffer output_gradBlock {
TYPENAME output_grad[];
};

void main() {
TYPENAME dx = sign(inp[gl_GlobalInvocationID.x]);

input_grad[gl_GlobalInvocationID.x] = dx * output_grad[gl_GlobalInvocationID.x];
}
22 changes: 22 additions & 0 deletions dfdx-core/src/tensor_ops/abs/abs.fwd.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#version 460 core

#extension GL_ARB_compute_shader: enable
#extension GL_ARB_shader_storage_buffer_object: enable

layout(local_size_x = 128) in;

layout(std430, binding = 1) buffer inpBlock {
TYPENAME inp[];
};

layout(std430, binding = 2) buffer outpBlock{
TYPENAME outp[];
};

void main() {
if (inp.length() == 0) {
outp[gl_GlobalInvocationID.x] = abs(outp[gl_GlobalInvocationID.x]);
} else {
outp[gl_GlobalInvocationID.x] = abs(inp[gl_GlobalInvocationID.x]);
}
}
Loading
Loading