Skip to content

Commit

Permalink
refactor for softmax, split, topk, transpose (#1404)
Browse files Browse the repository at this point in the history
* refactor for softmax, split, topk, transpose

* fix

---------

Co-authored-by: Reyna Abhyankar <[email protected]>
  • Loading branch information
Bob-Chen222 and reyna-abhyankar authored Jun 5, 2024
1 parent 89cbe93 commit af1caf5
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 79 deletions.
34 changes: 15 additions & 19 deletions lib/kernels/src/hip/softmax_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,40 +14,36 @@
*/

#include "kernels/softmax_kernels.h"
#include "kernels/hip_helper.h"
#include "device.h"
#include <hip/hip_runtime.h>

namespace FlexFlow {
// declare Legion names
using Legion::Domain;

SoftmaxPerDeviceState::SoftmaxPerDeviceState(FFHandler handler,
Softmax const *softmax,
Domain const &input_domain)
: PerDeviceOpState(handler) {
checkCUDNN(miopenCreateTensorDescriptor(&inputTensor));
checkCUDNN(cudnnSetTensorDescriptorFromDomain(inputTensor, input_domain));
dim = softmax->dim;
profiling = softmax->profiling;
std::strcpy(op_name, softmax->name);
}

namespace Kernels {
namespace Softmax {

SoftmaxPerDeviceState init_kernel(PerDeviceFFHandle const &handle, int dim) {
ffTensorDescriptor_t inputTensor;

checkCUDNN(miopenCreateTensorDescriptor(&inputTensor));

SoftmaxPerDeviceState per_device_state = {handle, inputTensor, dim};
return per_device_state;
}

void forward_kernel(hipStream_t stream,
SoftmaxPerDeviceState const *m,
SoftmaxPerDeviceState const &m,
float const *input_ptr,
float *output_ptr) {
checkCUDNN(miopenSetStream(m->handle.dnn, stream));
checkCUDNN(miopenSetStream(m.handle.dnn, stream));

float alpha = 1.0f, beta = 0.0f;
checkCUDNN(miopenSoftmaxForward_V2(m->handle.dnn,
checkCUDNN(miopenSoftmaxForward_V2(m.handle.dnn,
&alpha,
m->inputTensor,
m.inputTensor,
input_ptr,
&beta,
m->inputTensor,
m.inputTensor,
output_ptr,
MIOPEN_SOFTMAX_ACCURATE,
MIOPEN_SOFTMAX_MODE_CHANNEL));
Expand Down
4 changes: 1 addition & 3 deletions lib/kernels/src/hip/split_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,10 @@
*/

#include "kernels/split_kernels.h"
#include "kernels/hip_helper.h"
#include "device.h"
#include <hip/hip_runtime.h>

namespace FlexFlow {
// declare Legion names
using Legion::coord_t;

namespace Kernels {
namespace Split {
Expand Down
16 changes: 8 additions & 8 deletions lib/kernels/src/hip/topk_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,10 @@
*/

#include "kernels/topk_kernels.h"
#include "kernels/hip_helper.h"
#include "device.h"
#include <hip/hip_runtime.h>

namespace FlexFlow {
// declare Legion names
using Legion::coord_t;

TopKPerDeviceState::TopKPerDeviceState(FFHandler handler)
: PerDeviceOpState(handler) {}

namespace Kernels {
namespace TopK {
Expand All @@ -36,6 +31,11 @@ struct Entry {
T value;
};

TopKPerDeviceState init_kernel(bool sorted) {
TopKPerDeviceState per_device_state = {sorted};
return per_device_state;
}

template <typename T>
struct LinearData {
typedef Entry<T> Entry;
Expand Down Expand Up @@ -371,7 +371,7 @@ __global__ void topk_forward_kernel(T const *__restrict__ input,
}

void forward_kernel(hipStream_t stream,
TopKPerDeviceState const *m,
TopKPerDeviceState const &m,
float const *input_ptr,
float *output_ptr,
int *indices_ptr,
Expand Down Expand Up @@ -428,7 +428,7 @@ __global__ void topk_backward_kernel(T const *__restrict__ value_grad_ptr,
}

void backward_kernel(hipStream_t stream,
TopKPerDeviceState const *m,
TopKPerDeviceState const &m,
float const *value_grad_ptr,
int const *indices_ptr,
float *in_grad_ptr,
Expand Down
119 changes: 70 additions & 49 deletions lib/kernels/src/hip/transpose_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,12 @@
*/

#include "kernels/transpose_kernels.h"
#include "kernels/hip_helper.h"
#include "device.h"
#include "kernels/accessor.h"
#include "utils/exception.h"
#include <hip/hip_runtime.h>

namespace FlexFlow {
// declare Legion names
using Legion::coord_t;
using Legion::Domain;

struct TransposeStrides {
int num_dim;
Expand All @@ -31,81 +30,103 @@ struct TransposeStrides {
namespace Kernels {
namespace Transpose {

TransposePerDeviceState init_kernel(int num_dim,
std::vector<ff_dim_t> const &perm) {
int const length = perm.size();

std::vector<int> perm_vector;
assert(length <= MAX_TENSOR_DIM);
for (int i = 0; i < length; ++i) {
perm_vector.push_back(perm[i].value());
}

return {num_dim, perm_vector};
}

__global__ void transpose_simple_kernel(std::size_t volume,
float const *in_ptr,
float *out_ptr,
const TransposeStrides info,
float const beta) {
CUDA_KERNEL_LOOP(o_idx, volume) {
coord_t i_idx = 0;
coord_t t = o_idx;
for (int i = info.num_dim - 1; i >= 0; i--) {
coord_t ratio = t / info.out_strides[i];
t -= ratio * info.out_strides[i];
i_idx += ratio * info.in_strides[info.perm[i]];
}
out_ptr[o_idx] += out_ptr[o_idx] * beta + in_ptr[i_idx];
}
}

void forward_kernel(hipStream_t stream,
TransposePerDeviceState const *m,
float const *input_ptr,
float *output_ptr,
Domain in_domain,
Domain out_domain) {
TransposePerDeviceState const &m,
GenericTensorAccessorW const &in_grad,
GenericTensorAccessorR const &out_grad) {

TransposeStrides info;
info.num_dim = out_domain.get_dim();
assert(info.num_dim == m->num_dim);
info.num_dim = in_grad.shape.num_dims();
assert(info.num_dim == m.num_dim);
for (int i = 0; i < info.num_dim; i++) {
int in_dim_size = (in_domain.hi()[i] - in_domain.lo()[i] + 1);
int out_dim_size = (out_domain.hi()[i] - out_domain.lo()[i] + 1);
info.in_strides[i] = (i == 0) ? 1 : info.in_strides[i - 1] * in_dim_size;
info.out_strides[i] = (i == 0) ? 1 : info.out_strides[i - 1] * out_dim_size;
info.perm[i] = m->perm[i];
if (i == 0) {
info.in_strides[i] = 1;
info.out_strides[i] = 1;
} else {
int in_dim_size = input.shape[legion_dim_t(i)] + 1;
int out_dim_size = output.shape[legion_dim_t(i)] + 1;
info.in_strides[i] = info.in_strides[i - 1] * in_dim_size;
info.out_strides[i] = info.out_strides[i - 1] * out_dim_size;
}
info.perm[i] = m.perm[i];
}

hipLaunchKernelGGL(transpose_simple_kernel,
GET_BLOCKS(out_domain.get_volume()),
GET_BLOCKS(output.shape.get_volume()),
CUDA_NUM_THREADS,
0,
stream,
out_domain.get_volume(),
input_ptr,
output_ptr,
output.shape.get_volume(),
input.get_float_ptr(),
output.get_float_ptr(),
info,
0.0f /*beta*/);
}

void backward_kernel(hipStream_t stream,
TransposePerDeviceState const *m,
TransposePerDeviceState const &m,
float *input_grad_ptr,
float const *output_grad_ptr,
Domain in_grad_domain,
Domain out_grad_domain) {

TransposeStrides info;
info.num_dim = in_grad_domain.get_dim();
assert(info.num_dim == m->num_dim);
info.num_dim = in_grad.shape.num_dims();
assert(info.num_dim == m.num_dim);
for (int i = 0; i < info.num_dim; i++) {
int in_dim_size = (out_grad_domain.hi()[i] - out_grad_domain.lo()[i] + 1);
int out_dim_size = (in_grad_domain.hi()[i] - in_grad_domain.lo()[i] + 1);
info.in_strides[i] = (i == 0) ? 1 : info.in_strides[i - 1] * in_dim_size;
info.out_strides[i] = (i == 0) ? 1 : info.out_strides[i - 1] * out_dim_size;
info.perm[m->perm[i]] = i;
if (i == 0) {
info.in_strides[i] = 1;
info.out_strides[i] = 1;
} else {
int in_dim_size = out_grad.shape[legion_dim_t(i)] + 1;
int out_dim_size = in_grad.shape[legion_dim_t(i)] + 1;
info.in_strides[i] = info.in_strides[i - 1] * in_dim_size;
info.out_strides[i] = info.out_strides[i - 1] * out_dim_size;
}
info.perm[m.perm[i]] = i;
}
hipLaunchKernelGGL(transpose_simple_kernel,
GET_BLOCKS(in_grad_domain.get_volume()),
GET_BLOCKS(in_grad.shape.get_volume()),
CUDA_NUM_THREADS,
0,
stream,
in_grad_domain.get_volume(),
output_grad_ptr,
input_grad_ptr,
in_grad.shape.get_volume(),
out_grad.get_float_ptr(),
in_grad.get_float_ptr(),
info,
1.0f /*beta*/);
}

__global__ void transpose_simple_kernel(coord_t volume,
float const *in_ptr,
float *out_ptr,
const TransposeStrides info,
float const beta) {
CUDA_KERNEL_LOOP(o_idx, volume) {
coord_t i_idx = 0;
coord_t t = o_idx;
for (int i = info.num_dim - 1; i >= 0; i--) {
coord_t ratio = t / info.out_strides[i];
t -= ratio * info.out_strides[i];
i_idx += ratio * info.in_strides[info.perm[i]];
}
out_ptr[o_idx] += out_ptr[o_idx] * beta + in_ptr[i_idx];
}
}

} // namespace Transpose
} // namespace Kernels
} // namespace FlexFlow

0 comments on commit af1caf5

Please sign in to comment.