From 7106dec57901c3e580ed130589fbe0965b5db3da Mon Sep 17 00:00:00 2001 From: Dylan Lim Date: Tue, 15 Oct 2024 20:24:27 -0700 Subject: [PATCH] #1409 issue, change datatype for linear kernels away from void * --- lib/kernels/include/kernels/linear_kernels.h | 22 +++--- lib/kernels/src/cuda/ops/linear_kernels.cu | 76 +++++++++++--------- lib/local-execution/src/ops/linear.cc | 14 ++-- 3 files changed, 59 insertions(+), 53 deletions(-) diff --git a/lib/kernels/include/kernels/linear_kernels.h b/lib/kernels/include/kernels/linear_kernels.h index 99549adece..cff6563629 100644 --- a/lib/kernels/include/kernels/linear_kernels.h +++ b/lib/kernels/include/kernels/linear_kernels.h @@ -50,23 +50,23 @@ bool use_activation(Activation activation); void forward_kernel(ffStream_t stream, LinearPerDeviceState const &m, - void const *input_ptr, - void *output_ptr, - void const *filter_ptr, - void const *bias_ptr, + float const *input_ptr, + float *output_ptr, + float const *filter_ptr, + float const *bias_ptr, int in_dim, int out_dim, int batch_size); void backward_kernel(ffStream_t stream, LinearPerDeviceState const &m, - void const *input_ptr, - void *input_grad_ptr, - void const *output_ptr, - void *output_grad_ptr, - void const *kernel_ptr, - void *kernel_grad_ptr, - void *bias_ptr, + float const *input_ptr, + float *input_grad_ptr, + float const *output_ptr, + float *output_grad_ptr, + float const *kernel_ptr, + float *kernel_grad_ptr, + float *bias_ptr, int in_dim, int out_dim, int batch_size); diff --git a/lib/kernels/src/cuda/ops/linear_kernels.cu b/lib/kernels/src/cuda/ops/linear_kernels.cu index ca51f0d216..29b77fd9d9 100644 --- a/lib/kernels/src/cuda/ops/linear_kernels.cu +++ b/lib/kernels/src/cuda/ops/linear_kernels.cu @@ -108,10 +108,10 @@ LinearPerDeviceState init_kernel(PerDeviceFFHandle handle, void forward_kernel(cudaStream_t stream, LinearPerDeviceState const &m, - void const *input_ptr, - void *output_ptr, - void const *weight_ptr, - void const *bias_ptr, + float const *input_ptr, + float *output_ptr, + float const *weight_ptr, + float const *bias_ptr, int in_dim, int out_dim, int batch_size) { @@ -135,14 +135,14 @@ void forward_kernel(cudaStream_t stream, batch_size, in_dim, &alpha, - weight_ptr, + (void *)weight_ptr, weight_type, in_dim, - input_ptr, + (void *)input_ptr, input_type, in_dim, &beta, - output_ptr, + (void *)output_ptr, output_type, out_dim, compute_type, @@ -156,14 +156,14 @@ void forward_kernel(cudaStream_t stream, batch_size, 1, &alpha, - bias_ptr, + (void *)bias_ptr, weight_type, 1, - m.one_ptr, + (void *)m.one_ptr, CUDA_R_32F, 1, &alpha, - output_ptr, + (void *)output_ptr, output_type, out_dim, compute_type, @@ -174,10 +174,10 @@ void forward_kernel(cudaStream_t stream, m.actiDesc, &alpha, m.outputTensor, - output_ptr, + (void *)output_ptr, &beta, m.outputTensor, - output_ptr)); + (void *)output_ptr)); } else if (m.activation == Activation::GELU) { size_t elements = size_t_from_int(out_dim) * size_t_from_int(batch_size); constexpr float B = 0.7978845608028654f; // sqrt(2.0/M_PI) @@ -191,13 +191,13 @@ void forward_kernel(cudaStream_t stream, void backward_kernel(cudaStream_t stream, LinearPerDeviceState const &m, - void const *input_ptr, - void *input_grad_ptr, - void const *output_ptr, - void *output_grad_ptr, - void const *kernel_ptr, - void *kernel_grad_ptr, - void *bias_grad_ptr, + float const *input_ptr, + float *input_grad_ptr, + float const *output_ptr, + float *output_grad_ptr, + float const *kernel_ptr, + float *kernel_grad_ptr, + float *bias_grad_ptr, int in_dim, int out_dim, int batch_size) { @@ -216,11 +216,17 @@ void backward_kernel(cudaStream_t stream, int output_size = out_dim * batch_size; if (m.activation.has_value()) { if (m.activation == Activation::RELU) { - relu_backward_kernel( - m.output_type, output_grad_ptr, output_ptr, output_size, stream); + relu_backward_kernel(m.output_type, + (void *)output_grad_ptr, + (void *)output_ptr, + output_size, + stream); } else if (m.activation == Activation::SIGMOID) { - sigmoid_backward_kernel( - m.output_type, output_grad_ptr, output_ptr, output_size, stream); + sigmoid_backward_kernel(m.output_type, + (void *)output_grad_ptr, + (void *)output_ptr, + output_size, + stream); } else { // TODO: only support relu and sigmoid for now assert(false && "Unsupported activation for Linear"); @@ -235,14 +241,14 @@ void backward_kernel(cudaStream_t stream, out_dim, batch_size, &alpha, - input_ptr, + (void *)input_ptr, input_type, in_dim, - output_grad_ptr, + (void *)output_grad_ptr, output_type, out_dim, &alpha, - kernel_grad_ptr, + (void *)kernel_grad_ptr, weight_type, in_dim, compute_type, @@ -261,12 +267,12 @@ void backward_kernel(cudaStream_t stream, in_dim, out_dim, &alpha, - (float *)kernel_grad_ptr, + kernel_grad_ptr, in_dim, &lambda, - (float *)kernel_ptr, + kernel_ptr, in_dim, - (float *)kernel_grad_ptr, + kernel_grad_ptr, in_dim)); } else { assert(false && "Only L2 regularization is supported"); @@ -284,14 +290,14 @@ void backward_kernel(cudaStream_t stream, out_dim, batch_size, &alpha, - m.one_ptr, + (void *)m.one_ptr, CUDA_R_32F, 1, - output_grad_ptr, + (void *)output_grad_ptr, output_type, out_dim, &alpha, - bias_grad_ptr, + (void *)bias_grad_ptr, weight_type, 1, compute_type, @@ -307,14 +313,14 @@ void backward_kernel(cudaStream_t stream, batch_size, out_dim, &alpha, - kernel_ptr, + (void *)kernel_ptr, weight_type, in_dim, - output_grad_ptr, + (void *)output_grad_ptr, output_type, out_dim, &alpha, - input_grad_ptr, + (void *)input_grad_ptr, input_type, in_dim, compute_type, diff --git a/lib/local-execution/src/ops/linear.cc b/lib/local-execution/src/ops/linear.cc index 9934e2a45c..860eedaa1c 100644 --- a/lib/local-execution/src/ops/linear.cc +++ b/lib/local-execution/src/ops/linear.cc @@ -148,13 +148,13 @@ static std::optional profiling, "[Linear] backward_time = {:.2lf}ms\n", per_device_state, - (void *)input.get_float_ptr(), - (void *)input_grad.get_float_ptr(), - (void *)output.get_float_ptr(), - (void *)output_grad.get_float_ptr(), - (void *)weight.get_float_ptr(), - (void *)weight_grad.get_float_ptr(), - (void *)bias_ptr, + input.get_float_ptr(), + (float *)input_grad.get_float_ptr(), + output.get_float_ptr(), + (float *)output_grad.get_float_ptr(), + weight.get_float_ptr(), + (float *)weight_grad.get_float_ptr(), + (float *)bias_ptr, in_dim, out_dim, batch_size);