Skip to content

Commit

Permalink
Force the kernel selection to be on registerPrecisionC = FP32 only.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Aug 19, 2024
1 parent 0181946 commit 34bad96
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 6 deletions.
1 change: 1 addition & 0 deletions bin/nnc/adversarial_shape_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ void runTest(TestDescriptor descriptor)
};
gemmDesc.transposeState = simd::uchar3 { descriptor.transposeState[0], descriptor.transposeState[1], descriptor.transposeState[0] };
gemmDesc.useBias = descriptor.useBias;
gemmDesc.registerPrecisionC = GEMMOperandPrecision::FP32;

// Test the kernel.
auto statistic = profileProblemSize(gemmDesc);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ static int _ccv_nnc_scaled_dot_product_attention_forw(const ccv_nnc_cmd_t cmd, c
return CCV_NNC_EXEC_INVALID;
}

int is_upcast = ((cmd.info.blas.flags & CCV_NNC_GEMM_32F) && q->info.datatype == CCV_16F); // See the TODO: comment.
int attention_is_batched = (batch_size > 1);
ccv_nnc_mfa_attention_params_t params = {
.data_type = mtl_data_type,
Expand All @@ -166,7 +167,7 @@ static int _ccv_nnc_scaled_dot_product_attention_forw(const ccv_nnc_cmd_t cmd, c
.alpha = cmd.info.scaled_dot_product_attention.scale,
.batched = (attention_is_batched ? 1 : 0),
.masked = (attn_mask != NULL ? 1 : 0),
.upcast = (cmd.info.scaled_dot_product_attention.flags & CCV_NNC_GEMM_32F),
.upcast = (cmd.info.scaled_dot_product_attention.flags & CCV_NNC_GEMM_32F), // TODO: This default to FP32 after v2 introduction.

.batch_dims_q = { 0 },
.batch_dims_mask = { 0 },
Expand Down Expand Up @@ -317,7 +318,7 @@ static int _ccv_nnc_scaled_dot_product_attention_forw(const ccv_nnc_cmd_t cmd, c
.B_trans = true,
.D_trans = false,
.fused_bias = (bias ? 1 : 0),
.register_float = 0,
.register_float = (is_upcast ? 1 : 0),

.batch_dimension = 1,
.batch_stride_a = 0,
Expand Down
2 changes: 1 addition & 1 deletion lib/nnc/mfa/v2/GEMMDescriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ std::pair<GEMMKernelDescriptor, PipelineValue<GEMMKernel> *> GEMMDescriptor::fin
}

// Set the device and examine the block dimensions.
auto blockDimensionsAndPaddedBlockDimensions = GEMMKernelDescriptor::getBlockDimensions(device, dprops.coreCount, this->matrixDimensions, this->batchDimension, this->memoryPrecisions, this->transposeState);
auto blockDimensionsAndPaddedBlockDimensions = GEMMKernelDescriptor::getBlockDimensions(device, dprops.coreCount, this->matrixDimensions, this->batchDimension, this->memoryPrecisions, registerPrecisionC, this->transposeState);
std::optional<bool> preferAsyncStore = std::nullopt;
bool preferAsyncLoad;
simd::ushort2 splits;
Expand Down
4 changes: 2 additions & 2 deletions lib/nnc/mfa/v2/GEMMKernelDescriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ GEMMKernelDescriptor::GEMMKernelDescriptor(simd::ushort3 blockDimensions, GEMMOp
this->useBias = useBias;
}

std::pair<simd::ushort3, std::optional<simd::ushort3>> GEMMKernelDescriptor::getBlockDimensions(MTL::Device* const mtlDevice, const uint32_t coreCount, const simd::uint3 matrixDimensions, const int64_t batchDimension, const GEMMOperandPrecisions memoryPrecisions, const simd::uchar3 transposeState) noexcept {
std::pair<simd::ushort3, std::optional<simd::ushort3>> GEMMKernelDescriptor::getBlockDimensions(MTL::Device* const mtlDevice, const uint32_t coreCount, const simd::uint3 matrixDimensions, const int64_t batchDimension, const GEMMOperandPrecisions memoryPrecisions, const GEMMOperandPrecision registerPrecisionC, const simd::uchar3 transposeState) noexcept {
if (mtlDevice->supportsFamily(MTL::GPUFamily(1009))) {
if (!transposeState[0] && transposeState[1])
if (!transposeState[0] && transposeState[1] && registerPrecisionC == GEMMOperandPrecision::FP32)
{
unsigned short paddedAK = (memoryPrecisions.A == GEMMOperandPrecision::FP32) ? 8 : 32;
unsigned short paddedBK = (memoryPrecisions.B == GEMMOperandPrecision::FP32) ? 8 : 32;
Expand Down
2 changes: 1 addition & 1 deletion lib/nnc/mfa/v2/GEMMKernelDescriptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ struct GEMMKernelDescriptor {
///
/// This function initializes the 'blockDimensions' and
/// 'paddedBlockDimensions' properties.
static std::pair<simd::ushort3, std::optional<simd::ushort3>> getBlockDimensions(MTL::Device* const mtlDevice, const uint32_t coreCount, const simd::uint3 matrixDimensions, const int64_t batchDimension, const GEMMOperandPrecisions memoryPrecisions, const simd::uchar3 transposeState) noexcept;
static std::pair<simd::ushort3, std::optional<simd::ushort3>> getBlockDimensions(MTL::Device* const mtlDevice, const uint32_t coreCount, const simd::uint3 matrixDimensions, const int64_t batchDimension, const GEMMOperandPrecisions memoryPrecisions, const GEMMOperandPrecision registerPrecisionC, const simd::uchar3 transposeState) noexcept;

bool operator==(const GEMMKernelDescriptor& rhs) const;
};
Expand Down

0 comments on commit 34bad96

Please sign in to comment.