diff --git a/lib/kernels/src/hip/layer_norm_kernels.cpp b/lib/kernels/src/hip/layer_norm_kernels.cpp index dc2685ef28..247e8f3785 100644 --- a/lib/kernels/src/hip/layer_norm_kernels.cpp +++ b/lib/kernels/src/hip/layer_norm_kernels.cpp @@ -14,7 +14,8 @@ */ #include "kernels/layer_norm_kernels.h" -#include "kernels/hip_helper.h" +#include "kernels/accessor.h" +#include "kernels/datatype_dispatch.h" #include namespace FlexFlow { @@ -24,57 +25,330 @@ constexpr int kCUDABlockReduceNumThreads = 512; constexpr int kCUDANumThreads = 256; constexpr int kColwiseReduceTileSize = 32; -LayerNormPerDeviceState::LayerNormPerDeviceState( - FFHandler handle, - bool elementwise_affine_, - int64_t effective_batch_size_, - int64_t effective_num_elements_, - bool profiling_, - float eps_) - : PerDeviceOpState(handle) { - elementwise_affine = elementwise_affine_; - effective_batch_size = effective_batch_size_; - effective_num_elements = effective_num_elements_; - profiling = profiling_; - eps = eps_; - checkCUDA(hipMalloc(&mean_ptr, sizeof(float) * effective_batch_size)); - checkCUDA(hipMalloc(&rstd_ptr, sizeof(float) * effective_batch_size)); - checkCUDA(hipMalloc(&ds_ptr, sizeof(float) * effective_batch_size)); - checkCUDA(hipMalloc(&db_ptr, sizeof(float) * effective_batch_size)); - checkCUDA(hipMalloc(&scale_ptr, sizeof(float) * effective_batch_size)); - checkCUDA(hipMalloc(&bias_ptr, sizeof(float) * effective_batch_size)); -} - namespace Kernels { namespace LayerNorm { +template +__device__ __forceinline__ T WARP_SHFL_DOWN(T value, + unsigned int delta, + int width = warpSize, + unsigned int mask = 0xffffffff) { +#ifndef __HIP_PLATFORM_HCC__ + return __shfl_down_sync(mask, value, delta, width); +#else + return __shfl_down(value, delta, width); +#endif +} + +template +__inline__ __device__ T WarpReduceSum(T val) { +#pragma unroll + for (int offset = (C10_WARP_SIZE >> 1); offset > 0; offset >>= 1) { + val += WARP_SHFL_DOWN(val, offset); + } + return val; +} + +template +__inline__ __device__ T BlockReduceSum(T val, T *shared) { + int const lid = threadIdx.x % C10_WARP_SIZE; + int const wid = threadIdx.x / C10_WARP_SIZE; + val = WarpReduceSum(val); + __syncthreads(); + if (lid == 0) { + shared[wid] = val; + } + __syncthreads(); + val = (threadIdx.x < blockDim.x / C10_WARP_SIZE) ? shared[lid] : 0; + if (wid == 0) { + val = WarpReduceSum(val); + } + return val; +} + +template +__global__ void + RowwiseMomentsCUDAKernel(int64_t N, T eps, T const *X, T *mean, T *rstd) { + __shared__ T m_shared[C10_WARP_SIZE]; + __shared__ T v_shared[C10_WARP_SIZE]; + const int64_t i = blockIdx.x; + T sum1 = 0; + T sum2 = 0; + for (int64_t j = threadIdx.x; j < N; j += blockDim.x) { + const int64_t index = i * N + j; + sum1 += static_cast(X[index]); + sum2 += static_cast(X[index]) * static_cast(X[index]); + } + sum1 = BlockReduceSum(sum1, m_shared); + sum2 = BlockReduceSum(sum2, v_shared); + if (threadIdx.x == 0) { + const T scale = T(1) / static_cast(N); + sum1 *= scale; + sum2 = max(sum2 * scale - sum1 * sum1, T(0)); + mean[i] = sum1; + rstd[i] = rsqrt(sum2 + static_cast(eps)); + } +} + +template +__global__ void LayerNormForwardCUDAKernel(int64_t N, + T const *X, + T const *mean, + T const *rstd, + T const *gamma, + T const *beta, + T *Y) { + using T_ACC = T; + const int64_t i = blockIdx.x; + for (int64_t j = threadIdx.x; j < N; j += blockDim.x) { + const int64_t index = i * N + j; + const T_ACC gamma_v = + gamma == nullptr ? T_ACC(1) : static_cast(gamma[j]); + const T_ACC beta_v = + beta == nullptr ? T_ACC(0) : static_cast(beta[j]); + Y[index] = (static_cast(X[index]) - static_cast(mean[i])) * + static_cast(rstd[i]) * gamma_v + + beta_v; + } +} + +template +__global__ void ComputeInternalGradientsCUDAKernel( + int64_t N, T const *dY, T const *X, T const *gamma, T *ds, T *db) { + using T_ACC = T; + __shared__ T_ACC ds_shared[C10_WARP_SIZE]; + __shared__ T_ACC db_shared[C10_WARP_SIZE]; + const int64_t i = blockIdx.x; + T_ACC sum1 = 0; + T_ACC sum2 = 0; + for (int64_t j = threadIdx.x; j < N; j += blockDim.x) { + const int64_t index = i * N + j; + const T_ACC gamma_v = + gamma == nullptr ? T_ACC(1) : static_cast(gamma[j]); + sum1 += + static_cast(dY[index]) * static_cast(X[index]) * gamma_v; + sum2 += static_cast(dY[index]) * gamma_v; + } + sum1 = BlockReduceSum(sum1, ds_shared); + sum2 = BlockReduceSum(sum2, db_shared); + if (threadIdx.x == 0) { + ds[i] = sum1; + db[i] = sum2; + } +} + +template +__global__ void ComputeGradientFusedParamsCUDAKernel(int64_t M, + int64_t N, + T const *mean, + T const *rstd, + T const *ds, + T const *db, + T *c1, + T *c2) { + using T_ACC = T; + const int64_t index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < M) { + const T_ACC s = T_ACC(1) / static_cast(N); + const T_ACC a = (db[index] * static_cast(mean[index]) - ds[index]) * + static_cast(rstd[index]) * + static_cast(rstd[index]) * + static_cast(rstd[index]) * s; + c1[index] = a; + c2[index] = -(a * static_cast(mean[index]) + + db[index] * static_cast(rstd[index]) * s); + } +} + +template +__global__ void LayerNormBackwardCUDAKenrel(int64_t N, + T const *dY, + T const *X, + T const *gamma, + T const *a, + T const *b, + T const *c, + T *dX) { + using T_ACC = T; + const int64_t i = blockIdx.x; + for (int64_t j = threadIdx.x; j < N; j += blockDim.x) { + const int64_t index = i * N + j; + const T_ACC gamma_v = + gamma == nullptr ? T_ACC(1) : static_cast(gamma[j]); + dX[index] = + static_cast(a[i]) * static_cast(dY[index]) * gamma_v + + b[i] * static_cast(X[index]) + c[i]; + } +} + +template +__global__ void GammaBetaBackwardSimpleCUDAKernel(int64_t M, + int64_t N, + T const *dY, + T const *X, + T const *mean, + T const *rstd, + T *dg, + T *db) { + using T_ACC = T; + const int64_t j = blockIdx.x * blockDim.x + threadIdx.x; + if (j < N) { + T_ACC sum1 = 0; + T_ACC sum2 = 0; + for (int64_t i = 0; i < M; ++i) { + const int64_t index = i * N + j; + sum1 += dg == nullptr ? T_ACC(0) + : static_cast(dY[index]) * + (static_cast(X[index]) - + static_cast(mean[i])) * + static_cast(rstd[i]); + sum2 += db == nullptr ? T_ACC(0) : static_cast(dY[index]); + } + if (dg != nullptr) { + dg[j] = sum1; + } + if (db != nullptr) { + db[j] = sum2; + } + } +} + +template +__global__ void GammaBetaBackwardCUDAKernel(int64_t M, + int64_t N, + T const *dY, + T const *X, + T const *mean, + T const *rstd, + T *dg, + T *db) { + using T_ACC = T; + __shared__ T_ACC g_shared[kColwiseReduceTileSize][kColwiseReduceTileSize + 1]; + __shared__ T_ACC b_shared[kColwiseReduceTileSize][kColwiseReduceTileSize + 1]; + const int64_t j = blockIdx.x * blockDim.x + threadIdx.x; + T_ACC dg_sum1 = 0; + T_ACC dg_sum2 = 0; + T_ACC db_sum1 = 0; + T_ACC db_sum2 = 0; + if (j < N) { + for (int64_t i = threadIdx.y; i < M; i += blockDim.y * 2) { + const int64_t i1 = i; + const int64_t i2 = i + blockDim.y; + const int64_t index1 = i1 * N + j; + const int64_t index2 = i2 * N + j; + dg_sum1 += dg == nullptr ? T_ACC(0) + : static_cast(dY[index1]) * + (static_cast(X[index1]) - + static_cast(mean[i1])) * + static_cast(rstd[i1]); + db_sum1 += db == nullptr ? T_ACC(0) : static_cast(dY[index1]); + if (i2 < M) { + dg_sum2 += dg == nullptr ? T_ACC(0) + : static_cast(dY[index2]) * + (static_cast(X[index2]) - + static_cast(mean[i2])) * + static_cast(rstd[i2]); + db_sum2 += db == nullptr ? T_ACC(0) : static_cast(dY[index2]); + } + } + } + g_shared[threadIdx.y][threadIdx.x] = dg_sum1; + g_shared[threadIdx.y + blockDim.y][threadIdx.x] = dg_sum2; + b_shared[threadIdx.y][threadIdx.x] = db_sum1; + b_shared[threadIdx.y + blockDim.y][threadIdx.x] = db_sum2; + __syncthreads(); + T_ACC sum1 = g_shared[threadIdx.x][threadIdx.y]; + T_ACC sum2 = b_shared[threadIdx.x][threadIdx.y]; + sum1 = WarpReduceSum(sum1); + sum2 = WarpReduceSum(sum2); + if (threadIdx.x == 0) { + const int64_t j = blockIdx.x * blockDim.x + threadIdx.y; + if (j < N) { + if (dg != nullptr) { + dg[j] = sum1; + } + if (db != nullptr) { + db[j] = sum2; + } + } + } + sum1 = g_shared[threadIdx.x][threadIdx.y + blockDim.y]; + sum2 = b_shared[threadIdx.x][threadIdx.y + blockDim.y]; + sum1 = WarpReduceSum(sum1); + sum2 = WarpReduceSum(sum2); + if (threadIdx.x == 0) { + const int64_t j = blockIdx.x * blockDim.x + threadIdx.y + blockDim.y; + if (j < N) { + if (dg != nullptr) { + dg[j] = sum1; + } + if (db != nullptr) { + db[j] = sum2; + } + } + } +} + +LayerNormPerDeviceState init_kernel(PerDeviceFFHandle const &handle, + Allocator const &allocator, + bool elementwise_affine_, + int64_t effective_batch_size_, + int64_t effective_num_elements_, + float eps_) { + float *mean = + (float *)allocator.allocate(sizeof(float) * effective_batch_size_); + float *rstd = + (float *)allocator.allocate(sizeof(float) * effective_batch_size_); + float *ds = + (float *)allocator.allocate(sizeof(float) * effective_batch_size_); + float *db = + (float *)allocator.allocate(sizeof(float) * effective_batch_size_); + float *scale = + (float *)allocator.allocate(sizeof(float) * effective_batch_size_); + float *bias = + (float *)allocator.allocate(sizeof(float) * effective_batch_size_); + LayerNormPerDeviceState per_device_state = {handle, + elementwise_affine_, + effective_batch_size_, + effective_num_elements_, + eps_, + mean, + rstd, + ds, + db, + scale, + bias, + DataType::FLOAT}; + return per_device_state; +} + template struct ForwardKernel { void operator()(hipStream_t stream, - LayerNormPerDeviceState const *m, + LayerNormPerDeviceState const &m, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output, GenericTensorAccessorW const &gamma, GenericTensorAccessorW const &beta) { hipLaunchKernelGGL(HIP_KERNEL_NAME(RowwiseMomentsCUDAKernel), - m->effective_batch_size, + m.effective_batch_size, kCUDABlockReduceNumThreads, 0, stream, - m->effective_num_elements, - m->eps, + m.effective_num_elements, + m.eps, input.get(), - m->mean_ptr, - m->rstd_ptr); + m.mean_ptr, + m.rstd_ptr); hipLaunchKernelGGL(HIP_KERNEL_NAME(LayerNormForwardCUDAKernel), - m->effective_batch_size, + m.effective_batch_size, kCUDANumThreads, 0, stream, - m->effective_num_elements, + m.effective_num_elements, input.get(), - m->mean_ptr, - m->rstd_ptr, + m.mean_ptr, + m.rstd_ptr, gamma.get(), beta.get(), output.get()); @@ -84,15 +358,15 @@ struct ForwardKernel { template struct BackwardKernel { void operator()(hipStream_t stream, - LayerNormPerDeviceState const *m, + LayerNormPerDeviceState const &m, GenericTensorAccessorR const &output_grad, GenericTensorAccessorR const &input, GenericTensorAccessorW const &input_grad, GenericTensorAccessorR const &gamma, GenericTensorAccessorW const &gamma_grad, GenericTensorAccessorW const &beta_grad) { - const int64_t M = m->effective_batch_size; - const int64_t N = m->effective_num_elements; + const int64_t M = m.effective_batch_size; + const int64_t N = m.effective_num_elements; hipLaunchKernelGGL(HIP_KERNEL_NAME(ComputeInternalGradientsCUDAKernel), M, kCUDABlockReduceNumThreads, @@ -102,8 +376,8 @@ struct BackwardKernel { output_grad.get(), input.get(), gamma.get(), - m->ds_ptr, - m->db_ptr); + m.ds_ptr, + m.db_ptr); const int64_t B = (M + kCUDANumThreads - 1) / kCUDANumThreads; hipLaunchKernelGGL(HIP_KERNEL_NAME(ComputeGradientFusedParamsCUDAKernel), B, @@ -112,12 +386,12 @@ struct BackwardKernel { stream, M, N, - m->mean_ptr, - m->rstd_ptr, - m->ds_ptr, - m->db_ptr, - m->scale_ptr, - m->bias_ptr); + m.mean_ptr, + m.rstd_ptr, + m.ds_ptr, + m.db_ptr, + m.scale_ptr, + m.bias_ptr); if (gamma_grad.get() != NULL || beta_grad.get() != NULL) { if (M < 512) { // For small batch size, do colwise reduce directly @@ -132,8 +406,8 @@ struct BackwardKernel { N, output_grad.get(), input.get(), - m->mean_ptr, - m->rstd_ptr, + m.mean_ptr, + m.rstd_ptr, gamma_grad.get(), beta_grad.get()); } else { @@ -150,8 +424,8 @@ struct BackwardKernel { N, output_grad.get(), input.get(), - m->mean_ptr, - m->rstd_ptr, + m.mean_ptr, + m.rstd_ptr, gamma_grad.get(), beta_grad.get()); } @@ -159,24 +433,24 @@ struct BackwardKernel { } void forward_kernel(hipStream_t stream, - LayerNormPerDeviceState const *m, + LayerNormPerDeviceState const &m, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output, GenericTensorAccessorW const &gamma, GenericTensorAccessorW const &beta) { DataTypeDispatch1{}( - m->data_type, stream, m, input, output, gamma, beta); + m.data_type, stream, m, input, output, gamma, beta); } void backward_kernel(hipStream_t stream, - LayerNormPerDeviceState const *m, + LayerNormPerDeviceState const &m, GenericTensorAccessorR const &output_grad, GenericTensorAccessorR const &input, GenericTensorAccessorW const &input_grad, GenericTensorAccessorR const &gamma, GenericTensorAccessorW const &gamma_grad, GenericTensorAccessorW const &beta_grad) { - DataTypeDispatch1{}(m->data_type, + DataTypeDispatch1{}(m.data_type, stream, m, output_grad, @@ -187,271 +461,6 @@ struct BackwardKernel { beta_grad); } - template - __device__ __forceinline__ T WARP_SHFL_DOWN(T value, - unsigned int delta, - int width = warpSize, - unsigned int mask = 0xffffffff) { -#if 0 -#ifndef __HIP_PLATFORM_HCC__ - return __shfl_down_sync(mask, value, delta, width); -#else - return __shfl_down(value, delta, width); -#endif -#endif - } - - template - __inline__ __device__ T WarpReduceSum(T val) { -#pragma unroll - for (int offset = (C10_WARP_SIZE >> 1); offset > 0; offset >>= 1) { - val += WARP_SHFL_DOWN(val, offset); - } - return val; - } - - template - __inline__ __device__ T BlockReduceSum(T val, T *shared) { - int const lid = threadIdx.x % C10_WARP_SIZE; - int const wid = threadIdx.x / C10_WARP_SIZE; - val = WarpReduceSum(val); - __syncthreads(); - if (lid == 0) { - shared[wid] = val; - } - __syncthreads(); - val = (threadIdx.x < blockDim.x / C10_WARP_SIZE) ? shared[lid] : 0; - if (wid == 0) { - val = WarpReduceSum(val); - } - return val; - } - - template - __global__ void - RowwiseMomentsCUDAKernel(int64_t N, T eps, T const *X, T *mean, T *rstd) { - __shared__ T m_shared[C10_WARP_SIZE]; - __shared__ T v_shared[C10_WARP_SIZE]; - const int64_t i = blockIdx.x; - T sum1 = 0; - T sum2 = 0; - for (int64_t j = threadIdx.x; j < N; j += blockDim.x) { - const int64_t index = i * N + j; - sum1 += static_cast(X[index]); - sum2 += static_cast(X[index]) * static_cast(X[index]); - } - sum1 = BlockReduceSum(sum1, m_shared); - sum2 = BlockReduceSum(sum2, v_shared); - if (threadIdx.x == 0) { - const T scale = T(1) / static_cast(N); - sum1 *= scale; - sum2 = max(sum2 * scale - sum1 * sum1, T(0)); - mean[i] = sum1; - rstd[i] = rsqrt(sum2 + static_cast(eps)); - } - } - - template - __global__ void LayerNormForwardCUDAKernel(int64_t N, - T const *X, - T const *mean, - T const *rstd, - T const *gamma, - T const *beta, - T *Y) { - using T_ACC = T; - const int64_t i = blockIdx.x; - for (int64_t j = threadIdx.x; j < N; j += blockDim.x) { - const int64_t index = i * N + j; - const T_ACC gamma_v = - gamma == nullptr ? T_ACC(1) : static_cast(gamma[j]); - const T_ACC beta_v = - beta == nullptr ? T_ACC(0) : static_cast(beta[j]); - Y[index] = (static_cast(X[index]) - static_cast(mean[i])) * - static_cast(rstd[i]) * gamma_v + - beta_v; - } - } - - template - __global__ void ComputeInternalGradientsCUDAKernel( - int64_t N, T const *dY, T const *X, T const *gamma, T *ds, T *db) { - using T_ACC = T; - __shared__ T_ACC ds_shared[C10_WARP_SIZE]; - __shared__ T_ACC db_shared[C10_WARP_SIZE]; - const int64_t i = blockIdx.x; - T_ACC sum1 = 0; - T_ACC sum2 = 0; - for (int64_t j = threadIdx.x; j < N; j += blockDim.x) { - const int64_t index = i * N + j; - const T_ACC gamma_v = - gamma == nullptr ? T_ACC(1) : static_cast(gamma[j]); - sum1 += static_cast(dY[index]) * static_cast(X[index]) * - gamma_v; - sum2 += static_cast(dY[index]) * gamma_v; - } - sum1 = BlockReduceSum(sum1, ds_shared); - sum2 = BlockReduceSum(sum2, db_shared); - if (threadIdx.x == 0) { - ds[i] = sum1; - db[i] = sum2; - } - } - - template - __global__ void ComputeGradientFusedParamsCUDAKernel(int64_t M, - int64_t N, - T const *mean, - T const *rstd, - T const *ds, - T const *db, - T *c1, - T *c2) { - using T_ACC = T; - const int64_t index = blockIdx.x * blockDim.x + threadIdx.x; - if (index < M) { - const T_ACC s = T_ACC(1) / static_cast(N); - const T_ACC a = - (db[index] * static_cast(mean[index]) - ds[index]) * - static_cast(rstd[index]) * static_cast(rstd[index]) * - static_cast(rstd[index]) * s; - c1[index] = a; - c2[index] = -(a * static_cast(mean[index]) + - db[index] * static_cast(rstd[index]) * s); - } - } - - template - __global__ void LayerNormBackwardCUDAKenrel(int64_t N, - T const *dY, - T const *X, - T const *gamma, - T const *a, - T const *b, - T const *c, - T *dX) { - using T_ACC = T; - const int64_t i = blockIdx.x; - for (int64_t j = threadIdx.x; j < N; j += blockDim.x) { - const int64_t index = i * N + j; - const T_ACC gamma_v = - gamma == nullptr ? T_ACC(1) : static_cast(gamma[j]); - dX[index] = - static_cast(a[i]) * static_cast(dY[index]) * gamma_v + - b[i] * static_cast(X[index]) + c[i]; - } - } - - template - __global__ void GammaBetaBackwardSimpleCUDAKernel(int64_t M, - int64_t N, - T const *dY, - T const *X, - T const *mean, - T const *rstd, - T *dg, - T *db) { - using T_ACC = T; - const int64_t j = blockIdx.x * blockDim.x + threadIdx.x; - if (j < N) { - T_ACC sum1 = 0; - T_ACC sum2 = 0; - for (int64_t i = 0; i < M; ++i) { - const int64_t index = i * N + j; - sum1 += dg == nullptr ? T_ACC(0) - : static_cast(dY[index]) * - (static_cast(X[index]) - - static_cast(mean[i])) * - static_cast(rstd[i]); - sum2 += db == nullptr ? T_ACC(0) : static_cast(dY[index]); - } - if (dg != nullptr) { - dg[j] = sum1; - } - if (db != nullptr) { - db[j] = sum2; - } - } - } - - template - __global__ void GammaBetaBackwardCUDAKernel(int64_t M, - int64_t N, - T const *dY, - T const *X, - T const *mean, - T const *rstd, - T *dg, - T *db) { - using T_ACC = T; - __shared__ T_ACC - g_shared[kColwiseReduceTileSize][kColwiseReduceTileSize + 1]; - __shared__ T_ACC - b_shared[kColwiseReduceTileSize][kColwiseReduceTileSize + 1]; - const int64_t j = blockIdx.x * blockDim.x + threadIdx.x; - T_ACC dg_sum1 = 0; - T_ACC dg_sum2 = 0; - T_ACC db_sum1 = 0; - T_ACC db_sum2 = 0; - if (j < N) { - for (int64_t i = threadIdx.y; i < M; i += blockDim.y * 2) { - const int64_t i1 = i; - const int64_t i2 = i + blockDim.y; - const int64_t index1 = i1 * N + j; - const int64_t index2 = i2 * N + j; - dg_sum1 += dg == nullptr ? T_ACC(0) - : static_cast(dY[index1]) * - (static_cast(X[index1]) - - static_cast(mean[i1])) * - static_cast(rstd[i1]); - db_sum1 += db == nullptr ? T_ACC(0) : static_cast(dY[index1]); - if (i2 < M) { - dg_sum2 += dg == nullptr ? T_ACC(0) - : static_cast(dY[index2]) * - (static_cast(X[index2]) - - static_cast(mean[i2])) * - static_cast(rstd[i2]); - db_sum2 += db == nullptr ? T_ACC(0) : static_cast(dY[index2]); - } - } - } - g_shared[threadIdx.y][threadIdx.x] = dg_sum1; - g_shared[threadIdx.y + blockDim.y][threadIdx.x] = dg_sum2; - b_shared[threadIdx.y][threadIdx.x] = db_sum1; - b_shared[threadIdx.y + blockDim.y][threadIdx.x] = db_sum2; - __syncthreads(); - T_ACC sum1 = g_shared[threadIdx.x][threadIdx.y]; - T_ACC sum2 = b_shared[threadIdx.x][threadIdx.y]; - sum1 = WarpReduceSum(sum1); - sum2 = WarpReduceSum(sum2); - if (threadIdx.x == 0) { - const int64_t j = blockIdx.x * blockDim.x + threadIdx.y; - if (j < N) { - if (dg != nullptr) { - dg[j] = sum1; - } - if (db != nullptr) { - db[j] = sum2; - } - } - } - sum1 = g_shared[threadIdx.x][threadIdx.y + blockDim.y]; - sum2 = b_shared[threadIdx.x][threadIdx.y + blockDim.y]; - sum1 = WarpReduceSum(sum1); - sum2 = WarpReduceSum(sum2); - if (threadIdx.x == 0) { - const int64_t j = blockIdx.x * blockDim.x + threadIdx.y + blockDim.y; - if (j < N) { - if (dg != nullptr) { - dg[j] = sum1; - } - if (db != nullptr) { - db[j] = sum2; - } - } - } - } - } // namespace LayerNorm } // namespace Kernels } // namespace FlexFlow diff --git a/lib/kernels/src/hip/linear_kernels.cpp b/lib/kernels/src/hip/linear_kernels.cpp index 7d5626b6e8..972af9b9b1 100644 --- a/lib/kernels/src/hip/linear_kernels.cpp +++ b/lib/kernels/src/hip/linear_kernels.cpp @@ -19,69 +19,87 @@ namespace FlexFlow { -LinearPerDeviceState::LinearPerDeviceState(FFHandler handler, int batch_size) - : PerDeviceOpState(handler) { - // Allocate an all-one's vector - float *dram_one_ptr = (float *)malloc(sizeof(float) * batch_size); - for (int i = 0; i < batch_size; i++) { - dram_one_ptr[i] = 1.0f; - } - float *fb_one_ptr; - checkCUDA(hipMalloc(&fb_one_ptr, sizeof(float) * batch_size)); - checkCUDA(hipMemcpy(fb_one_ptr, - dram_one_ptr, - sizeof(float) * batch_size, - hipMemcpyHostToDevice)); - one_ptr = (float const *)fb_one_ptr; - // Allocate descriptors - checkCUDNN(miopenCreateActivationDescriptor(&actiDesc)); - checkCUDNN(miopenCreateTensorDescriptor(&outputTensor)); -} - namespace Kernels { namespace Linear { -bool use_activation(ActiMode mode) { - switch (mode) { - case AC_MODE_RELU: - case AC_MODE_SIGMOID: - case AC_MODE_TANH: - return true; - case AC_MODE_NONE: - return false; - default: - assert(0); - break; +bool use_activation(std::optional activation) { + if (activation.has_value()) { + switch (activation.value()) { + case Activation::RELU: + case Activation::SIGMOID: + case Activation::TANH: + return true; + case Activation::GELU: + return false; + default: + assert(false && "Unsupported activation for Linear"); + break; + } } return false; } -void init_kernel(LinearPerDeviceState *m, int batch_size, int channel) { - if (use_activation(m->activation)) { - miopenActivationMode_t mode; - switch (m->activation) { - case AC_MODE_RELU: - mode = miopenActivationRELU; +LinearPerDeviceState + init_kernel(PerDeviceFFHandle handle, Allocator allocator, float *one_ptr; + ActiMode activation, + Regularizer regularizer, + bool use_bias, + DataType input_type, + DataType weight_type, + DataType output_type, + int batch_size, + int channel) { + ffTensorDescriptor_t outputTensor; + ffActivationDescriptor_t actiDesc; + checkCUDNN(miopenCreateTensorDescriptor(&outputTensor)); + checkCUDNN(miopenSetActivationDescriptor(actiDesc, mode, 0.0, 0.0, 0.0)); + checkCUDNN(miopenSet4dTensorDescriptor(outputTensor, + ff_to_cudnn_datatype(output_type), + batch_size, + channel, + 1, + 1)); + + miopenActivationMode_t mode; + if (activation.has_value()) { + switch (activation.value()) { + case Activation::RELU: + mode = CUDNN_ACTIVATION_RELU; + break; + case Activation::SIGMOID: + mode = CUDNN_ACTIVATION_SIGMOID; break; - case AC_MODE_SIGMOID: - mode = miopenActivationLOGISTIC; + case Activation::TANH: + mode = CUDNN_ACTIVATION_TANH; + break; + case Activation::GELU: + // mode = CUDNN_ACTIVATION_GELU; //cudnnActivationMode_t does not have + // GELU break; default: // Unsupported activation mode assert(false); } - checkCUDNN(miopenSetActivationDescriptor(m->actiDesc, mode, 0.0, 0.0, 0.0)); - checkCUDNN(miopenSet4dTensorDescriptor(m->outputTensor, - ff_to_cudnn_datatype(m->output_type), - batch_size, - channel, - 1, - 1)); } + checkCUDNN(miopenSetActivationDescriptor(actiDesc, mode, 0.0, 0.0, 0.0)); + // todo: how to use allocator to allocate memory for float * one_ptr, how many + // bytes to allocate? + checkCUDA(hipMalloc(&one_ptr, sizeof(float) * batch_size)); + LinearPerDeviceState per_device_state = {handle, + outputTensor, + actiDesc, + one_ptr, + activation, + regularizer, + use_bias, + input_type, + weight_type, + output_type}; + return per_device_state; } void forward_kernel(hipStream_t stream, - LinearPerDeviceState const *m, + LinearPerDeviceState const &m, void const *input_ptr, void *output_ptr, void const *weight_ptr, @@ -90,19 +108,19 @@ void forward_kernel(hipStream_t stream, int out_dim, int batch_size) { - checkCUDA(hipblasSetStream(m->handle.blas, stream)); - checkCUDNN(miopenSetStream(m->handle.dnn, stream)); + checkCUDA(hipblasSetStream(m.handle.blas, stream)); + checkCUDNN(miopenSetStream(m.handle.dnn, stream)); float alpha = 1.0f, beta = 0.0f; - hipblasDatatype_t input_type = ff_to_cuda_datatype(m->input_type); - hipblasDatatype_t weight_type = ff_to_cuda_datatype(m->weight_type); - hipblasDatatype_t output_type = ff_to_cuda_datatype(m->output_type); + hipblasDatatype_t input_type = ff_to_cuda_datatype(m.input_type); + hipblasDatatype_t weight_type = ff_to_cuda_datatype(m.weight_type); + hipblasDatatype_t output_type = ff_to_cuda_datatype(m.output_type); #if CUDA_VERSION >= 11000 // TODO: currently set the default to CUBLAS_COMPUTE_16F for best performance cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F; #else hipblasDatatype_t compute_type = HIPBLAS_R_32F; #endif - checkCUDA(hipblasGemmEx(m->handle.blas, + checkCUDA(hipblasGemmEx(m.handle.blas, HIPBLAS_OP_T, HIPBLAS_OP_N, out_dim, @@ -123,7 +141,7 @@ void forward_kernel(hipStream_t stream, HIPBLAS_GEMM_DEFAULT)); // use_bias = True if (bias_ptr != NULL) { - checkCUDA(hipblasGemmEx(m->handle.blas, + checkCUDA(hipblasGemmEx(m.handle.blas, HIPBLAS_OP_T, HIPBLAS_OP_N, out_dim, @@ -133,7 +151,7 @@ void forward_kernel(hipStream_t stream, bias_ptr, weight_type, 1, - m->one_ptr, + m.one_ptr, HIPBLAS_R_32F, 1, &alpha, @@ -143,16 +161,16 @@ void forward_kernel(hipStream_t stream, compute_type, HIPBLAS_GEMM_DEFAULT)); } - if (use_activation(m->activation)) { - checkCUDNN(miopenActivationForward(m->handle.dnn, - m->actiDesc, + if (use_activation(m.activation)) { + checkCUDNN(miopenActivationForward(m.handle.dnn, + m.actiDesc, &alpha, - m->outputTensor, + m.outputTensor, output_ptr, &beta, - m->outputTensor, + m.outputTensor, output_ptr)); - } else if (m->activation == AC_MODE_GELU) { + } else if (m.activation == AC_MODE_GELU) { size_t elements = (size_t)out_dim * (size_t)batch_size; constexpr float B = 0.7978845608028654f; // sqrt(2.0/M_PI) constexpr float C = 0.035677408136300125f; // 0.044715 * sqrt(2.0/M_PI) @@ -165,15 +183,13 @@ void forward_kernel(hipStream_t stream, B, C, (float *)output_ptr); - } else if (m->activation == AC_MODE_NONE) { - // Do nothing } else { - assert(false && "Unsupported activation for Linear"); + // Do nothing } } void backward_kernel(hipStream_t stream, - LinearPerDeviceState const *m, + LinearPerDeviceState const &m, void const *input_ptr, void *input_grad_ptr, void const *output_ptr, @@ -185,13 +201,13 @@ void backward_kernel(hipStream_t stream, int out_dim, int batch_size) { - checkCUDA(hipblasSetStream(m->handle.blas, stream)); - checkCUDNN(miopenSetStream(m->handle.dnn, stream)); + checkCUDA(hipblasSetStream(m.handle.blas, stream)); + checkCUDNN(miopenSetStream(m.handle.dnn, stream)); float alpha = 1.0f; - hipblasDatatype_t input_type = ff_to_cuda_datatype(m->input_type); - hipblasDatatype_t weight_type = ff_to_cuda_datatype(m->weight_type); - hipblasDatatype_t output_type = ff_to_cuda_datatype(m->output_type); + hipblasDatatype_t input_type = ff_to_cuda_datatype(m.input_type); + hipblasDatatype_t weight_type = ff_to_cuda_datatype(m.weight_type); + hipblasDatatype_t output_type = ff_to_cuda_datatype(m.output_type); #if CUDA_VERSION >= 11000 // TODO: currently set the default to CUBLAS_COMPUTE_16F for best performance cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F; @@ -199,19 +215,21 @@ void backward_kernel(hipStream_t stream, hipblasDatatype_t compute_type = HIPBLAS_R_32F; #endif int output_size = out_dim * batch_size; - if (m->activation == AC_MODE_RELU) { - relu_backward_kernel( - m->output_type, output_grad_ptr, output_ptr, output_size, stream); - } else if (m->activation == AC_MODE_SIGMOID) { - sigmoid_backward_kernel( - m->output_type, output_grad_ptr, output_ptr, output_size, stream); - } else { - // TODO: only support relu and sigmoid for now - assert(m->activation == AC_MODE_NONE); + if (m.activation.has_value()) { + if (m.activation == Activation::RELU) { + relu_backward_kernel( + m.output_type, output_grad_ptr, output_ptr, output_size, stream); + } else if (m.activation == Activation::SIGMOID) { + sigmoid_backward_kernel( + m.output_type, output_grad_ptr, output_ptr, output_size, stream); + } else { + // TODO: only support relu and sigmoid for now + assert(false && "Unsupported activation for Linear"); + } } // Compute weight gradiant // NOTE: we use alpha=1 for kernel_grad to accumulate gradients - checkCUDA(hipblasGemmEx(m->handle.blas, + checkCUDA(hipblasGemmEx(m.handle.blas, HIPBLAS_OP_N, HIPBLAS_OP_T, in_dim, @@ -230,18 +248,44 @@ void backward_kernel(hipStream_t stream, in_dim, compute_type, HIPBLAS_GEMM_DEFAULT)); - // Compute bias gradiant + + if (m.regularizer == std::nullopt) { + // do nothing + } else { + RegularizerAttrs regularizer_attrs = m.regularizer.value(); + if (std::holds_alternative(regularizer_attrs)) { + L2RegularizerAttrs l2_attrs = + std::get(regularizer_attrs); + float lambda = l2_attrs.lambda; + checkCUDA(hipblasSgeam(m.handle.blas, + HIPBLAS_OP_N, + HIPBLAS_OP_N, + in_dim, + out_dim, + &alpha, + (float *)kernel_grad_ptr, + in_dim, + &(m.kernel_reg_lambda), + (float *)kernel_ptr, + in_dim, + (float *)kernel_grad_ptr, + in_dim)); + } else { + assert(false && "Only L2 regularization is supported"); + } + } + // compute bias gradient // NOTE: we use alpha=1 for bias_grad to accumulate gradients // use_bias = True if (bias_grad_ptr != NULL) { - checkCUDA(hipblasGemmEx(m->handle.blas, + checkCUDA(hipblasGemmEx(m.handle.blas, HIPBLAS_OP_N, HIPBLAS_OP_T, 1, out_dim, batch_size, &alpha, - m->one_ptr, + m.one_ptr, HIPBLAS_R_32F, 1, output_grad_ptr, @@ -257,7 +301,7 @@ void backward_kernel(hipStream_t stream, // Compute data gradiant // NOTE: we use alpha=1 for input_grad to accumulate gradients if (input_grad_ptr != NULL) { - checkCUDA(hipblasGemmEx(m->handle.blas, + checkCUDA(hipblasGemmEx(m.handle.blas, HIPBLAS_OP_N, HIPBLAS_OP_N, in_dim,