Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
oOTigger committed Nov 22, 2024
1 parent 878cff1 commit 42f1fce
Show file tree
Hide file tree
Showing 67 changed files with 934 additions and 453 deletions.
3 changes: 1 addition & 2 deletions lib/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ file(GLOB_RECURSE SRC
CONFIGURE_DEPENDS
LIST_DIRECTORIES False
src/*.cc
src/cuda/cuda_helper.cu
src/cuda/ops/*.cu
src/cuda/*.cu
)

add_library(
Expand Down
4 changes: 2 additions & 2 deletions lib/kernels/include/kernels/batch_norm_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ void forward_kernel(ffStream_t stream,

void backward_kernel(ffStream_t stream,
BatchNormPerDeviceState const &m,
float const *input_ptr,
float *output_grad_ptr,
float const *output_ptr,
float *output_grad_ptr,
float const *input_ptr,
float *input_grad_ptr,
float const *scale_ptr,
float *scale_grad_ptr,
Expand Down
4 changes: 2 additions & 2 deletions lib/kernels/include/kernels/cast_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ void forward_kernel(ffStream_t stream,
GenericTensorAccessorW const &output);

void backward_kernel(ffStream_t stream,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output);
GenericTensorAccessorR const &output,
GenericTensorAccessorW const &input);

} // namespace FlexFlow::Kernels::Cast

Expand Down
4 changes: 2 additions & 2 deletions lib/kernels/include/kernels/cast_kernels_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ namespace FlexFlow::Kernels::Cast {
void cpu_forward_kernel(GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output);

void cpu_backward_kernel(GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output);
void cpu_backward_kernel(GenericTensorAccessorR const &output,
GenericTensorAccessorW const &input);

} // namespace FlexFlow::Kernels::Cast

Expand Down
4 changes: 2 additions & 2 deletions lib/kernels/include/kernels/conv_2d_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ void forward_kernel(ffStream_t stream,

void backward_kernel(ffStream_t stream,
Conv2DPerDeviceState const &m,
float const *input_ptr,
float *input_grad_ptr,
float const *output_ptr,
float *output_grad_ptr,
float const *input_ptr,
float *input_grad_ptr,
float const *filter_ptr,
float *filter_grad_ptr,
float *bias_grad_ptr,
Expand Down
6 changes: 3 additions & 3 deletions lib/kernels/include/kernels/element_unary_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ void backward_kernel(ffStream_t stream,
ElementUnaryPerDeviceState const &device_state,
ElementUnaryAttrs const &attrs,
PerDeviceFFHandle const &handle,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &input_grad,
GenericTensorAccessorR const &output,
GenericTensorAccessorR const &output_grad);
GenericTensorAccessorR const &output_grad,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &input_grad);

} // namespace Kernels::ElementUnary
} // namespace FlexFlow
Expand Down
4 changes: 2 additions & 2 deletions lib/kernels/include/kernels/embedding_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ void forward_kernel(ffStream_t stream,
int out_dim,
int batch_size);
void backward_kernel(ffStream_t stream,
GenericTensorAccessorR const &input,
GenericTensorAccessorR const &output,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &weight_grad,
DataType input_data_type,
DataType output_data_type,
DataType input_data_type,
std::optional<AggregateOp> aggr,
int in_dim,
int out_dim,
Expand Down
7 changes: 4 additions & 3 deletions lib/kernels/include/kernels/flat_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ namespace FlexFlow::Kernels::Flat {
void forward_kernel(ffStream_t stream,
GenericTensorAccessorR input,
float *output_ptr);
void backward_kernel(ffStream_t stream,

void backward_kernel(cudaStream_t stream,
GenericTensorAccessorR input,
float *input_grad_ptr,
float const *output_grad_ptr);
float const *output_grad_ptr,
float *input_grad_ptr);

} // namespace FlexFlow::Kernels::Flat

Expand Down
4 changes: 2 additions & 2 deletions lib/kernels/include/kernels/linear_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ void forward_kernel(ffStream_t stream,

void backward_kernel(ffStream_t stream,
LinearPerDeviceState const &m,
float const *input_ptr,
float *input_grad_ptr,
float const *output_ptr,
float *output_grad_ptr,
float const *input_ptr,
float *input_grad_ptr,
float const *kernel_ptr,
float *kernel_grad_ptr,
float *bias_ptr,
Expand Down
2 changes: 1 addition & 1 deletion lib/kernels/include/kernels/loss_function_kernels.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#ifndef _FLEXFLOW_KERNELS_INCLUDE_KERNELS_LOSS_FUNCTION_KERNELS_H
#define _FLEXFLOW_KERNELS_INCLUDE_KERNELS_LOSS_FUNCTION_KERNELS_H

#include "kernels/device.h"
#include "device.h"

namespace FlexFlow {

Expand Down
29 changes: 14 additions & 15 deletions lib/kernels/include/kernels/metrics_kernels.h
Original file line number Diff line number Diff line change
@@ -1,25 +1,24 @@
#ifndef _FLEXFLOW_KERNELS_INCLUDE_KERNELS_METRICS_KERNELS_H
#define _FLEXFLOW_KERNELS_INCLUDE_KERNELS_METRICS_KERNELS_H

#include "perf_metrics.h"
#include "kernels/perf_metrics.h"
#include "pcg/metric.h"

namespace FlexFlow {

void update_metrics_sparse_label_kernel(ffStream_t,
MetricsAttrs const &,
float const *logit_ptr,
int const *label_ptr,
int num_samples,
int num_classes,
PerfMetrics &perf_zc);
void update_metrics_label_kernel(ffStream_t,
MetricsAttrs const &,
float const *logit_ptr,
float const *label_ptr,
int num_samples,
int num_classes,
PerfMetrics &perf_zc);
void update_metrics_sparse_label_kernel_wrapper(float const *logit_ptr,
int const *label_ptr,
MetricsAttrs const *me,
int num_effective_samples,
int num_classes,
PerfMetrics &perf_zc);

void update_metrics_label_kernel_wrapper(float const *logit_ptr,
float const *label_ptr,
MetricsAttrs const *me,
int num_samples,
int num_classes,
PerfMetrics &perf_zc);
} // namespace FlexFlow

#endif
124 changes: 81 additions & 43 deletions lib/kernels/include/kernels/optimizer_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,53 +2,91 @@
#define _FLEXFLOW_KERNELS_INCLUDE_KERNELS_OPTIMIZER_KERNELS_H

#include "device.h"
#include "kernels/ff_handle.h"
#include "kernels/nccl.h"
#include "kernels/per_device_op_state.dtg.h"

namespace FlexFlow {

void sgd_ps_update_task_gpu(ffStream_t,
float lr,
float momentum,
bool nesterov,
__global__ void sgd_update(size_t count,
float lr,
float weight_decay,
float momentum,
bool nesterov,
float const *WGrad,
float *V,
float *W);

class SGDOptimizer {
public:
static __host__ void ps_update_task_gpu(SGDOptimizer const *op,
float const *w_grad_ptr,
size_t size,
int num_replicas,
float *w_ptr,
float *v_ptr);

#ifdef FF_USE_NCCL
static __host__ void nccl_update_task_gpu(SGDOptimizer const *op,
PerDeviceOpState const *meta,
float const *w_grad_ptr,
size_t size,
float *w_ptr,
float *v_ptr);
#endif

public:
float lr;
float weight_decay;
float momentum;
bool nesterov;
};

__global__ void
add_kernel(int count, float scale, float const *src, float *dst);

__global__ void scale_kernel(int count, float a, float b, float *ptr);

__global__ void adam_update(int count,
float alpha_t,
float beta1,
float beta2,
float weight_decay,
float const *weight_grad_ptr,
size_t size,
int num_replicas,
float *weight_ptr,
float *sgd_v_ptr);

void sgd_nccl_update_task_gpu(ffStream_t,
float lr,
float momentum,
bool nesterov,
float weight_decay PerDeviceFFHandle const &,
float const *weight_grad_ptr,
size_t size,
float *weight_ptr,
float *sgd_v_ptr);

void adam_ps_update_task_gpu(ffStream_t,
float alpha_t,
float beta1,
float beta2,
float weight_decay,
float epsilon,
float const *weight_grad_ptr,
float *adam_m_ptr,
float *adam_v_ptr,
float *weight_ptr);

void adam_nccl_update_task_gpu(ffStream_t,
float alpha_t,
float beta1,
float beta2,
float weight_decay,
float epsilon,
PerDeviceFFHandle const &,
float const *weight_grad_ptr,
float *adam_m_ptr,
float *adam_v_ptr,
float *weight_ptr);
float epsilon,
float const *WGrad,
float *M,
float *V,
float *W);

} // namespace FlexFlow
class AdamOptimizer {
public:
static __host__ void ps_update_task_gpu(AdamOptimizer const *op,
float const *w_grad_ptr,
size_t size,
int num_replicas,
float *w_ptr,
float *v_ptr,
float *m_ptr);

#ifdef FF_USE_NCCL
static __host__ void nccl_update_task_gpu(AdamOptimizer const *op,
PerDeviceOpState const *meta,
float const *w_grad_ptr,
size_t size,
float *w_ptr,
float *v_ptr,
float *m_ptr);
#endif

public:
float alpha;
float alpha_t;
float beta1;
float beta2;
float weight_decay;
float epsilon;
};

} // namespace FlexFlow

#endif // _FLEXFLOW_KERNELS_INCLUDE_KERNELS_OPTIMIZER_KERNELS_H
4 changes: 2 additions & 2 deletions lib/kernels/include/kernels/partition_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ void forward_kernel(ffStream_t stream,

void backward_kernel(ffStream_t stream,
RepartitionPerDeviceState const &m,
GenericTensorAccessorW const &output_grad,
GenericTensorAccessorR const &input_grad);
GenericTensorAccessorR const &output_grad,
GenericTensorAccessorW const &input_grad);

} // namespace Kernels::Repartition
} // namespace FlexFlow
Expand Down
9 changes: 5 additions & 4 deletions lib/kernels/include/kernels/pool_2d_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,13 @@ void forward_kernel(ffStream_t stream,
void const *input_ptr,
void *output_ptr);

void backward_kernel(ffStream_t stream,
void backward_kernel(cudaStream_t stream,
Pool2DPerDeviceState const &m,
void const *input_ptr,
void *input_grad_ptr,
void const *output_ptr,
void const *output_grad_ptr);
void const *output_grad_ptr,
void const *input_ptr,
void *input_grad_ptr);


} // namespace Kernels::Pool2D
} // namespace FlexFlow
Expand Down
4 changes: 2 additions & 2 deletions lib/kernels/include/kernels/reduction_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ void forward_kernel(ffStream_t stream,
size_t num_replicas);

void backward_kernel(ffStream_t stream,
GenericTensorAccessorW const &input,
GenericTensorAccessorR const &output);
GenericTensorAccessorR const &output,
GenericTensorAccessorW const &input);

} // namespace FlexFlow::Kernels::Reduction

Expand Down
4 changes: 2 additions & 2 deletions lib/kernels/include/kernels/reshape_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ void forward_kernel(ffStream_t stream,

void backward_kernel(ffStream_t stream,
ReshapePerDeviceState const &per_device_state,
GenericTensorAccessorW const &input,
GenericTensorAccessorR const &output);
GenericTensorAccessorR const &output,
GenericTensorAccessorW const &input);

} // namespace Kernels::Reshape
} // namespace FlexFlow
Expand Down
2 changes: 1 addition & 1 deletion lib/kernels/include/kernels/softmax_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ void forward_kernel(ffStream_t stream,
float *output_ptr);

void backward_kernel(ffStream_t stream,
float *input_grad_ptr,
float const *output_grad_ptr,
float *input_grad_ptr,
size_t num_elements);

} // namespace Kernels::Softmax
Expand Down
4 changes: 2 additions & 2 deletions lib/kernels/include/kernels/transpose_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ void forward_kernel(cudaStream_t stream,

void backward_kernel(cudaStream_t stream,
TransposePerDeviceState const &m,
GenericTensorAccessorW const &in_grad,
GenericTensorAccessorR const &out_grad);
GenericTensorAccessorR const &out_grad,
GenericTensorAccessorW const &in_grad);

} // namespace Kernels::Transpose
} // namespace FlexFlow
Expand Down
14 changes: 7 additions & 7 deletions lib/kernels/src/cpu/cast_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ struct CPUForwardKernel {

template <DataType IDT, DataType ODT>
struct CPUBackwardKernel {
void operator()(GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output) {
size_t volume = input.shape.get_volume();
void operator()(GenericTensorAccessorR const &output,
GenericTensorAccessorW const &input) {
size_t volume = output.shape.get_volume();
cpu_cast_backward(
input.get<IDT>(), output.get<ODT>(), volume, cast_to<ODT>(1.0f));
output.get<IDT>(), input.get<ODT>(), volume, cast_to<ODT>(1.0f));
}
};

Expand All @@ -42,10 +42,10 @@ void cpu_forward_kernel(GenericTensorAccessorR const &input,
input.data_type, output.data_type, input, output);
}

void cpu_backward_kernel(GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output) {
void cpu_backward_kernel(GenericTensorAccessorR const &output,
GenericTensorAccessorW const &input) {
DataTypeDispatch2<CPUBackwardKernel>{}(
input.data_type, output.data_type, input, output);
output.data_type, input.data_type, output, input);
}

} // namespace FlexFlow::Kernels::Cast
Loading

0 comments on commit 42f1fce

Please sign in to comment.