diff --git a/dfdx-core/Cargo.toml b/dfdx-core/Cargo.toml index 0eeac1f4..554ec068 100644 --- a/dfdx-core/Cargo.toml +++ b/dfdx-core/Cargo.toml @@ -25,22 +25,24 @@ keywords = [ features = ["nightly", "numpy", "safetensors", "cuda", "ci-check"] [dependencies] +bytemuck = { version = "1.7.0", optional = true } +cudarc = { version = "0.9.15", default-features = false, optional = true, features = ["driver", "cublas", "nvrtc"] } +futures-lite = { version = "2.0.1", optional = true } +gemm = { version = "0.16.14", default-features = false, optional = true, features = ["rayon"] } +half = { version = "2.3.1", optional = true, features = ["num-traits", "rand_distr"] } +libm = { workspace = true } +memmap2 = { workspace = true, optional = true } no-std-compat = { version = "0.4.1", default-features = false, features = [ "alloc", "compat_hash" ], optional = true } -spin = { version = "0.9.8", default-features = false, features = ["spin_mutex", "rwlock", "portable_atomic"], optional = true } +num-traits = { workspace = true } rand = { workspace = true } rand_distr = { workspace = true } -zip = { version = "0.6.6", default-features = false, optional = true } -cudarc = { version = "0.9.15", default-features = false, optional = true, features = ["driver", "cublas", "nvrtc"] } -num-traits = { workspace = true } -safetensors = { workspace = true, optional = true } -memmap2 = { workspace = true, optional = true } -half = { version = "2.3.1", optional = true, features = ["num-traits", "rand_distr"] } -gemm = { version = "0.16.14", default-features = false, optional = true, features = ["rayon"] } rayon = { version = "1.7.0", optional = true } -libm = { workspace = true } -wgpu = { version = "0.18.0", optional = true } -futures-lite = { version = "2.0.1", optional = true } +safetensors = { workspace = true, optional = true } +spin = { version = "0.9.8", default-features = false, features = ["spin_mutex", "rwlock", "portable_atomic"], optional = true } +static_assertions = { version = "1.1.0", optional = true } thingbuf = { version = "0.1.4", optional = true } +wgpu = { version = "0.18.0", optional = true } +zip = { version = "0.6.6", default-features = false, optional = true } [dev-dependencies] tempfile = "3.3.0" @@ -62,7 +64,7 @@ fast-alloc = ["std"] cuda = ["dep:cudarc", "dep:glob"] cudnn = ["cuda", "cudarc?/cudnn"] -webgpu = ["dep:wgpu", "dep:futures-lite", "dep:thingbuf", "wgpu/expose-ids"] +webgpu = ["dep:wgpu", "dep:futures-lite", "dep:thingbuf", "dep:bytemuck", "dep:static_assertions", "wgpu/expose-ids"] f16 = ["dep:half", "cudarc?/f16", "gemm?/f16"] diff --git a/dfdx-core/src/tensor/cache.rs b/dfdx-core/src/tensor/cache.rs index e785cb64..0690a721 100644 --- a/dfdx-core/src/tensor/cache.rs +++ b/dfdx-core/src/tensor/cache.rs @@ -74,7 +74,7 @@ impl TensorCache { } } - /// Disables the cache. + /// Enables the cache. pub(crate) fn enable(&self) { #[cfg(not(feature = "no-std"))] { diff --git a/dfdx-core/src/tensor/mod.rs b/dfdx-core/src/tensor/mod.rs index acc4074a..d766c4c2 100644 --- a/dfdx-core/src/tensor/mod.rs +++ b/dfdx-core/src/tensor/mod.rs @@ -146,7 +146,7 @@ mod masks; #[cfg(feature = "numpy")] pub(crate) mod numpy; #[cfg(feature = "webgpu")] -pub(crate) mod webgpu; +pub mod webgpu; #[cfg(feature = "numpy")] pub use numpy::NumpyDtype; mod error; @@ -177,7 +177,7 @@ pub type AutoDevice = Cuda; #[cfg(feature = "webgpu")] pub use webgpu::Webgpu; #[cfg(feature = "webgpu")] -pub type AutoDevice = Webgpu; +pub type AutoDevice = Cpu; // todo pub use storage_traits::{AsArray, CopySlice, TensorFrom, TensorFromVec, TensorToArray}; pub use storage_traits::{Cache, RandomU64, Storage, Synchronize}; diff --git a/dfdx-core/src/tensor/webgpu/allocate.rs b/dfdx-core/src/tensor/webgpu/allocate.rs index 49162381..c85a96bf 100644 --- a/dfdx-core/src/tensor/webgpu/allocate.rs +++ b/dfdx-core/src/tensor/webgpu/allocate.rs @@ -112,7 +112,7 @@ where impl OneFillStorage for Webgpu { fn try_fill_with_ones(&self, storage: &mut Self::Vec) -> Result<(), Error> { - let len = storage.size() as usize / std::mem::size_of::(); + let len = storage.size() / std::mem::size_of::(); let buf = vec![E::ONE; len]; storage .data @@ -171,7 +171,7 @@ where impl CopySlice for Webgpu { fn copy_from(dst: &mut Tensor, src: &[E]) { assert_eq!( - dst.data.size() as usize, + dst.data.size() , src.len() * std::mem::size_of::(), "Slices must have same number of elements as *physical* Storage of tensors." ); @@ -182,7 +182,7 @@ impl CopySlice for Webgpu { fn copy_into(src: &Tensor, dst: &mut [E]) { assert_eq!( - src.data.size() as usize, + src.data.size(), dst.len() * std::mem::size_of::(), "Slices must have same number of elements as *physical* Storage of tensors." ); diff --git a/dfdx-core/src/tensor/webgpu/device.rs b/dfdx-core/src/tensor/webgpu/device.rs index 3cba06c7..0a4c8375 100644 --- a/dfdx-core/src/tensor/webgpu/device.rs +++ b/dfdx-core/src/tensor/webgpu/device.rs @@ -3,6 +3,7 @@ use wgpu::{ RequestDeviceError, }; +use super::resources::{binary_op_layout_desc, unary_op_layout_desc}; use crate::{ shapes::{Shape, Unit}, tensor::{ @@ -19,6 +20,8 @@ use std::sync::Mutex; use std::{marker::PhantomData, sync::Arc, vec::Vec}; +use futures_lite::future::block_on; + use super::allocate::round_to_buffer_alignment; #[derive(Debug)] @@ -102,6 +105,12 @@ pub struct Webgpu { pub(crate) queue: Arc, pub(crate) cache: Arc>, + + // pipeline resources + /// `[unary, binary]` pipeline layouts + /// + /// storing them for re-use reduces resource allocation pressure on the GPU + pub(super) layouts: [Arc; 2], } impl From for Error { @@ -129,16 +138,42 @@ impl Webgpu { #[cfg(not(feature = "no-std"))] let _lock = { CONSTRUCTOR_MUTEX.lock().unwrap() }; - let cpu = Cpu::seed_from_u64(seed); + #[cfg(not(feature = "f16"))] + let features: wgpu::Features = Default::default() | wgpu::Features::PUSH_CONSTANTS; + #[cfg(feature = "f16")] + let features: wgpu::Features = + wgpu::Features::default() | wgpu::Features::PUSH_CONSTANTS | wgpu::Features::SHADER_F16; + + let limits: wgpu::Limits = Default::default(); + let device_desc = wgpu::DeviceDescriptor { + label: Some("dfdx"), + features, + limits, + }; + let adapter_desc = wgpu::RequestAdapterOptions { + power_preference: wgpu::PowerPreference::HighPerformance, + ..Default::default() + }; + + // request adapter let instance = Arc::new(Instance::new(InstanceDescriptor::default())); - let adapter = futures_lite::future::block_on(instance.request_adapter(&Default::default())) + // note: may also fail b/c adapter doesn't support requested features/limits + let adapter = block_on(instance.request_adapter(&adapter_desc)) .ok_or(Error::WebgpuAdapterNotFound)?; let adapter = Arc::new(adapter); - let (dev, queue) = - futures_lite::future::block_on(adapter.request_device(&Default::default(), None))?; + + // request device from adapter + let (dev, queue) = block_on(adapter.request_device(&device_desc, None))?; let dev = Arc::new(dev); let queue = Arc::new(queue); + let cpu = Cpu::seed_from_u64(seed); + + let layouts = [ + Arc::new(dev.create_bind_group_layout(&unary_op_layout_desc())), + Arc::new(dev.create_bind_group_layout(&binary_op_layout_desc())), + ]; + Ok(Self { cpu, instance, @@ -147,18 +182,68 @@ impl Webgpu { queue, cache: Default::default(), + + layouts, }) } + + /// Submit a command buffer to the GPU. + /// + /// Note: Does not block until completion. If you need this, use + /// `self.dev.poll(Maintain::WaitForSubmissionIndex(idx))` using the + /// returned [`wgpu::SubmissionIndex`] + pub(crate) fn submit_commands(&self, command_builder: F) -> wgpu::SubmissionIndex + where + F: FnOnce(&mut wgpu::CommandEncoder), + { + let mut encoder = self + .dev + .create_command_encoder(&wgpu::CommandEncoderDescriptor { + label: Some("submit_commands"), + }); + command_builder(&mut encoder); + let cmd = [encoder.finish()]; + self.queue.submit(cmd) + } + + /// Convienence function for submitting single-stage compute operations. + /// + /// see: [`submit_commands`] + pub(crate) fn submit_basic_op( + &self, + pipeline: &wgpu::ComputePipeline, + params: &wgpu::BindGroup, + label: Option<&str>, + work_groups: &(u32, u32, u32), + ) -> wgpu::SubmissionIndex { + return self.submit_commands(|encoder| { + let (x, y, z) = *work_groups; + let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label, + ..Default::default() + }); + if let Some(label) = label { + pass.push_debug_group(label); + } + pass.set_pipeline(pipeline); + pass.set_bind_group(0, params, &[]); + pass.dispatch_workgroups(x, y, z); + if label.is_some() { + pass.pop_debug_group(); + } + }); + } } impl Webgpu { + // todo: support configuration of usage flags pub(crate) unsafe fn alloc_empty(&self, len: usize) -> Result { let data = self.cache.try_pop::(len).map_or_else( || Buffer { data: self.dev.create_buffer(&BufferDescriptor { label: None, size: round_to_buffer_alignment((len * std::mem::size_of::()) as u64), - usage: BufferUsages::COPY_SRC | BufferUsages::COPY_DST, + usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC | BufferUsages::COPY_DST, mapped_at_creation: false, }), size: len * std::mem::size_of::(), @@ -198,7 +283,7 @@ pub struct CachableBuffer { impl Clone for CachableBuffer { fn clone(&self) -> Self { - let len = self.data.size() as usize / std::mem::size_of::(); + let len = self.data.size() / std::mem::size_of::(); let (encoder, data) = self.cache.try_pop::(len).map_or_else( || { let mut encoder = self.dev.create_command_encoder(&Default::default()); @@ -213,7 +298,7 @@ impl Clone for CachableBuffer { encoder, Buffer { data: bfr, - size: self.data.size as usize, + size: self.data.size, }, ) }, diff --git a/dfdx-core/src/tensor/webgpu/mod.rs b/dfdx-core/src/tensor/webgpu/mod.rs index 666ce53e..2c1a08e9 100644 --- a/dfdx-core/src/tensor/webgpu/mod.rs +++ b/dfdx-core/src/tensor/webgpu/mod.rs @@ -1,8 +1,11 @@ mod allocate; mod device; +mod resources; +mod types; pub use device::Buffer; pub use device::Webgpu; +pub(crate) use types::WebgpuNativeType; #[cfg(test)] mod tests { diff --git a/dfdx-core/src/tensor/webgpu/resources.rs b/dfdx-core/src/tensor/webgpu/resources.rs new file mode 100644 index 00000000..0f2feb99 --- /dev/null +++ b/dfdx-core/src/tensor/webgpu/resources.rs @@ -0,0 +1,153 @@ +use super::Webgpu; +// FIXME: nostd support +use std::ops::Range; +use wgpu; + +const UNARY_OP_LAYOUT_NAME: &str = "unary"; +const BINARY_OP_LAYOUT_NAME: &str = "binary"; + +impl Webgpu { + #[inline] + pub(crate) fn unary_op_layout(&self) -> &wgpu::BindGroupLayout { + self.layouts[0].as_ref() + } + + #[inline] + pub(crate) fn binary_op_layout(&self) -> &wgpu::BindGroupLayout { + self.layouts[1].as_ref() + } + + /// Creates a [`wgpu::ComputePipeline`] for a binary operation. + /// + /// todo: implement pipeline caching + /// + /// shader_name: the name of the shader module + /// shader_source: the module's WGSL source code + /// fn_name: The name of the entry point function + pub(crate) fn load_binary_pipeline( + &self, + shader_name: &str, + shader_source: &str, + fn_name: &str, + push_constant_ranges: &[Range] + ) -> wgpu::ComputePipeline { + self.load_pipeline( + shader_name, + shader_source, + fn_name, + &[self.binary_op_layout()], + push_constant_ranges + ) + } + + pub(crate) fn load_unary_pipeline( + &self, + shader_name: &str, + shader_source: &str, + fn_name: &str, + push_constant_ranges: &[Range] + ) -> wgpu::ComputePipeline { + self.load_pipeline( + shader_name, + shader_source, + fn_name, + &[self.unary_op_layout()], + push_constant_ranges + ) + } + + /// Creates a [`wgpu::ComputePipeline`] for some operation. + /// + /// - `shader_name`: the name of the shader module + /// - `shader_source`: the module's WGSL source code + /// - `fn_name`: The name of the entry point function + /// - `layouts`: bind group layouts + /// - `push_constant_ranges`: push constant ranges. Leave empty if unused. + pub(crate) fn load_pipeline( + &self, + shader_name: &str, + shader_source: &str, + fn_name: &str, + layouts: &[&wgpu::BindGroupLayout], + push_constant_ranges: &[Range], + ) -> wgpu::ComputePipeline { + // todo: cache these + let source = wgpu::ShaderSource::Wgsl(shader_source.into()); + let shader_module = self.dev.create_shader_module(wgpu::ShaderModuleDescriptor { + label: Some(shader_name), + source, + }); + + let push_constant_ranges = push_constant_ranges + .iter() + .map(|range| wgpu::PushConstantRange { + stages: wgpu::ShaderStages::COMPUTE, + range: range.clone() + }) + .collect::>(); + + let pipeline_layout = self + .dev + .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { + label: Some(fn_name), + bind_group_layouts: layouts, + // todo: these are useful and we should use them if the adapter supports them + push_constant_ranges: &push_constant_ranges + }); + + self.dev + .create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: Some(fn_name), + layout: Some(&pipeline_layout), + module: &shader_module, + entry_point: fn_name, + }) + } +} + +pub(super) const fn unary_op_layout_desc() -> wgpu::BindGroupLayoutDescriptor<'static> { + const ENTRIES: [wgpu::BindGroupLayoutEntry; 2] = [ + // input tensor buffer + storage_entry(0, true), + // TODO: metadata buffer (also try getting uniforms to work, since we can only have 8 storage buffers at once) + // storage_entry(1, true), + // output tensor buffer + storage_entry(1, false), + ]; + wgpu::BindGroupLayoutDescriptor { + label: Some(UNARY_OP_LAYOUT_NAME), + entries: &ENTRIES, + } +} + +pub(super) const fn binary_op_layout_desc() -> wgpu::BindGroupLayoutDescriptor<'static> { + const ENTRIES: [wgpu::BindGroupLayoutEntry; 3] = [ + // lhs tensor buffer + storage_entry(0, true), + // rhs tensor buffer + storage_entry(1, true), + // TODO: metadata buffer (also try getting uniforms to work, since we can only have 8 storage buffers at once) + // storage_entry(2, true), + // output tensor buffer + storage_entry(2, false), + ]; + wgpu::BindGroupLayoutDescriptor { + label: Some(BINARY_OP_LAYOUT_NAME), + entries: &ENTRIES, + } +} + +/// Creates a [`wgpu::BindGroupLayoutEntry`] for a storage buffer. Useful for +/// composing a [`wgpu::BindGroupLayout`]. +const fn storage_entry(index: u32, read_only: bool) -> wgpu::BindGroupLayoutEntry { + wgpu::BindGroupLayoutEntry { + binding: index, + visibility: wgpu::ShaderStages::COMPUTE, + ty: wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Storage { read_only }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + } +} diff --git a/dfdx-core/src/tensor/webgpu/types.rs b/dfdx-core/src/tensor/webgpu/types.rs new file mode 100644 index 00000000..c1900f4e --- /dev/null +++ b/dfdx-core/src/tensor/webgpu/types.rs @@ -0,0 +1,48 @@ +use crate::shapes::Unit; + +/// A primitive data type natively supported by WebGPU. +pub trait WebgpuNativeType : Unit { + /// Name of the data type in WGSL. + const NAME: &'static str; +} + +macro_rules! webgpu_type { + ($RustTy:ty) => { + impl WebgpuNativeType for $RustTy { + const NAME: &'static str = stringify!($RustTy); + } + }; + ($RustTy:ty, $WgpuTy:expr) => { + impl WebgpuNativeType for $RustTy { + const NAME: &'static str = $WgpuTy; + } + }; +} + +/* +see: +- https://docs.rs/wgpu/latest/wgpu/struct.Features.html#associatedconstant.SHADER_F16 +- https://docs.rs/wgpu/latest/wgpu/struct.Features.html#associatedconstant.SHADER_F64 +- https://docs.rs/wgpu/latest/wgpu/struct.Features.html#associatedconstant.SHADER_I16 + */ +#[cfg(feature = "f16")] +webgpu_type!(half::f16, "f16"); +webgpu_type!(f32); +// todo: only enable when f64 feature is enabled +#[cfg(feature = "f64")] +webgpu_type!(f64); + +#[cfg(feature = "i16")] +webgpu_type!(i16); +webgpu_type!(i32); + +webgpu_type!(u32); +webgpu_type!(bool); + +// pub trait WgpuPackedType : Unit { +// const NAME: &'static str; +// /// Number of elements packed into a single `E` value. +// /// +// /// For example, `i8` is packed 4 times into a single `i32` value. +// const PACK_WIDTH: usize; +// } diff --git a/dfdx-core/src/tensor_ops/abs/mod.rs b/dfdx-core/src/tensor_ops/abs/mod.rs index 45c7794d..51c86700 100644 --- a/dfdx-core/src/tensor_ops/abs/mod.rs +++ b/dfdx-core/src/tensor_ops/abs/mod.rs @@ -3,8 +3,8 @@ mod cpu_kernel; #[cfg(feature = "cuda")] mod cuda_kernel; -#[cfg(feature = "webgpu")] -mod webgpu_kernel; +// #[cfg(feature = "webgpu")] +// mod webgpu_kernel; use super::ops::{try_unary_op, UnaryKernel}; use crate::{shapes::*, tensor::*}; diff --git a/dfdx-core/src/tensor_ops/accurate_gelu/mod.rs b/dfdx-core/src/tensor_ops/accurate_gelu/mod.rs index 396c7fa2..940b8253 100644 --- a/dfdx-core/src/tensor_ops/accurate_gelu/mod.rs +++ b/dfdx-core/src/tensor_ops/accurate_gelu/mod.rs @@ -3,8 +3,8 @@ mod cpu_kernel; #[cfg(feature = "cuda")] mod cuda_kernel; -#[cfg(feature = "webgpu")] -mod webgpu_kernel; +// #[cfg(feature = "webgpu")] +// mod webgpu_kernel; use super::ops::{try_unary_op, UnaryKernel}; use crate::{shapes::*, tensor::*}; diff --git a/dfdx-core/src/tensor_ops/adam/mod.rs b/dfdx-core/src/tensor_ops/adam/mod.rs index 9b2372e4..ee4cdb52 100644 --- a/dfdx-core/src/tensor_ops/adam/mod.rs +++ b/dfdx-core/src/tensor_ops/adam/mod.rs @@ -3,8 +3,8 @@ mod cpu_kernel; #[cfg(feature = "cuda")] mod cuda_kernel; -#[cfg(feature = "webgpu")] -mod webgpu_kernel; +// #[cfg(feature = "webgpu")] +// mod webgpu_kernel; use crate::{ shapes::{Dtype, Shape}, diff --git a/dfdx-core/src/tensor_ops/add/binary_add.wgsl b/dfdx-core/src/tensor_ops/add/binary_add.wgsl new file mode 100644 index 00000000..e53d40e6 --- /dev/null +++ b/dfdx-core/src/tensor_ops/add/binary_add.wgsl @@ -0,0 +1,27 @@ +alias usize = u32; + +struct BinaryKernelMeta { + numel: usize, + num_dims: usize, + info: array +} + +@group(0) @binding(0) +var lhs: array; + +@group(0) @binding(1) +var rhs: array; + +// @group(0) @binding(2) +// var kernel_meta: BinaryKernelMeta; + +@group(0) @binding(2) +var output: array; + +@compute @workgroup_size(1, 1, 1) +fn badd_fwd_f32( + @builtin(global_invocation_id) global_id: vec3 +) { + let i = global_id.x; + output[i] = lhs[i] + rhs[i]; +} diff --git a/dfdx-core/src/tensor_ops/add/mod.rs b/dfdx-core/src/tensor_ops/add/mod.rs index 33c27184..d5395a5d 100644 --- a/dfdx-core/src/tensor_ops/add/mod.rs +++ b/dfdx-core/src/tensor_ops/add/mod.rs @@ -98,12 +98,13 @@ where #[cfg(test)] mod tests { use crate::{shapes::*, tensor::*, tensor_ops::*, tests::*}; + // use crate::tensor::webgpu::Webgpu; #[test] fn test_add_0d() { - let dev: TestDevice = Default::default(); - let a = dev.tensor(1.0f64).to_dtype::(); - let b = dev.tensor(1.0f64).to_dtype::(); + let dev: Webgpu = Default::default(); + let a = dev.tensor(1.0f32).to_dtype::(); + let b = dev.tensor(1.0f32).to_dtype::(); let r = a.leaky_trace() + b.clone(); assert_close_to_literal!(r, 2.0); diff --git a/dfdx-core/src/tensor_ops/add/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/add/webgpu_kernel.rs index 91becc55..d7b61414 100644 --- a/dfdx-core/src/tensor_ops/add/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/add/webgpu_kernel.rs @@ -4,53 +4,94 @@ use crate::prelude::{ ops::{BinaryKernel, UnaryKernel}, Dtype, Webgpu, }; +use crate::tensor_ops::webgpu_kernels::wgpu_binary; +use super::BinaryAddKernelOp as Binary; -impl UnaryKernel, E> for Webgpu { - const BACKWARD_WITHOUT_INP: bool = false; +const BADD_SRC: &'static str = include_str!("./binary_add.wgsl"); +wgpu_binary!( + const_df() Binary, + f32, + BADD_SRC, + "badd", + "badd_fwd_f32", + "badd_bwd_lhs_f32", + "badd_bwd_rhs_f32" +); - const BACKWARD_WITHOUT_DATA: bool = true; +// putting these tests here b/c I haven't added support for other data types +// and/or unary adds yet, so using the mod's tests breaks the build +#[cfg(test)] +mod tests { + use crate::{shapes::*, tensor::*, tensor_ops::*, tests::*}; + #[test] + fn test_add_zeroes() { + let dev: Webgpu = Default::default(); - fn forward( - &self, - op: super::ScalarAddKernelOp, - inp: Cow>, - ) -> Result, crate::prelude::Error> { - todo!() + let a: Tensor, f32, _> = dev.zeros(); + let b: Tensor, f32, _> = dev.ones(); + let actual = a + b.clone(); + let actual_vec = actual.as_vec(); + let expected = b.as_vec(); + + assert_eq!(actual_vec, expected); } - fn backward( - &self, - op: super::ScalarAddKernelOp, - inp: &impl crate::prelude::Tensorlike, - grad_inp: &mut Self::Vec, - out: &impl crate::prelude::Tensorlike, - grad_out: &Self::Vec, - ) -> Result<(), crate::prelude::Error> { - todo!() + #[test] + fn test_add_increasing() { + let dev: Webgpu = Default::default(); + + let a = dev.tensor(&[1.0f32, 2.0, 3.0]); + let result = a.clone() + a.clone(); + let result = a + result.clone(); + assert_eq!(result.as_vec(), &[3.0, 6.0, 9.0]); } } +// impl UnaryKernel, E> for Webgpu { +// const BACKWARD_WITHOUT_INP: bool = false; -impl BinaryKernel for Webgpu { - const BACKWARD_WITHOUT_DATA: bool = true; +// const BACKWARD_WITHOUT_DATA: bool = true; - fn forward( - &self, - op: super::BinaryAddKernelOp, - lhs: Cow>, - rhs: Cow>, - ) -> Result, crate::prelude::Error> { - todo!() - } +// fn forward( +// &self, +// op: super::ScalarAddKernelOp, +// inp: Cow>, +// ) -> Result, crate::prelude::Error> { +// todo!() +// } - fn backward( - &self, - op: super::BinaryAddKernelOp, - lhs: &impl crate::prelude::Tensorlike, - grad_lhs: &mut Self::Vec, - rhs: &impl crate::prelude::Tensorlike, - grad_rhs: &mut Self::Vec, - grad_out: &Self::Vec, - ) -> Result<(), crate::prelude::Error> { - todo!() - } -} +// fn backward( +// &self, +// op: super::ScalarAddKernelOp, +// inp: &impl crate::prelude::Tensorlike, +// grad_inp: &mut Self::Vec, +// out: &impl crate::prelude::Tensorlike, +// grad_out: &Self::Vec, +// ) -> Result<(), crate::prelude::Error> { +// todo!() +// } +// } + +// impl BinaryKernel for Webgpu { +// const BACKWARD_WITHOUT_DATA: bool = true; + +// fn forward( +// &self, +// op: super::BinaryAddKernelOp, +// lhs: Cow>, +// rhs: Cow>, +// ) -> Result, crate::prelude::Error> { +// todo!() +// } + +// fn backward( +// &self, +// op: super::BinaryAddKernelOp, +// lhs: &impl crate::prelude::Tensorlike, +// grad_lhs: &mut Self::Vec, +// rhs: &impl crate::prelude::Tensorlike, +// grad_rhs: &mut Self::Vec, +// grad_out: &Self::Vec, +// ) -> Result<(), crate::prelude::Error> { +// todo!() +// } +// } diff --git a/dfdx-core/src/tensor_ops/axpy/mod.rs b/dfdx-core/src/tensor_ops/axpy/mod.rs index e0b1e66a..2743749a 100644 --- a/dfdx-core/src/tensor_ops/axpy/mod.rs +++ b/dfdx-core/src/tensor_ops/axpy/mod.rs @@ -7,8 +7,8 @@ mod cpu_kernel; #[cfg(feature = "cuda")] mod cuda_kernel; -#[cfg(feature = "webgpu")] -mod webgpu_kernel; +// #[cfg(feature = "webgpu")] +// mod webgpu_kernel; /// Elementwise `a * alpha + b * beta`. /// diff --git a/dfdx-core/src/tensor_ops/bce/mod.rs b/dfdx-core/src/tensor_ops/bce/mod.rs index 48735e68..370e757e 100644 --- a/dfdx-core/src/tensor_ops/bce/mod.rs +++ b/dfdx-core/src/tensor_ops/bce/mod.rs @@ -3,8 +3,8 @@ mod cpu_kernel; #[cfg(feature = "cuda")] mod cuda_kernel; -#[cfg(feature = "webgpu")] -mod webgpu_kernel; +// #[cfg(feature = "webgpu")] +// mod webgpu_kernel; use super::ops::{try_binary_op, BinaryKernel}; use crate::{shapes::*, tensor::*}; diff --git a/dfdx-core/src/tensor_ops/bce/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/bce/webgpu_kernel.rs index 02b7f3cf..65907bfb 100644 --- a/dfdx-core/src/tensor_ops/bce/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/bce/webgpu_kernel.rs @@ -1,27 +1,27 @@ use crate::prelude::{ops::BinaryKernel, Dtype, Webgpu}; use std::borrow::Cow; -impl BinaryKernel for Webgpu { - const BACKWARD_WITHOUT_DATA: bool = false; +// impl BinaryKernel for Webgpu { +// const BACKWARD_WITHOUT_DATA: bool = false; - fn forward( - &self, - op: super::BCEKernelOp, - lhs: Cow>, - rhs: Cow>, - ) -> Result, crate::prelude::Error> { - todo!() - } +// fn forward( +// &self, +// op: super::BCEKernelOp, +// lhs: Cow>, +// rhs: Cow>, +// ) -> Result, crate::prelude::Error> { +// todo!() +// } - fn backward( - &self, - op: super::BCEKernelOp, - lhs: &impl crate::prelude::Tensorlike, - grad_lhs: &mut Self::Vec, - rhs: &impl crate::prelude::Tensorlike, - grad_rhs: &mut Self::Vec, - grad_out: &Self::Vec, - ) -> Result<(), crate::prelude::Error> { - todo!() - } -} +// fn backward( +// &self, +// op: super::BCEKernelOp, +// lhs: &impl crate::prelude::Tensorlike, +// grad_lhs: &mut Self::Vec, +// rhs: &impl crate::prelude::Tensorlike, +// grad_rhs: &mut Self::Vec, +// grad_out: &Self::Vec, +// ) -> Result<(), crate::prelude::Error> { +// todo!() +// } +// } diff --git a/dfdx-core/src/tensor_ops/boolean/mod.rs b/dfdx-core/src/tensor_ops/boolean/mod.rs index e86c4d16..3a4886ca 100644 --- a/dfdx-core/src/tensor_ops/boolean/mod.rs +++ b/dfdx-core/src/tensor_ops/boolean/mod.rs @@ -3,8 +3,8 @@ mod cpu_kernels; #[cfg(feature = "cuda")] mod cuda_kernels; -#[cfg(feature = "webgpu")] -mod webgpu_kernels; +// #[cfg(feature = "webgpu")] +// mod webgpu_kernels; use crate::{ prelude::{OnesTensor, Tensor, ZerosTensor}, diff --git a/dfdx-core/src/tensor_ops/choose/mod.rs b/dfdx-core/src/tensor_ops/choose/mod.rs index f391bd75..78a045b9 100644 --- a/dfdx-core/src/tensor_ops/choose/mod.rs +++ b/dfdx-core/src/tensor_ops/choose/mod.rs @@ -3,8 +3,8 @@ mod cpu_kernel; #[cfg(feature = "cuda")] mod cuda_kernel; -#[cfg(feature = "webgpu")] -mod webgpu_kernel; +// #[cfg(feature = "webgpu")] +// mod webgpu_kernel; use crate::{ shapes::{Dtype, HasShape, Shape}, diff --git a/dfdx-core/src/tensor_ops/clamp/mod.rs b/dfdx-core/src/tensor_ops/clamp/mod.rs index 88d246bc..714e96ac 100644 --- a/dfdx-core/src/tensor_ops/clamp/mod.rs +++ b/dfdx-core/src/tensor_ops/clamp/mod.rs @@ -3,8 +3,8 @@ mod cpu_kernel; #[cfg(feature = "cuda")] mod cuda_kernel; -#[cfg(feature = "webgpu")] -mod webgpu_kernel; +// #[cfg(feature = "webgpu")] +// mod webgpu_kernel; use super::ops::{try_unary_op, UnaryKernel}; use crate::{shapes::*, tensor::*}; diff --git a/dfdx-core/src/tensor_ops/concat/mod.rs b/dfdx-core/src/tensor_ops/concat/mod.rs index 1e719fc5..be030039 100644 --- a/dfdx-core/src/tensor_ops/concat/mod.rs +++ b/dfdx-core/src/tensor_ops/concat/mod.rs @@ -4,8 +4,8 @@ mod cpu_kernel; #[cfg(feature = "cuda")] mod cuda_kernel; -#[cfg(feature = "webgpu")] -mod webgpu_kernel; +// #[cfg(feature = "webgpu")] +// mod webgpu_kernel; /// Concatenate two tensors along the first dimension. /// diff --git a/dfdx-core/src/tensor_ops/concat_tensor_along/mod.rs b/dfdx-core/src/tensor_ops/concat_tensor_along/mod.rs index 7462fd2b..9bd26dc4 100644 --- a/dfdx-core/src/tensor_ops/concat_tensor_along/mod.rs +++ b/dfdx-core/src/tensor_ops/concat_tensor_along/mod.rs @@ -4,8 +4,8 @@ use crate::{shapes::*, tensor::*}; pub(crate) mod cpu_kernel; #[cfg(feature = "cuda")] pub(crate) mod cuda_kernel; -#[cfg(feature = "webgpu")] -mod webgpu_kernel; +// #[cfg(feature = "webgpu")] +// mod webgpu_kernel; /// Concatenate two tensors along a given axis. /// diff --git a/dfdx-core/src/tensor_ops/cos/mod.rs b/dfdx-core/src/tensor_ops/cos/mod.rs index 434b1db8..1092e99a 100644 --- a/dfdx-core/src/tensor_ops/cos/mod.rs +++ b/dfdx-core/src/tensor_ops/cos/mod.rs @@ -3,8 +3,8 @@ mod cpu_kernel; #[cfg(feature = "cuda")] mod cuda_kernel; -#[cfg(feature = "webgpu")] -mod webgpu_kernel; +// #[cfg(feature = "webgpu")] +// mod webgpu_kernel; use super::ops::{try_unary_op, UnaryKernel}; use crate::{shapes::*, tensor::*}; diff --git a/dfdx-core/src/tensor_ops/div/mod.rs b/dfdx-core/src/tensor_ops/div/mod.rs index 7aa56063..e61f833b 100644 --- a/dfdx-core/src/tensor_ops/div/mod.rs +++ b/dfdx-core/src/tensor_ops/div/mod.rs @@ -3,8 +3,8 @@ mod cpu_kernel; #[cfg(feature = "cuda")] mod cuda_kernel; -#[cfg(feature = "webgpu")] -mod webgpu_kernel; +// #[cfg(feature = "webgpu")] +// mod webgpu_kernel; use super::ops::*; use crate::{shapes::*, tensor::*}; diff --git a/dfdx-core/src/tensor_ops/div/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/div/webgpu_kernel.rs index 3a15ef7e..c9d8453b 100644 --- a/dfdx-core/src/tensor_ops/div/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/div/webgpu_kernel.rs @@ -30,27 +30,27 @@ impl UnaryKernel, E> for Webgpu { } } -impl BinaryKernel for Webgpu { - const BACKWARD_WITHOUT_DATA: bool = true; - - fn forward( - &self, - op: super::BinaryDivKernelOp, - lhs: Cow>, - rhs: Cow>, - ) -> Result, crate::prelude::Error> { - todo!() - } - - fn backward( - &self, - op: super::BinaryDivKernelOp, - lhs: &impl crate::prelude::Tensorlike, - grad_lhs: &mut Self::Vec, - rhs: &impl crate::prelude::Tensorlike, - grad_rhs: &mut Self::Vec, - grad_out: &Self::Vec, - ) -> Result<(), crate::prelude::Error> { - todo!() - } -} +// impl BinaryKernel for Webgpu { +// const BACKWARD_WITHOUT_DATA: bool = true; + +// fn forward( +// &self, +// op: super::BinaryDivKernelOp, +// lhs: Cow>, +// rhs: Cow>, +// ) -> Result, crate::prelude::Error> { +// todo!() +// } + +// fn backward( +// &self, +// op: super::BinaryDivKernelOp, +// lhs: &impl crate::prelude::Tensorlike, +// grad_lhs: &mut Self::Vec, +// rhs: &impl crate::prelude::Tensorlike, +// grad_rhs: &mut Self::Vec, +// grad_out: &Self::Vec, +// ) -> Result<(), crate::prelude::Error> { +// todo!() +// } +// } diff --git a/dfdx-core/src/tensor_ops/dropout/mod.rs b/dfdx-core/src/tensor_ops/dropout/mod.rs index 9277669a..ebdba135 100644 --- a/dfdx-core/src/tensor_ops/dropout/mod.rs +++ b/dfdx-core/src/tensor_ops/dropout/mod.rs @@ -3,8 +3,8 @@ mod cpu_kernel; #[cfg(feature = "cuda")] mod cuda_kernel; -#[cfg(feature = "webgpu")] -mod webgpu_kernel; +// #[cfg(feature = "webgpu")] +// mod webgpu_kernel; use crate::{shapes::*, tensor::*}; diff --git a/dfdx-core/src/tensor_ops/exp/mod.rs b/dfdx-core/src/tensor_ops/exp/mod.rs index 5d1066f3..f0c0ebe2 100644 --- a/dfdx-core/src/tensor_ops/exp/mod.rs +++ b/dfdx-core/src/tensor_ops/exp/mod.rs @@ -3,8 +3,8 @@ mod cpu_kernel; #[cfg(feature = "cuda")] mod cuda_kernel; -#[cfg(feature = "webgpu")] -mod webgpu_kernel; +// #[cfg(feature = "webgpu")] +// mod webgpu_kernel; use super::ops::{try_unary_op, UnaryKernel}; use crate::{shapes::*, tensor::*}; diff --git a/dfdx-core/src/tensor_ops/fast_gelu/mod.rs b/dfdx-core/src/tensor_ops/fast_gelu/mod.rs index 45c7dad6..85b8e1f8 100644 --- a/dfdx-core/src/tensor_ops/fast_gelu/mod.rs +++ b/dfdx-core/src/tensor_ops/fast_gelu/mod.rs @@ -3,8 +3,8 @@ mod cpu_kernel; #[cfg(feature = "cuda")] mod cuda_kernel; -#[cfg(feature = "webgpu")] -mod webgpu_kernel; +// #[cfg(feature = "webgpu")] +// mod webgpu_kernel; use super::ops::{try_unary_op, UnaryKernel}; use crate::{shapes::*, tensor::*}; diff --git a/dfdx-core/src/tensor_ops/huber_error/mod.rs b/dfdx-core/src/tensor_ops/huber_error/mod.rs index fb7df26e..73b320c8 100644 --- a/dfdx-core/src/tensor_ops/huber_error/mod.rs +++ b/dfdx-core/src/tensor_ops/huber_error/mod.rs @@ -3,8 +3,8 @@ mod cpu_kernel; #[cfg(feature = "cuda")] mod cuda_kernel; -#[cfg(feature = "webgpu")] -mod webgpu_kernel; +// #[cfg(feature = "webgpu")] +// mod webgpu_kernel; use super::{ops::try_binary_op, Device}; use crate::{shapes::*, tensor::*}; diff --git a/dfdx-core/src/tensor_ops/huber_error/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/huber_error/webgpu_kernel.rs index b66b7d1e..d863905b 100644 --- a/dfdx-core/src/tensor_ops/huber_error/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/huber_error/webgpu_kernel.rs @@ -1,27 +1,27 @@ use crate::prelude::{ops::BinaryKernel, Dtype, Webgpu}; use std::borrow::Cow; -impl BinaryKernel, E> for Webgpu { - const BACKWARD_WITHOUT_DATA: bool = false; +// impl BinaryKernel, E> for Webgpu { +// const BACKWARD_WITHOUT_DATA: bool = false; - fn forward( - &self, - op: super::HuberErrorKernelOp, - lhs: Cow>, - rhs: Cow>, - ) -> Result, crate::prelude::Error> { - todo!() - } +// fn forward( +// &self, +// op: super::HuberErrorKernelOp, +// lhs: Cow>, +// rhs: Cow>, +// ) -> Result, crate::prelude::Error> { +// todo!() +// } - fn backward( - &self, - op: super::HuberErrorKernelOp, - lhs: &impl crate::prelude::Tensorlike, - grad_lhs: &mut Self::Vec, - rhs: &impl crate::prelude::Tensorlike, - grad_rhs: &mut Self::Vec, - grad_out: &Self::Vec, - ) -> Result<(), crate::prelude::Error> { - todo!() - } -} +// fn backward( +// &self, +// op: super::HuberErrorKernelOp, +// lhs: &impl crate::prelude::Tensorlike, +// grad_lhs: &mut Self::Vec, +// rhs: &impl crate::prelude::Tensorlike, +// grad_rhs: &mut Self::Vec, +// grad_out: &Self::Vec, +// ) -> Result<(), crate::prelude::Error> { +// todo!() +// } +// } diff --git a/dfdx-core/src/tensor_ops/ln/mod.rs b/dfdx-core/src/tensor_ops/ln/mod.rs index 51bc001f..9bb58d39 100644 --- a/dfdx-core/src/tensor_ops/ln/mod.rs +++ b/dfdx-core/src/tensor_ops/ln/mod.rs @@ -3,8 +3,8 @@ mod cpu_kernel; #[cfg(feature = "cuda")] mod cuda_kernel; -#[cfg(feature = "webgpu")] -mod webgpu_kernel; +// #[cfg(feature = "webgpu")] +// mod webgpu_kernel; use super::ops::{try_unary_op, UnaryKernel}; use crate::{shapes::*, tensor::*}; diff --git a/dfdx-core/src/tensor_ops/max_to/mod.rs b/dfdx-core/src/tensor_ops/max_to/mod.rs index e00ba600..538e96c4 100644 --- a/dfdx-core/src/tensor_ops/max_to/mod.rs +++ b/dfdx-core/src/tensor_ops/max_to/mod.rs @@ -3,8 +3,8 @@ mod cpu_kernel; #[cfg(feature = "cuda")] mod cuda_kernel; -#[cfg(feature = "webgpu")] -mod webgpu_kernel; +// #[cfg(feature = "webgpu")] +// mod webgpu_kernel; use crate::{shapes::*, tensor::*}; diff --git a/dfdx-core/src/tensor_ops/maximum/mod.rs b/dfdx-core/src/tensor_ops/maximum/mod.rs index e1d1a89a..8f5ed9dc 100644 --- a/dfdx-core/src/tensor_ops/maximum/mod.rs +++ b/dfdx-core/src/tensor_ops/maximum/mod.rs @@ -3,8 +3,8 @@ mod cpu_kernel; #[cfg(feature = "cuda")] mod cuda_kernel; -#[cfg(feature = "webgpu")] -mod webgpu_kernel; +// #[cfg(feature = "webgpu")] +// mod webgpu_kernel; use super::{ops::try_binary_op, Device}; use crate::{shapes::*, tensor::*}; diff --git a/dfdx-core/src/tensor_ops/maximum/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/maximum/webgpu_kernel.rs index 690e2471..938994ad 100644 --- a/dfdx-core/src/tensor_ops/maximum/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/maximum/webgpu_kernel.rs @@ -1,27 +1,27 @@ use crate::prelude::{ops::BinaryKernel, Dtype, Webgpu}; use std::borrow::Cow; -impl BinaryKernel for Webgpu { - const BACKWARD_WITHOUT_DATA: bool = false; +// impl BinaryKernel for Webgpu { +// const BACKWARD_WITHOUT_DATA: bool = false; - fn forward( - &self, - op: super::MaximumKernelOp, - lhs: Cow>, - rhs: Cow>, - ) -> Result, crate::prelude::Error> { - todo!() - } +// fn forward( +// &self, +// op: super::MaximumKernelOp, +// lhs: Cow>, +// rhs: Cow>, +// ) -> Result, crate::prelude::Error> { +// todo!() +// } - fn backward( - &self, - op: super::MaximumKernelOp, - lhs: &impl crate::prelude::Tensorlike, - grad_lhs: &mut Self::Vec, - rhs: &impl crate::prelude::Tensorlike, - grad_rhs: &mut Self::Vec, - grad_out: &Self::Vec, - ) -> Result<(), crate::prelude::Error> { - todo!() - } -} +// fn backward( +// &self, +// op: super::MaximumKernelOp, +// lhs: &impl crate::prelude::Tensorlike, +// grad_lhs: &mut Self::Vec, +// rhs: &impl crate::prelude::Tensorlike, +// grad_rhs: &mut Self::Vec, +// grad_out: &Self::Vec, +// ) -> Result<(), crate::prelude::Error> { +// todo!() +// } +// } diff --git a/dfdx-core/src/tensor_ops/min_to/mod.rs b/dfdx-core/src/tensor_ops/min_to/mod.rs index 38166f6f..931d83ae 100644 --- a/dfdx-core/src/tensor_ops/min_to/mod.rs +++ b/dfdx-core/src/tensor_ops/min_to/mod.rs @@ -3,8 +3,8 @@ mod cpu_kernel; #[cfg(feature = "cuda")] mod cuda_kernel; -#[cfg(feature = "webgpu")] -mod webgpu_kernel; +// #[cfg(feature = "webgpu")] +// mod webgpu_kernel; use crate::{shapes::*, tensor::*}; diff --git a/dfdx-core/src/tensor_ops/minimum/mod.rs b/dfdx-core/src/tensor_ops/minimum/mod.rs index f6b9b6e1..2057d180 100644 --- a/dfdx-core/src/tensor_ops/minimum/mod.rs +++ b/dfdx-core/src/tensor_ops/minimum/mod.rs @@ -3,8 +3,8 @@ mod cpu_kernel; #[cfg(feature = "cuda")] mod cuda_kernel; -#[cfg(feature = "webgpu")] -mod webgpu_kernel; +// #[cfg(feature = "webgpu")] +// mod webgpu_kernel; use super::{ops::try_binary_op, Device}; use crate::{shapes::*, tensor::*}; diff --git a/dfdx-core/src/tensor_ops/minimum/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/minimum/webgpu_kernel.rs index 5ebcc561..6a4e9690 100644 --- a/dfdx-core/src/tensor_ops/minimum/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/minimum/webgpu_kernel.rs @@ -1,27 +1,27 @@ use crate::prelude::{ops::BinaryKernel, Dtype, Webgpu}; use std::borrow::Cow; -impl BinaryKernel for Webgpu { - const BACKWARD_WITHOUT_DATA: bool = false; +// impl BinaryKernel for Webgpu { +// const BACKWARD_WITHOUT_DATA: bool = false; - fn forward( - &self, - op: super::MinimumKernelOp, - lhs: Cow>, - rhs: Cow>, - ) -> Result, crate::prelude::Error> { - todo!() - } +// fn forward( +// &self, +// op: super::MinimumKernelOp, +// lhs: Cow>, +// rhs: Cow>, +// ) -> Result, crate::prelude::Error> { +// todo!() +// } - fn backward( - &self, - op: super::MinimumKernelOp, - lhs: &impl crate::prelude::Tensorlike, - grad_lhs: &mut Self::Vec, - rhs: &impl crate::prelude::Tensorlike, - grad_rhs: &mut Self::Vec, - grad_out: &Self::Vec, - ) -> Result<(), crate::prelude::Error> { - todo!() - } -} +// fn backward( +// &self, +// op: super::MinimumKernelOp, +// lhs: &impl crate::prelude::Tensorlike, +// grad_lhs: &mut Self::Vec, +// rhs: &impl crate::prelude::Tensorlike, +// grad_rhs: &mut Self::Vec, +// grad_out: &Self::Vec, +// ) -> Result<(), crate::prelude::Error> { +// todo!() +// } +// } diff --git a/dfdx-core/src/tensor_ops/mul/mod.rs b/dfdx-core/src/tensor_ops/mul/mod.rs index 0dccebd8..54bf550e 100644 --- a/dfdx-core/src/tensor_ops/mul/mod.rs +++ b/dfdx-core/src/tensor_ops/mul/mod.rs @@ -3,8 +3,8 @@ mod cpu_kernel; #[cfg(feature = "cuda")] mod cuda_kernel; -#[cfg(feature = "webgpu")] -mod webgpu_kernel; +// #[cfg(feature = "webgpu")] +// mod webgpu_kernel; use super::ops::*; use crate::{shapes::*, tensor::*}; diff --git a/dfdx-core/src/tensor_ops/mul/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/mul/webgpu_kernel.rs index 240ba571..76f44498 100644 --- a/dfdx-core/src/tensor_ops/mul/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/mul/webgpu_kernel.rs @@ -30,27 +30,27 @@ impl UnaryKernel, E> for Webgpu { } } -impl BinaryKernel for Webgpu { - const BACKWARD_WITHOUT_DATA: bool = true; - - fn forward( - &self, - op: super::BinaryMulKernelOp, - lhs: Cow>, - rhs: Cow>, - ) -> Result, crate::prelude::Error> { - todo!() - } - - fn backward( - &self, - op: super::BinaryMulKernelOp, - lhs: &impl crate::prelude::Tensorlike, - grad_lhs: &mut Self::Vec, - rhs: &impl crate::prelude::Tensorlike, - grad_rhs: &mut Self::Vec, - grad_out: &Self::Vec, - ) -> Result<(), crate::prelude::Error> { - todo!() - } -} +// impl BinaryKernel for Webgpu { +// const BACKWARD_WITHOUT_DATA: bool = true; + +// fn forward( +// &self, +// op: super::BinaryMulKernelOp, +// lhs: Cow>, +// rhs: Cow>, +// ) -> Result, crate::prelude::Error> { +// todo!() +// } + +// fn backward( +// &self, +// op: super::BinaryMulKernelOp, +// lhs: &impl crate::prelude::Tensorlike, +// grad_lhs: &mut Self::Vec, +// rhs: &impl crate::prelude::Tensorlike, +// grad_rhs: &mut Self::Vec, +// grad_out: &Self::Vec, +// ) -> Result<(), crate::prelude::Error> { +// todo!() +// } +// } diff --git a/dfdx-core/src/tensor_ops/nans_to/mod.rs b/dfdx-core/src/tensor_ops/nans_to/mod.rs index f3ade77e..c6f0a689 100644 --- a/dfdx-core/src/tensor_ops/nans_to/mod.rs +++ b/dfdx-core/src/tensor_ops/nans_to/mod.rs @@ -3,8 +3,8 @@ mod cpu_kernel; #[cfg(feature = "cuda")] mod cuda_kernel; -#[cfg(feature = "webgpu")] -mod webgpu_kernel; +// #[cfg(feature = "webgpu")] +// mod webgpu_kernel; use super::ops::{try_unary_op, UnaryKernel}; use crate::{shapes::*, tensor::*}; diff --git a/dfdx-core/src/tensor_ops/negate/mod.rs b/dfdx-core/src/tensor_ops/negate/mod.rs index f6dfa820..ce3f6c44 100644 --- a/dfdx-core/src/tensor_ops/negate/mod.rs +++ b/dfdx-core/src/tensor_ops/negate/mod.rs @@ -3,8 +3,8 @@ mod cpu_kernel; #[cfg(feature = "cuda")] mod cuda_kernel; -#[cfg(feature = "webgpu")] -mod webgpu_kernel; +// #[cfg(feature = "webgpu")] +// mod webgpu_kernel; use super::ops::{try_unary_op, UnaryKernel}; use crate::{shapes::*, tensor::*}; diff --git a/dfdx-core/src/tensor_ops/pow/mod.rs b/dfdx-core/src/tensor_ops/pow/mod.rs index 83f12f4b..29fd0982 100644 --- a/dfdx-core/src/tensor_ops/pow/mod.rs +++ b/dfdx-core/src/tensor_ops/pow/mod.rs @@ -3,8 +3,8 @@ mod cpu_kernel; #[cfg(feature = "cuda")] mod cuda_kernel; -#[cfg(feature = "webgpu")] -mod webgpu_kernel; +// #[cfg(feature = "webgpu")] +// mod webgpu_kernel; use super::ops::{try_unary_op, UnaryKernel}; use crate::{shapes::*, tensor::*}; diff --git a/dfdx-core/src/tensor_ops/recip/mod.rs b/dfdx-core/src/tensor_ops/recip/mod.rs index 57b76b26..90a86f6f 100644 --- a/dfdx-core/src/tensor_ops/recip/mod.rs +++ b/dfdx-core/src/tensor_ops/recip/mod.rs @@ -3,8 +3,8 @@ mod cpu_kernel; #[cfg(feature = "cuda")] mod cuda_kernel; -#[cfg(feature = "webgpu")] -mod webgpu_kernel; +// #[cfg(feature = "webgpu")] +// mod webgpu_kernel; use super::ops::{try_unary_op, UnaryKernel}; use crate::{shapes::*, tensor::*}; diff --git a/dfdx-core/src/tensor_ops/relu/mod.rs b/dfdx-core/src/tensor_ops/relu/mod.rs index 31496368..84008aef 100644 --- a/dfdx-core/src/tensor_ops/relu/mod.rs +++ b/dfdx-core/src/tensor_ops/relu/mod.rs @@ -3,8 +3,8 @@ mod cpu_kernel; #[cfg(feature = "cuda")] mod cuda_kernel; -#[cfg(feature = "webgpu")] -mod webgpu_kernel; +// #[cfg(feature = "webgpu")] +// mod webgpu_kernel; use super::ops::{try_unary_op, UnaryKernel}; use crate::{shapes::*, tensor::*}; diff --git a/dfdx-core/src/tensor_ops/reshape_to/mod.rs b/dfdx-core/src/tensor_ops/reshape_to/mod.rs index 8778de14..98ea3c5e 100644 --- a/dfdx-core/src/tensor_ops/reshape_to/mod.rs +++ b/dfdx-core/src/tensor_ops/reshape_to/mod.rs @@ -3,8 +3,8 @@ mod cpu_kernel; #[cfg(feature = "cuda")] mod cuda_kernel; -#[cfg(feature = "webgpu")] -mod webgpu_kernel; +// #[cfg(feature = "webgpu")] +// mod webgpu_kernel; use crate::{shapes::*, tensor::*}; diff --git a/dfdx-core/src/tensor_ops/rmsprop/mod.rs b/dfdx-core/src/tensor_ops/rmsprop/mod.rs index cc031546..6a0f6ad3 100644 --- a/dfdx-core/src/tensor_ops/rmsprop/mod.rs +++ b/dfdx-core/src/tensor_ops/rmsprop/mod.rs @@ -3,8 +3,8 @@ mod cpu_kernel; #[cfg(feature = "cuda")] mod cuda_kernel; -#[cfg(feature = "webgpu")] -mod webgpu_kernel; +// #[cfg(feature = "webgpu")] +// mod webgpu_kernel; use crate::{ shapes::{Dtype, Shape}, diff --git a/dfdx-core/src/tensor_ops/roll/mod.rs b/dfdx-core/src/tensor_ops/roll/mod.rs index b0f1237d..7d507479 100644 --- a/dfdx-core/src/tensor_ops/roll/mod.rs +++ b/dfdx-core/src/tensor_ops/roll/mod.rs @@ -7,8 +7,8 @@ mod cpu_kernel; #[cfg(feature = "cuda")] mod cuda_kernel; -#[cfg(feature = "webgpu")] -mod webgpu_kernel; +// #[cfg(feature = "webgpu")] +// mod webgpu_kernel; #[repr(C)] #[derive(Copy, Clone, Debug)] diff --git a/dfdx-core/src/tensor_ops/select_and_gather/mod.rs b/dfdx-core/src/tensor_ops/select_and_gather/mod.rs index b5ccebb0..09b96e1d 100644 --- a/dfdx-core/src/tensor_ops/select_and_gather/mod.rs +++ b/dfdx-core/src/tensor_ops/select_and_gather/mod.rs @@ -5,8 +5,8 @@ mod cpu_kernel; #[cfg(feature = "cuda")] mod cuda_kernel; -#[cfg(feature = "webgpu")] -mod webgpu_kernel; +// #[cfg(feature = "webgpu")] +// mod webgpu_kernel; use crate::{shapes::*, tensor::*}; diff --git a/dfdx-core/src/tensor_ops/sgd/mod.rs b/dfdx-core/src/tensor_ops/sgd/mod.rs index 3cc28c05..43d73c3a 100644 --- a/dfdx-core/src/tensor_ops/sgd/mod.rs +++ b/dfdx-core/src/tensor_ops/sgd/mod.rs @@ -3,8 +3,8 @@ mod cpu_kernel; #[cfg(feature = "cuda")] mod cuda_kernel; -#[cfg(feature = "webgpu")] -mod webgpu_kernel; +// #[cfg(feature = "webgpu")] +// mod webgpu_kernel; use crate::{ shapes::{Dtype, Shape}, diff --git a/dfdx-core/src/tensor_ops/sigmoid/mod.rs b/dfdx-core/src/tensor_ops/sigmoid/mod.rs index d2fdfe5e..a1b177ad 100644 --- a/dfdx-core/src/tensor_ops/sigmoid/mod.rs +++ b/dfdx-core/src/tensor_ops/sigmoid/mod.rs @@ -3,8 +3,8 @@ mod cpu_kernel; #[cfg(feature = "cuda")] mod cuda_kernel; -#[cfg(feature = "webgpu")] -mod webgpu_kernel; +// #[cfg(feature = "webgpu")] +// mod webgpu_kernel; use super::ops::{try_unary_op, UnaryKernel}; use crate::{shapes::*, tensor::*}; diff --git a/dfdx-core/src/tensor_ops/sin/mod.rs b/dfdx-core/src/tensor_ops/sin/mod.rs index 033e2cb3..78828bbd 100644 --- a/dfdx-core/src/tensor_ops/sin/mod.rs +++ b/dfdx-core/src/tensor_ops/sin/mod.rs @@ -3,8 +3,8 @@ mod cpu_kernel; #[cfg(feature = "cuda")] mod cuda_kernel; -#[cfg(feature = "webgpu")] -mod webgpu_kernel; +// #[cfg(feature = "webgpu")] +// mod webgpu_kernel; use super::ops::{try_unary_op, UnaryKernel}; use crate::{shapes::*, tensor::*}; diff --git a/dfdx-core/src/tensor_ops/slice/mod.rs b/dfdx-core/src/tensor_ops/slice/mod.rs index c69511a5..678a08bb 100644 --- a/dfdx-core/src/tensor_ops/slice/mod.rs +++ b/dfdx-core/src/tensor_ops/slice/mod.rs @@ -4,8 +4,8 @@ mod cpu_kernel; #[cfg(feature = "cuda")] mod cuda_kernel; -#[cfg(feature = "webgpu")] -mod webgpu_kernel; +// #[cfg(feature = "webgpu")] +// mod webgpu_kernel; pub trait SliceKernel: Storage { fn forward, Slice>( diff --git a/dfdx-core/src/tensor_ops/sqrt/mod.rs b/dfdx-core/src/tensor_ops/sqrt/mod.rs index 4138348e..3a507a40 100644 --- a/dfdx-core/src/tensor_ops/sqrt/mod.rs +++ b/dfdx-core/src/tensor_ops/sqrt/mod.rs @@ -3,8 +3,8 @@ mod cpu_kernel; #[cfg(feature = "cuda")] mod cuda_kernel; -#[cfg(feature = "webgpu")] -mod webgpu_kernel; +// #[cfg(feature = "webgpu")] +// mod webgpu_kernel; use super::ops::{try_unary_op, UnaryKernel}; use crate::{shapes::*, tensor::*}; diff --git a/dfdx-core/src/tensor_ops/square/mod.rs b/dfdx-core/src/tensor_ops/square/mod.rs index e4d26c94..88651a23 100644 --- a/dfdx-core/src/tensor_ops/square/mod.rs +++ b/dfdx-core/src/tensor_ops/square/mod.rs @@ -3,8 +3,8 @@ mod cpu_kernel; #[cfg(feature = "cuda")] mod cuda_kernel; -#[cfg(feature = "webgpu")] -mod webgpu_kernel; +// #[cfg(feature = "webgpu")] +// mod webgpu_kernel; use super::ops::{try_unary_op, UnaryKernel}; use crate::{shapes::*, tensor::*}; diff --git a/dfdx-core/src/tensor_ops/stack/mod.rs b/dfdx-core/src/tensor_ops/stack/mod.rs index b30e7e98..f21fd650 100644 --- a/dfdx-core/src/tensor_ops/stack/mod.rs +++ b/dfdx-core/src/tensor_ops/stack/mod.rs @@ -6,8 +6,8 @@ mod cpu_kernel; #[cfg(feature = "cuda")] mod cuda_kernel; -#[cfg(feature = "webgpu")] -mod webgpu_kernel; +// #[cfg(feature = "webgpu")] +// mod webgpu_kernel; /// Stack an array or vec of tensors together along a new dimension. /// diff --git a/dfdx-core/src/tensor_ops/sub/mod.rs b/dfdx-core/src/tensor_ops/sub/mod.rs index 9c798cda..9ec4cb22 100644 --- a/dfdx-core/src/tensor_ops/sub/mod.rs +++ b/dfdx-core/src/tensor_ops/sub/mod.rs @@ -3,8 +3,8 @@ mod cpu_kernel; #[cfg(feature = "cuda")] mod cuda_kernel; -#[cfg(feature = "webgpu")] -mod webgpu_kernel; +// #[cfg(feature = "webgpu")] +// mod webgpu_kernel; use super::ops::*; use crate::{shapes::*, tensor::*}; diff --git a/dfdx-core/src/tensor_ops/sub/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/sub/webgpu_kernel.rs index 8d5e943e..864bd0a8 100644 --- a/dfdx-core/src/tensor_ops/sub/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/sub/webgpu_kernel.rs @@ -30,27 +30,27 @@ impl UnaryKernel, E> for Webgpu { } } -impl BinaryKernel for Webgpu { - const BACKWARD_WITHOUT_DATA: bool = true; - - fn forward( - &self, - op: super::BinarySubKernelOp, - lhs: Cow>, - rhs: Cow>, - ) -> Result, crate::prelude::Error> { - todo!() - } - - fn backward( - &self, - op: super::BinarySubKernelOp, - lhs: &impl crate::prelude::Tensorlike, - grad_lhs: &mut Self::Vec, - rhs: &impl crate::prelude::Tensorlike, - grad_rhs: &mut Self::Vec, - grad_out: &Self::Vec, - ) -> Result<(), crate::prelude::Error> { - todo!() - } -} +// impl BinaryKernel for Webgpu { +// const BACKWARD_WITHOUT_DATA: bool = true; + +// fn forward( +// &self, +// op: super::BinarySubKernelOp, +// lhs: Cow>, +// rhs: Cow>, +// ) -> Result, crate::prelude::Error> { +// todo!() +// } + +// fn backward( +// &self, +// op: super::BinarySubKernelOp, +// lhs: &impl crate::prelude::Tensorlike, +// grad_lhs: &mut Self::Vec, +// rhs: &impl crate::prelude::Tensorlike, +// grad_rhs: &mut Self::Vec, +// grad_out: &Self::Vec, +// ) -> Result<(), crate::prelude::Error> { +// todo!() +// } +// } diff --git a/dfdx-core/src/tensor_ops/sum_to/mod.rs b/dfdx-core/src/tensor_ops/sum_to/mod.rs index ad149df7..ee4d18f1 100644 --- a/dfdx-core/src/tensor_ops/sum_to/mod.rs +++ b/dfdx-core/src/tensor_ops/sum_to/mod.rs @@ -3,8 +3,8 @@ mod cpu_kernel; #[cfg(feature = "cuda")] mod cuda_kernel; -#[cfg(feature = "webgpu")] -mod webgpu_kernel; +// #[cfg(feature = "webgpu")] +// mod webgpu_kernel; use crate::{shapes::*, tensor::*}; diff --git a/dfdx-core/src/tensor_ops/tanh/mod.rs b/dfdx-core/src/tensor_ops/tanh/mod.rs index 65340a87..d824c146 100644 --- a/dfdx-core/src/tensor_ops/tanh/mod.rs +++ b/dfdx-core/src/tensor_ops/tanh/mod.rs @@ -3,8 +3,8 @@ mod cpu_kernel; #[cfg(feature = "cuda")] mod cuda_kernel; -#[cfg(feature = "webgpu")] -mod webgpu_kernel; +// #[cfg(feature = "webgpu")] +// mod webgpu_kernel; use super::ops::{try_unary_op, UnaryKernel}; use crate::{shapes::*, tensor::*}; diff --git a/dfdx-core/src/tensor_ops/to_dtype/to_dtype.wgsl b/dfdx-core/src/tensor_ops/to_dtype/to_dtype.wgsl new file mode 100644 index 00000000..67d8fa43 --- /dev/null +++ b/dfdx-core/src/tensor_ops/to_dtype/to_dtype.wgsl @@ -0,0 +1,16 @@ +alias T = __SRC__; +alias U = __DST__; + +@group(0) @binding(0) +var in: array; + +@group(0) @binding(1) +var out: array; + +@compute @workgroup_size(1, 1, 1) +fn main( + @builtin(global_invocation_id) global_id: vec3 +) { + let i = global_id.x; + out[i] = U(in[i]); +} diff --git a/dfdx-core/src/tensor_ops/to_dtype/webgpu_kernel.rs b/dfdx-core/src/tensor_ops/to_dtype/webgpu_kernel.rs index 111b930e..96893a29 100644 --- a/dfdx-core/src/tensor_ops/to_dtype/webgpu_kernel.rs +++ b/dfdx-core/src/tensor_ops/to_dtype/webgpu_kernel.rs @@ -1,9 +1,71 @@ -use crate::prelude::{Unit, Webgpu}; +use crate::{ + tensor::webgpu::{Webgpu, WebgpuNativeType}, + tensor_ops::utilities::webgpu_kernels::webgpu_params, prelude::Storage +}; +use num_traits::AsPrimitive; +use wgpu; -impl super::ToDtypeKernel for Webgpu { +/// kernel template +const KERNEL: &'static str = include_str!("./to_dtype.wgsl"); + +const LAYOUT_DESC: wgpu::BindGroupLayoutDescriptor = wgpu::BindGroupLayoutDescriptor { + label: Some("to-dtype"), + entries: &[ + wgpu::BindGroupLayoutEntry { + binding: 0, + visibility: wgpu::ShaderStages::COMPUTE, + ty: wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Storage { read_only: true }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }, + wgpu::BindGroupLayoutEntry { + binding: 1, + visibility: wgpu::ShaderStages::COMPUTE, + ty: wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Storage { read_only: false }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }, + ], +}; + +impl, E2: WebgpuNativeType> super::ToDtypeKernel + for Webgpu +{ fn forward( inp: crate::prelude::Tensor, ) -> Result, crate::prelude::Error> { - todo!() + let module_name = std::format!("convert_{}_to_{}", E1::NAME, E2::NAME); + let device = inp.device; + + let layout = device.dev.create_bind_group_layout(&LAYOUT_DESC); + let shader_source: String = KERNEL + .replace("__SRC__", E1::NAME) + .replace("__DST__", E2::NAME); + + let pipeline = device.load_pipeline( + module_name.as_str(), + shader_source.as_str(), + "main", + &[&layout], + &[] + ); + + let numel = inp.shape.num_elements(); + let work_groups: (u32, u32, u32) = (numel as u32, 1, 1); + let shape = inp.shape; + let strides = shape.strides(); + let output = unsafe { device.alloc_empty::(numel) }?; + + let params: wgpu::BindGroup = webgpu_params!(device, pipeline; inp.data, output); + + let _idx = device.submit_basic_op(&pipeline, ¶ms, Some(module_name.as_str()), &work_groups); + + Ok(device.build_tensor(shape, strides, output)) } } diff --git a/dfdx-core/src/tensor_ops/utilities/device.rs b/dfdx-core/src/tensor_ops/utilities/device.rs index 388eea7f..d7595183 100644 --- a/dfdx-core/src/tensor_ops/utilities/device.rs +++ b/dfdx-core/src/tensor_ops/utilities/device.rs @@ -131,11 +131,11 @@ impl Device for crate::tensor::Cuda {} #[cfg(feature = "cuda")] impl Device for crate::tensor::Cuda {} -#[cfg(all(feature = "webgpu", feature = "f16"))] -impl Device for crate::tensor::Webgpu {} -#[cfg(all(feature = "webgpu", feature = "f16"))] -impl Device> for crate::tensor::Webgpu {} -#[cfg(feature = "webgpu")] -impl Device for crate::tensor::Webgpu {} -#[cfg(feature = "webgpu")] -impl Device for crate::tensor::Webgpu {} +// #[cfg(all(feature = "webgpu", feature = "f16"))] +// impl Device for crate::tensor::Webgpu {} +// #[cfg(all(feature = "webgpu", feature = "f16"))] +// impl Device> for crate::tensor::Webgpu {} +// #[cfg(feature = "webgpu")] +// impl Device for crate::tensor::Webgpu {} +// #[cfg(feature = "webgpu")] +// impl Device for crate::tensor::Webgpu {} diff --git a/dfdx-core/src/tensor_ops/utilities/mod.rs b/dfdx-core/src/tensor_ops/utilities/mod.rs index e23565bc..02db7e2d 100644 --- a/dfdx-core/src/tensor_ops/utilities/mod.rs +++ b/dfdx-core/src/tensor_ops/utilities/mod.rs @@ -2,6 +2,9 @@ mod backward; pub(crate) mod cpu_kernels; #[cfg(feature = "cuda")] pub(crate) mod cuda_kernels; +#[cfg(feature = "webgpu")] +pub(crate) mod webgpu_kernels; + mod device; pub(crate) mod ops; pub(crate) mod reduction_utils; diff --git a/dfdx-core/src/tensor_ops/utilities/webgpu_kernels.rs b/dfdx-core/src/tensor_ops/utilities/webgpu_kernels.rs new file mode 100644 index 00000000..ab5cf767 --- /dev/null +++ b/dfdx-core/src/tensor_ops/utilities/webgpu_kernels.rs @@ -0,0 +1,230 @@ +extern crate alloc; + +use crate::{ + shapes::{Dtype, Shape}, + tensor::{webgpu::Webgpu, Error, Tensor}, + tensor_ops::ops::{BinaryKernel, UnaryKernel}, +}; +// FIXME: nostd support +use std::ops::Range; + +use alloc::{borrow::Cow, sync::Arc}; + +/// Creates a [`BindGroup`] for a pipeline from a set of [`wgpu::BindingResource`]s. +macro_rules! webgpu_params { + ($self:expr, $pipeline:expr; $($x:expr),+ $(,)? ) => { + { + let bindings = [$($x.as_entire_binding()),+]; + let entries: Vec<_> = bindings + .into_iter() + .enumerate() + .map(|(i, binding)| wgpu::BindGroupEntry { + binding: i as u32, + resource: binding, + }) + .collect(); + $self.dev.create_bind_group(&::wgpu::BindGroupDescriptor { + label: None, + layout: &($pipeline).get_bind_group_layout(0), + entries: &entries + }) + } + } +} +pub(crate) use webgpu_params; + +pub trait UnaryOpWebgpuKernel { + const DF_USES_FX: bool; + const HAS_CONST_DF: bool; + + /// WGSL source code for the kernel + const WGSL_SRC: &'static str; + + /// Unique name for the kernel + const MODULE_NAME: &'static str; + + /// Name of function in the .wgsl file (used as entrypoint) + const FWD_FN_NAME: &'static str; + + /// Name of function in the .ggsl file (used as entrypoint) + const BWD_FN_NAME: &'static str; + + const ALL_FN_NAMES: [&'static str; 2] = [Self::FWD_FN_NAME, Self::BWD_FN_NAME]; + + /// Extra parameters to pass to the kernel. + fn params(&self) -> Option<&[Range]> { + None + } +} + +macro_rules! webgpu_unary { + ($Op:path, $TypeName:ty, $Wgsl:tt, $Fwd:tt, $Bwd:tt) => { + impl crate::tensor_ops::wgpu_kernels::UnaryOpWebgpuKernel<$TypeName> for $Op { + const DF_USES_FX: bool = false; + const HAS_CONST_DF: bool = false; + const WGSL_SRC: &'static str = include_str!($Wgsl); + const MODULE_NAME: &'static str = $Fwd; + const FWD_FN_NAME: &'static str = $Fwd; + const BWD_FN_NAME: &'static str = $Bwd; + } + }; + (df(f(x)) $Op:path, $TypeName:ty, $Wgsl:tt, $Fwd:tt, $Bwd:tt) => { + impl crate::tensor_ops::wgpu_kernels::UnaryOpWebgpuKernel<$TypeName> for $Op { + const DF_USES_FX: bool = true; + const HAS_CONST_DF: bool = false; + const WGSL_SRC: &'static str = include_str!($Wgsl); + const MODULE_NAME: &'static str = $Fwd; + const FWD_FN_NAME: &'static str = $Fwd; + const BWD_FN_NAME: &'static str = $Bwd; + } + }; + (const_df() $Op:path, $TypeName:ty, $Wgsl:tt, $Fwd:tt, $Bwd:tt) => { + impl crate::tensor_ops::wgpu_kernels::UnaryOpWebgpuKernel<$TypeName> for $Op { + const DF_USES_FX: bool = false; + const HAS_CONST_DF: bool = true; + const WGSL_SRC: &'static str = include_str!($Wgsl); + const MODULE_NAME: &'static str = $Fwd; + const FWD_FN_NAME: &'static str = $Fwd; + const BWD_FN_NAME: &'static str = $Bwd; + } + }; +} +pub(crate) use webgpu_unary; +impl> UnaryKernel for Webgpu { + const BACKWARD_WITHOUT_INP: bool = K::DF_USES_FX; + const BACKWARD_WITHOUT_DATA: bool = K::HAS_CONST_DF; + + fn forward( + &self, + op: K, + inp: Cow>, + ) -> Result, Error> { + let shape = match &inp { + Cow::Borrowed(lhs) => inp.shape, + Cow::Owned(lhs) => inp.shape, + }; + let strides = shape.strides(); + let numel = shape.num_elements(); + // todo: dream about memory64 + // https://github.com/WebAssembly/memory64 + let work_groups: (u32, u32, u32) = (numel as u32, 1, 1); + todo!("Webgpu unary forwards") + } + + fn backward( + &self, + op: K, + inp: &impl crate::prelude::Tensorlike, + grad_inp: &mut Self::Vec, + out: &impl crate::prelude::Tensorlike, + grad_out: &Self::Vec, + ) -> Result<(), Error> { + todo!("Wgpu unary backwards") + } +} + +pub trait BinaryOpWebgpuKernel { + const HAS_CONST_DF: bool; + + /// WGSL source code for the kernel + const WGSL_SRC: &'static str; + + /// Unique name for the kernel + const MODULE_NAME: &'static str; + + /// Name of function in the .wgsl file + const FWD_FN_NAME: &'static str; + + /// Name of function in the .wgsl file + const BWD_LHS_FN_NAME: &'static str; + + /// Name of function in the .wgsl file + const BWD_RHS_FN_NAME: &'static str; + + const ALL_FN_NAMES: [&'static str; 3] = [ + Self::FWD_FN_NAME, + Self::BWD_LHS_FN_NAME, + Self::BWD_RHS_FN_NAME, + ]; + + /// Extra parameters to pass to the kernel. + fn params(&self) -> Option<&[Range]> { + None + } +} +macro_rules! wgpu_binary { + ($Op:path, $TypeName:ty, $Wgsl:tt, $Mod:tt, $Fwd:tt, $Bwd_Lhs:tt, $Bwd_Rhs:tt) => { + impl crate::tensor_ops::webgpu_kernels::BinaryOpWebgpuKernel<$TypeName> for $Op { + const HAS_CONST_DF: bool = false; + const WGSL_SRC: &'static str = $Wgsl; + const MODULE_NAME: &'static str = $Mod; + const FWD_FN_NAME: &'static str = $Fwd; + const BWD_LHS_FN_NAME: &'static str = $Bwd_Lhs; + const BWD_RHS_FN_NAME: &'static str = $Bwd_Rhs; + } + }; + (const_df() $Op:path, $TypeName:ty, $Wgsl:tt, $Mod:tt, $Fwd:tt, $Bwd_Lhs:tt, $Bwd_Rhs:tt) => { + impl crate::tensor_ops::webgpu_kernels::BinaryOpWebgpuKernel<$TypeName> for $Op { + const HAS_CONST_DF: bool = true; + const WGSL_SRC: &'static str = $Wgsl; + const MODULE_NAME: &'static str = $Mod; + const FWD_FN_NAME: &'static str = $Fwd; + const BWD_LHS_FN_NAME: &'static str = $Bwd_Lhs; + const BWD_RHS_FN_NAME: &'static str = $Bwd_Rhs; + } + }; +} + +pub(crate) use wgpu_binary; + +impl + Clone> BinaryKernel for Webgpu { + const BACKWARD_WITHOUT_DATA: bool = K::HAS_CONST_DF; + fn forward( + &self, + op: K, + lhs: Cow>, + rhs: Cow>, + ) -> Result, Error> { + let shape = match &lhs { + Cow::Borrowed(lhs) => lhs.shape, + Cow::Owned(lhs) => lhs.shape, + }; + let strides = shape.strides(); + let numel = shape.num_elements(); + // todo: dream about memory64 + // https://github.com/WebAssembly/memory64 + let work_groups: (u32, u32, u32) = (numel as u32, 1, 1); + + let push_constants = op.params().unwrap_or(&[]); + // todo: pipeline caching + let fwd_pipeline = + self.load_binary_pipeline(K::MODULE_NAME, K::WGSL_SRC, K::FWD_FN_NAME, push_constants); + + let output = unsafe { self.alloc_empty::(numel) }?; + + // note: storage buffers cannot be both read from and written to within + // the same pipeline stage, so Cow doesn't change operation behavior. + { + // let (lhs, rhs) = (&lhs, &rhs); + let lhs: &Tensor = lhs.as_ref(); + let rhs: &Tensor = rhs.as_ref(); + let params: wgpu::BindGroup = + webgpu_params!(self, fwd_pipeline; lhs.data, rhs.data, output); + let _idx = + self.submit_basic_op(&fwd_pipeline, ¶ms, Some(K::FWD_FN_NAME), &work_groups); + } + Ok(self.build_tensor(shape, strides, output)) + } + + fn backward( + &self, + op: K, + lhs: &impl crate::prelude::Tensorlike, + grad_lhs: &mut Self::Vec, + rhs: &impl crate::prelude::Tensorlike, + grad_rhs: &mut Self::Vec, + grad_out: &Self::Vec, + ) -> Result<(), Error> { + todo!("Webgpu binary backwards") + } +}