Skip to content

Commit

Permalink
Support backward of cmul.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Oct 29, 2024
1 parent e4d13c2 commit d69ee17
Show file tree
Hide file tree
Showing 10 changed files with 1,231 additions and 713 deletions.
2 changes: 1 addition & 1 deletion lib/nnc/cmd/blas/ccv_nnc_blas.c
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ REGISTER_COMMAND(CCV_NNC_CMUL_FORWARD)(ccv_nnc_cmd_registry_t* const registry)
}

REGISTER_COMMAND(CCV_NNC_CMUL_BACKWARD)(ccv_nnc_cmd_registry_t* const registry)
FIND_BACKEND(ccv_nnc_cmul_cpu_ref.c, gpu/ccv_nnc_cmul_gpu_ref.cu)
FIND_BACKEND(ccv_nnc_cmul_cpu_ref.c, gpu/ccv_nnc_cmul_gpu_ref.cu, mps/ccv_nnc_cmul_mps.m)
{
registry->flags = CCV_NNC_CMD_ATTR_NULL_IS_ONES;
registry->bitmask = _ccv_nnc_cmul_back_bitmask;
Expand Down
399 changes: 396 additions & 3 deletions lib/nnc/cmd/blas/mps/ccv_nnc_cmul_mps.m

Large diffs are not rendered by default.

1,386 changes: 694 additions & 692 deletions lib/nnc/cmd/ccv_nnc_cmd.inc

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions lib/nnc/mfa/ccv_nnc_mfa_cmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ void ccv_nnc_mfa_encode_cmul(ccv_nnc_mfa_context_t* context, ccv_nnc_mfa_cmul_pa
CCV_NNC_MFA_PRECONDITION(num_tensors == 3);

CMulDescriptor descriptor;
descriptor.conjugate = params.conjugate ? 1 : 0;
descriptor.memoryPrecision = (params.data_type == MTL::DataTypeFloat) ? GEMMOperandPrecision::FP32 : GEMMOperandPrecision::FP16;
descriptor.stridesA[0] = params.astride[0];
descriptor.stridesA[1] = params.astride[1];
Expand Down
1 change: 1 addition & 0 deletions lib/nnc/mfa/ccv_nnc_mfa_cmul.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define GUARD_ccv_nnc_mfa_cmul_hpp

typedef struct {
uint8_t conjugate;
uint64_t data_type;
uint32_t astride[3];
uint32_t bstride[3];
Expand Down
4 changes: 3 additions & 1 deletion lib/nnc/mfa/v2/CMulDescriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
bool CMulDescriptor::operator==(const CMulDescriptor& rhs) const {
return
memoryPrecision == rhs.memoryPrecision &&
conjugate == rhs.conjugate &&
value == rhs.value &&
simd_all(stridesA == rhs.stridesA) &&
simd_all(stridesB == rhs.stridesB) &&
Expand Down Expand Up @@ -43,8 +44,9 @@ std::pair<CMulKernelDescriptor, PipelineValue<CMulKernel> *> CMulDescriptor::fin
};

CMulKernelDescriptor kernelDesc;
kernelDesc.memoryPrecision = memoryPrecision;
kernelDesc.conjugate = conjugate;
kernelDesc.value = value;
kernelDesc.memoryPrecision = memoryPrecision;

// WARNING: The owner must explicitly retain the compute pipeline.
auto createPipeline =
Expand Down
9 changes: 6 additions & 3 deletions lib/nnc/mfa/v2/CMulDescriptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
#include "GEMMOperandPrecision.hpp"

struct CMulKernelDescriptor {
uint8_t conjugate;
uint8_t value;
GEMMOperandPrecision memoryPrecision;
unsigned int value;
constexpr bool operator==(const CMulKernelDescriptor &rhs) const { return value == rhs.value && memoryPrecision == rhs.memoryPrecision; }
constexpr bool operator==(const CMulKernelDescriptor &rhs) const { return value == rhs.value && memoryPrecision == rhs.memoryPrecision && conjugate == rhs.conjugate; }
};

template<>
Expand All @@ -22,7 +23,9 @@ struct std::hash<CMulKernelDescriptor>
struct CMulKernel;

struct CMulDescriptor {
unsigned int value;
uint8_t conjugate;

uint8_t value;

GEMMOperandPrecision memoryPrecision;

Expand Down
132 changes: 123 additions & 9 deletions lib/nnc/mfa/v2/CMulKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@

CMulKernel::CMulKernel(CMulKernelDescriptor descriptor, MTL::Device *const device) {

memoryPrecision = descriptor.memoryPrecision;
conjugate = descriptor.conjugate;

value = descriptor.value;

memoryPrecision = descriptor.memoryPrecision;

source = createSource();

threadgroupMemoryAllocation = createThreadgroupMemoryAllocation();
Expand Down Expand Up @@ -38,8 +40,119 @@ unsigned short CMulKernel::createThreadgroupMemoryAllocation() const noexcept {

std::string CMulKernel::createSource() const noexcept {
std::string shader = createConstants() + "\n";
if (value == 0) {
shader += R"(
if (conjugate) {
if (value == 0) {
shader += R"(
#include <metal_stdlib>
using namespace metal;
kernel void cmul(
device real *src0 [[buffer(0)]],
device real *src1 [[buffer(1)]],
device real *destination [[buffer(2)]],
uint3 tpig [[thread_position_in_grid]]
) {
const uint idx = tpig.x;
if (idx >= dim0)
return;
const float a0 = (float)src0[idx * 2];
const float a1 = (float)src0[idx * 2 + 1];
const float b0 = (float)src1[idx * 2];
const float b1 = (float)src1[idx * 2 + 1];
destination[idx * 2] = (real)(a0 * b0 + a1 * b1);
destination[idx * 2 + 1] = (real)(-a0 * b1 + a1 * b0);
}
)";
} else if (value == 1) {
shader += R"(
#include <metal_stdlib>
using namespace metal;
kernel void cmul(
device real *src0 [[buffer(0)]],
device real *src1 [[buffer(1)]],
device real *destination [[buffer(2)]],
uint3 tpig [[thread_position_in_grid]]
) {
const uint x = tpig.x;
const uint y = tpig.y;
if (y >= dim1 || x >= dim0)
return;
const uint ida = y * astride0 + x * 2;
const uint idb = y * bstride0 + x * 2;
const uint idc = y * cstride0 + x * 2;
const float a0 = (float)src0[ida];
const float a1 = (float)src0[ida + 1];
const float b0 = (float)src1[idb];
const float b1 = (float)src1[idb + 1];
destination[idc] = (real)(a0 * b0 + a1 * b1);
destination[idc + 1] = (real)(-a0 * b1 + a1 * b0);
}
)";
} else if (value == 2) {
shader += R"(
#include <metal_stdlib>
using namespace metal;
kernel void cmul(
device real *src0 [[buffer(0)]],
device real *src1 [[buffer(1)]],
device real *destination [[buffer(2)]],
uint3 tpig [[thread_position_in_grid]]
) {
const uint x = tpig.x;
const uint y = tpig.y;
const uint z = tpig.z;
if (y >= dim1 || x >= dim0)
return;
const uint ida = z * astride1 + y * astride0 + x * 2;
const uint idb = z * bstride1 + y * bstride0 + x * 2;
const uint idc = z * cstride1 + y * cstride0 + x * 2;
const float a0 = (float)src0[ida];
const float a1 = (float)src0[ida + 1];
const float b0 = (float)src1[idb];
const float b1 = (float)src1[idb + 1];
destination[idc] = (real)(a0 * b0 + a1 * b1);
destination[idc + 1] = (real)(-a0 * b1 + a1 * b0);
}
)";
} else {
shader += R"(
#include <metal_stdlib>
using namespace metal;
kernel void cmul(
device real *src0 [[buffer(0)]],
device real *src1 [[buffer(1)]],
device real *destination [[buffer(2)]],
uint3 tpig [[thread_position_in_grid]]
) {
const uint x = tpig.x;
const uint y = tpig.y;
const uint z = tpig.z;
if (y >= dim1 || x >= dim0)
return;
const int u = z % dim2;
const int v = z / dim2;
const uint ida = v * astride2 + u * astride1 + y * astride0 + x * 2;
const uint idb = v * bstride2 + u * bstride1 + y * bstride0 + x * 2;
const uint idc = v * cstride2 + u * cstride1 + y * cstride0 + x * 2;
const float a0 = (float)src0[ida];
const float a1 = (float)src0[ida + 1];
const float b0 = (float)src1[idb];
const float b1 = (float)src1[idb + 1];
destination[idc] = (real)(a0 * b0 + a1 * b1);
destination[idc + 1] = (real)(-a0 * b1 + a1 * b0);
}
)";
}
} else {
if (value == 0) {
shader += R"(
#include <metal_stdlib>
using namespace metal;
Expand All @@ -61,8 +174,8 @@ kernel void cmul(
destination[idx * 2 + 1] = (real)(a0 * b1 + a1 * b0);
}
)";
} else if (value == 1) {
shader += R"(
} else if (value == 1) {
shader += R"(
#include <metal_stdlib>
using namespace metal;
Expand All @@ -88,8 +201,8 @@ kernel void cmul(
destination[idc + 1] = (real)(a0 * b1 + a1 * b0);
}
)";
} else if (value == 2) {
shader += R"(
} else if (value == 2) {
shader += R"(
#include <metal_stdlib>
using namespace metal;
Expand All @@ -116,8 +229,8 @@ kernel void cmul(
destination[idc + 1] = (real)(a0 * b1 + a1 * b0);
}
)";
} else {
shader += R"(
} else {
shader += R"(
#include <metal_stdlib>
using namespace metal;
Expand Down Expand Up @@ -146,6 +259,7 @@ kernel void cmul(
destination[idc + 1] = (real)(a0 * b1 + a1 * b0);
}
)";
}
}
return shader;
}
Expand Down
6 changes: 4 additions & 2 deletions lib/nnc/mfa/v2/CMulKernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@ struct CMulKernel {
/// The number of threads per group.
MTL::Size threadgroupSize;

GEMMOperandPrecision memoryPrecision;
uint8_t conjugate;

uint8_t value;

unsigned int value;
GEMMOperandPrecision memoryPrecision;

CMulKernel(CMulKernelDescriptor descriptor, MTL::Device *const device);

Expand Down
4 changes: 2 additions & 2 deletions test/int/nnc/cublas.tests.c
Original file line number Diff line number Diff line change
Expand Up @@ -3098,7 +3098,7 @@ TEST_CASE("cmul in float, broadcast semantics")

TEST_CASE("cmul gradient in float")
{
GUARD_ELSE_RETURN(ccv_nnc_cmd_ok(CCV_NNC_CMUL_BACKWARD, CCV_NNC_BACKEND_GPU_REF));
GUARD_ELSE_RETURN(ccv_nnc_cmd_ok(CCV_NNC_CMUL_BACKWARD, CCV_NNC_BACKEND_GPU_REF) || ccv_nnc_cmd_ok(CCV_NNC_CMUL_BACKWARD, CCV_NNC_BACKEND_MPS));
ccv_nnc_symbolic_graph_t* const symbolic_graph = ccv_nnc_symbolic_graph_new();
ccv_nnc_tensor_symbol_t a = ccv_nnc_tensor_symbol_new(symbolic_graph, GPU_TENSOR_NCHW(000, 32F, 20, 10), "a");
ccv_nnc_tensor_symbol_t b = ccv_nnc_tensor_symbol_new(symbolic_graph, GPU_TENSOR_NCHW(000, 32F, 20, 10), "b");
Expand Down Expand Up @@ -3157,7 +3157,7 @@ TEST_CASE("cmul gradient in float")

TEST_CASE("cmul gradient in half precision")
{
GUARD_ELSE_RETURN(ccv_nnc_cmd_ok(CCV_NNC_CMUL_BACKWARD, CCV_NNC_BACKEND_GPU_REF));
GUARD_ELSE_RETURN(ccv_nnc_cmd_ok(CCV_NNC_CMUL_BACKWARD, CCV_NNC_BACKEND_GPU_REF) || ccv_nnc_cmd_ok(CCV_NNC_CMUL_BACKWARD, CCV_NNC_BACKEND_MPS));
ccv_nnc_symbolic_graph_t* const symbolic_graph = ccv_nnc_symbolic_graph_new();
ccv_nnc_tensor_symbol_t a = ccv_nnc_tensor_symbol_new(symbolic_graph, GPU_TENSOR_NCHW(000, 16F, 20, 10), "a");
ccv_nnc_tensor_symbol_t b = ccv_nnc_tensor_symbol_new(symbolic_graph, GPU_TENSOR_NCHW(000, 16F, 20, 10), "b");
Expand Down

0 comments on commit d69ee17

Please sign in to comment.