From f19a22ff124fec4a8ad7051a3ae311907fcc8b05 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Fri, 5 May 2023 16:24:26 +0200 Subject: [PATCH 1/4] Add Compute_options and L2_options --- cpp/include/raft/distance/distance_types.hpp | 54 ++++++++++++++++++++ docs/source/cpp_api/distance.rst | 3 ++ 2 files changed, 57 insertions(+) diff --git a/cpp/include/raft/distance/distance_types.hpp b/cpp/include/raft/distance/distance_types.hpp index d17ef358ee..cdc2a564b0 100644 --- a/cpp/include/raft/distance/distance_types.hpp +++ b/cpp/include/raft/distance/distance_types.hpp @@ -19,6 +19,60 @@ namespace raft { namespace distance { +/** + * @brief Describes how precise and fast distance should be computed. + */ +enum class Compute_options { + /** The choice of speed and accuracy is left to the implementation. + * + * This will use Fast_Similar_Precision by default. If the environment + * variable `NVIDIA_TF32_OVERRIDE` is set, this will default to + * Fast_Reduced_Precision. + * + * */ + Unspecified, + /** Use the most numerically accurate option. + * */ + Precise, + /** Use fast computation with similar precision. + * + * - If possible, expand the norm computation for two points into the sum of + * norms minus an inner product: + * + * || x - y ||^2 = || x ||^2 + || y ||^2 - 2 + * + * The inner product becomes a matrix multiplication for many points. + * + * - If possible, execute the matrix multiplication using 3xtfloat, as + * described in [0]. + * + * [0] Ootomo H, Yokota R. Recovering single precision accuracy from Tensor Cores + * while surpassing the FP32 theoretical peak performance. The International + * Journal of High Performance Computing Applications. 2022;36(4):475-491. + * doi:10.1177/10943420221090256 + * + * */ + Fast_Similar_Precision, + /** Use reduced precision to speed up computation. + * + * 1. Use inner product expansion, as described above. + * 2. Use tensor float precision instead of fp32 precision. + * + * */ + Fast_Reduced_Precision +}; + +/** + * @brief Describes how the L2 norm should be computed. + * + */ +struct L2_options { + /** If true, compute squared L2 norm. */ + bool squared; + /** Specify speed and precision of computation. */ + Compute_options compute_options; +}; + /** enum to tell how to compute distance */ enum DistanceType : unsigned short { diff --git a/docs/source/cpp_api/distance.rst b/docs/source/cpp_api/distance.rst index fd81295def..dd516c1cbb 100644 --- a/docs/source/cpp_api/distance.rst +++ b/docs/source/cpp_api/distance.rst @@ -15,6 +15,9 @@ Distance Types namespace *raft::distance* +.. doxygenenum:: raft::distance::Compute_options +.. doxygenenum:: raft::distance::L2_options + .. doxygenenum:: raft::distance::DistanceType :project: RAFT From 789f95af0502e0af4ad600fd9f85ff6caf472581 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Fri, 5 May 2023 16:27:53 +0200 Subject: [PATCH 2/4] Dispatch L2 distances using L2_options The instance with rbf_fin_op caused some headache: Because the handling of L2 expanded and unexpanded is unified in ``distance_impl_l2_with_options``, an instance of the CUTLASS distance kernel for rbf_fin_op was instantiated. For some reason, CUTLASS did not accept this as a valid argument and threw a very very big error message. I could not get the rbf_fin_op in acceptable state for cutlass: I included a default constructor, put const on every method, but to no avail. The current solution is to avoid CUTLASS when another final op is used than the raft::identity_op. --- cpp/include/raft/distance/detail/distance.cuh | 160 +++++++++++------- 1 file changed, 103 insertions(+), 57 deletions(-) diff --git a/cpp/include/raft/distance/detail/distance.cuh b/cpp/include/raft/distance/detail/distance.cuh index 7493c4e558..ba825c033c 100644 --- a/cpp/include/raft/distance/detail/distance.cuh +++ b/cpp/include/raft/distance/detail/distance.cuh @@ -434,8 +434,9 @@ template -void distance_impl_l2_expanded( // NOTE: different name - bool perform_sqrt, // dispatch on sqrt +void distance_impl_l2_with_options( // NOTE: different name + raft::resources const& handle, + L2_options dist_options, const DataT* x, const DataT* y, OutT* out, @@ -445,34 +446,57 @@ void distance_impl_l2_expanded( // NOTE: different name AccT* workspace, size_t worksize, FinOpT fin_op, - cudaStream_t stream, bool is_row_major) { + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + bool perform_sqrt = !dist_options.squared; + constexpr bool fin_op_is_cutlass_compatible = std::is_same_v; + // raft distance support inputs as float/double and output as uint8_t/float/double. static_assert(!((sizeof(OutT) > 1) && (sizeof(AccT) != sizeof(OutT))), "OutT can be uint8_t, float, double," "if sizeof(OutT) > 1 then sizeof(AccT) == sizeof(OutT)."); - ASSERT(!(((x != y) && (worksize < (m + n) * sizeof(AccT))) || (worksize < m * sizeof(AccT))), - "workspace size error"); - ASSERT(workspace != nullptr, "workspace is null"); - - DataT* x_norm = workspace; - DataT* y_norm = workspace; - if (x != y) { - y_norm += m; - raft::linalg::rowNorm( - x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{}); - raft::linalg::rowNorm( - y_norm, y, k, n, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{}); + bool expanded = dist_options.compute_options != Compute_options::Precise; + + if (expanded && fin_op_is_cutlass_compatible) { + ASSERT(!(((x != y) && (worksize < (m + n) * sizeof(AccT))) || (worksize < m * sizeof(AccT))), + "workspace size error"); + ASSERT(workspace != nullptr, "workspace is null"); + + DataT* x_norm = workspace; + DataT* y_norm = workspace; + if (x != y) { + y_norm += m; + raft::linalg::rowNorm( + x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{}); + raft::linalg::rowNorm( + y_norm, y, k, n, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{}); + } else { + raft::linalg::rowNorm( + x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{}); + } + + ops::l2_exp_distance_op distance_op{perform_sqrt}; + // Use if constexpr to prevent instantiation of CUTLASS templates with final + // operations like rbf_fin_op, which are somehow not compatible with + // CUTLASS (they lead to 5 page errors). + if constexpr (fin_op_is_cutlass_compatible) { + pairwise_matrix_dispatch( + distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); + } else { + // Unreachable (see outer if condition). + } } else { - raft::linalg::rowNorm( - x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{}); - } + ops::l2_unexp_distance_op l2_op(perform_sqrt); - ops::l2_exp_distance_op distance_op{perform_sqrt}; - pairwise_matrix_dispatch( - distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); + // The unexpanded L2 does not require the norms of a and b to be calculated. + const DataT* x_norm = nullptr; + const DataT* y_norm = nullptr; + + pairwise_matrix_dispatch( + l2_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); + } } template @@ -490,10 +514,19 @@ void distance_impl(raft::resources const& handle, bool is_row_major, DataT) // metric_arg unused { - bool perform_sqrt = false; - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - distance_impl_l2_expanded( - perform_sqrt, x, y, out, m, n, k, workspace, worksize, fin_op, stream, is_row_major); + bool squared = true; + distance_impl_l2_with_options(handle, + L2_options{squared, Compute_options::Fast_Similar_Precision}, + x, + y, + out, + m, + n, + k, + workspace, + worksize, + fin_op, + is_row_major); } template @@ -511,10 +544,19 @@ void distance_impl(raft::resources const& handle, bool is_row_major, DataT) // metric_arg unused { - bool perform_sqrt = true; - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - distance_impl_l2_expanded( - perform_sqrt, x, y, out, m, n, k, workspace, worksize, fin_op, stream, is_row_major); + bool squared = false; + distance_impl_l2_with_options(handle, + L2_options{squared, Compute_options::Fast_Similar_Precision}, + x, + y, + out, + m, + n, + k, + workspace, + worksize, + fin_op, + is_row_major); } template @@ -526,23 +568,25 @@ void distance_impl(raft::resources const& handle, IdxT m, IdxT n, IdxT k, - AccT*, // workspace unused - size_t, // worksize unused + AccT* workspace, + size_t worksize, FinOpT fin_op, bool is_row_major, DataT) // metric_arg unused { - bool perform_sqrt = false; - ops::l2_unexp_distance_op l2_op(perform_sqrt); - - // The unexpanded L2 does not require the norms of a and b to be calculated. - const DataT* x_norm = nullptr; - const DataT* y_norm = nullptr; - - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - - pairwise_matrix_dispatch( - l2_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); + bool squared = true; + distance_impl_l2_with_options(handle, + L2_options{squared, Compute_options::Precise}, + x, + y, + out, + m, + n, + k, + workspace, + worksize, + fin_op, + is_row_major); } template @@ -554,23 +598,25 @@ void distance_impl(raft::resources const& handle, IdxT m, IdxT n, IdxT k, - AccT*, // workspace unused - size_t, // worksize unused + AccT* workspace, + size_t worksize, FinOpT fin_op, bool is_row_major, DataT) // metric_arg unused { - bool perform_sqrt = true; - ops::l2_unexp_distance_op l2_op(perform_sqrt); - - // The unexpanded L2 does not require the norms of a and b to be calculated. - const DataT* x_norm = nullptr; - const DataT* y_norm = nullptr; - - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - - pairwise_matrix_dispatch( - l2_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); + bool squared = false; + distance_impl_l2_with_options(handle, + L2_options{squared, Compute_options::Fast_Similar_Precision}, + x, + y, + out, + m, + n, + k, + workspace, + worksize, + fin_op, + is_row_major); } template @@ -582,8 +628,8 @@ void distance_impl(raft::resources const& handle, IdxT m, IdxT n, IdxT k, - AccT*, // workspace unused - size_t, // worksize unused + AccT* workspace, + size_t worksize, FinOpT fin_op, bool is_row_major, DataT) // metric_arg unused From 7324cdd20bea1ae0e55a0f77a5bdc1e10179916f Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Fri, 5 May 2023 18:47:23 +0200 Subject: [PATCH 3/4] Enable 1xtfloat dispatch for sm80 --- .../detail/pairwise_distance_cutlass_base.cuh | 135 ++++++++++-------- .../distance/detail/pairwise_distance_gemm.h | 28 ++-- .../detail/pairwise_matrix/dispatch-inl.cuh | 8 +- .../detail/pairwise_matrix/dispatch_sm80.cuh | 32 +++-- 4 files changed, 117 insertions(+), 86 deletions(-) diff --git a/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh index efcd5d9389..cf529e0a4c 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh @@ -38,6 +38,7 @@ #include #include +#include // Compute_options #include "./pairwise_distance_epilogue_elementwise.h" #include "./pairwise_distance_gemm.h" @@ -64,20 +65,22 @@ template -std::enable_if_t::value> cutlassDistanceKernel(const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - OutT* dOutput, - FinalLambda fin_op, - OpT distance_op, - cudaStream_t stream) +std::enable_if_t::value> cutlassDistanceKernel( + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + IdxT lda, + IdxT ldb, + IdxT ldd, + OutT* dOutput, + FinalLambda fin_op, + OpT distance_op, + Compute_options compute_options, + cudaStream_t stream) { static_assert(!(std::is_same::value), "OutType bool is not supported use uint8_t instead"); @@ -110,20 +113,6 @@ std::enable_if_t::value> cutlassDistanceKernel(const Da // default initialize problem size with row major inputs auto problem_size = cutlass::gemm::GemmCoord(n, m, k); - - using cutlassDistKernel = - typename cutlass::gemm::kernel::PairwiseDistanceGemm::GemmKernel; - - using cutlassDist = cutlass::gemm::device::GemmUniversalAdapter; - if constexpr (isRowMajor) { a = y; b = x; @@ -137,41 +126,65 @@ std::enable_if_t::value> cutlassDistanceKernel(const Da gemm_ldb = ldb; } - typename cutlassDist::Arguments arguments{ - mode, problem_size, batch_count, epilog_op_param, a, b, - xn, // C matrix eq vector param, which here is A norm - nullptr, // tensor_Z, - (DataT*)yn, // this is broadcast vec, which is required to be non-const param - dOutput, // Output distance matrix - (int64_t)0, // batch stride A - (int64_t)0, // batch stride B - (int64_t)0, // batch stride Norm A - (int64_t)0, - (int64_t)0, // batch stride Norm B - (int64_t)0, // batch stride Output - gemm_lda, // stride A - gemm_ldb, // stride B - 1, // stride A norm - 0, // this is no-op for Z - 0, // This must be zero - ldd // stride Output matrix + // lambda takes a compile-time constant boolean to determine whether to + // execute with 1xtfloat or 3xtfloat. + auto lambda = [&](auto use_1xtfloat) { + using cutlassDistKernel = + typename cutlass::gemm::kernel::PairwiseDistanceGemm::GemmKernel; + + using cutlassDist = cutlass::gemm::device::GemmUniversalAdapter; + + typename cutlassDist::Arguments arguments{ + mode, problem_size, batch_count, epilog_op_param, a, b, + xn, // C matrix eq vector param, which here is A norm + nullptr, // tensor_Z, + (DataT*)yn, // this is broadcast vec, which is required to be non-const param + dOutput, // Output distance matrix + (int64_t)0, // batch stride A + (int64_t)0, // batch stride B + (int64_t)0, // batch stride Norm A + (int64_t)0, + (int64_t)0, // batch stride Norm B + (int64_t)0, // batch stride Output + gemm_lda, // stride A + gemm_ldb, // stride B + 1, // stride A norm + 0, // this is no-op for Z + 0, // This must be zero + ldd // stride Output matrix + }; + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = cutlassDist::get_workspace_size(arguments); + // Allocate workspace memory + rmm::device_uvector workspace(workspace_size, stream); + // Instantiate CUTLASS kernel depending on templates + cutlassDist cutlassDist_op; + // Check the problem size is supported or not + cutlass::Status status = cutlassDist_op.can_implement(arguments); + CUTLASS_CHECK(status); + // Initialize CUTLASS kernel with arguments and workspace pointer + status = cutlassDist_op.initialize(arguments, workspace.data(), stream); + CUTLASS_CHECK(status); + // Launch initialized CUTLASS kernel + status = cutlassDist_op(); + CUTLASS_CHECK(status); }; - // Using the arguments, query for extra workspace required for matrix multiplication computation - size_t workspace_size = cutlassDist::get_workspace_size(arguments); - // Allocate workspace memory - rmm::device_uvector workspace(workspace_size, stream); - // Instantiate CUTLASS kernel depending on templates - cutlassDist cutlassDist_op; - // Check the problem size is supported or not - cutlass::Status status = cutlassDist_op.can_implement(arguments); - CUTLASS_CHECK(status); - // Initialize CUTLASS kernel with arguments and workspace pointer - status = cutlassDist_op.initialize(arguments, workspace.data(), stream); - CUTLASS_CHECK(status); - // Launch initialized CUTLASS kernel - status = cutlassDist_op(); - CUTLASS_CHECK(status); + if (compute_options == Compute_options::Fast_Reduced_Precision) { + lambda(std::true_type{}); + } else { + lambda(std::false_type{}); + } } }; // namespace detail diff --git a/cpp/include/raft/distance/detail/pairwise_distance_gemm.h b/cpp/include/raft/distance/detail/pairwise_distance_gemm.h index 8dcccfc14f..a48571aab5 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_gemm.h +++ b/cpp/include/raft/distance/detail/pairwise_distance_gemm.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -53,7 +53,9 @@ template < /// Number of stages used in the pipelined mainloop int Stages, /// data layout row/column major of inputs - bool isRowMajor> + bool isRowMajor, + /// Whether to use 3xtfloat or 1xtfloat: + bool use_1xtfloat> struct PairwiseDistanceGemm { // This struct is specialized for fp32/3xTF32 @@ -69,7 +71,10 @@ struct PairwiseDistanceGemm { cutlass::gemm::GemmShape<16, 8, 4>; // <- MMA Op tile M = 16, N = 8, K = 4 /// Operation performed by GEMM - using Operator = cutlass::arch::OpMultiplyAddFastF32; + using Operator = + std::conditional_t; // This implies 3xtfloat // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU // SM @@ -147,16 +152,19 @@ template < /// Number of stages used in the pipelined mainloop int Stages, /// data layout row/column major of inputs - bool isRowMajor> -struct PairwiseDistanceGemm +struct PairwiseDistanceGemm { + isRowMajor, + use_1xtfloat> { // using Transform = cutlass::ComplexTransform::kNone; // Threadblock-level tile size (concept: GemmShape) using ThreadblockShape = @@ -168,7 +176,9 @@ struct PairwiseDistanceGemm; - // Operation performed by GEMM + // Operation performed by GEMM. Regardless of the value of use_1xtfloat, we do + // OpMultiplyAdd with a TensorOp. So we are using tensor cores for double + // precision. using Operator = cutlass::arch::OpMultiplyAdd; // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU // SM @@ -236,4 +246,4 @@ struct PairwiseDistanceGemm // ops::has_cutlass_op #include // dispatch_sm60 #include // pairwise_matrix_params -#include // raft::util::arch::SM_* +#include // raft::distance::Compute_options +#include // raft::util::arch::SM_* // NOTE: to minimize compile times, we do not include dispatch_sm80.cuh. // Including dispatch_sm80.cuh can slow down compile times (due to CUTLASS). @@ -56,6 +57,7 @@ template void pairwise_matrix_sm80_dispatch(OpT, + Compute_options, pairwise_matrix_params, SM_compat_t, cudaStream_t); @@ -67,6 +69,7 @@ template void pairwise_matrix_dispatch(OpT distance_op, + // Compute_options compute_options, TODO. IdxT m, IdxT n, IdxT k, @@ -118,7 +121,8 @@ void pairwise_matrix_dispatch(OpT distance_op, if (cutlass_range.contains(runtime_arch)) { // If device is SM_80 or later, use CUTLASS-based kernel. - pairwise_matrix_sm80_dispatch(distance_op, params, cutlass_range, stream); + pairwise_matrix_sm80_dispatch( + distance_op, Compute_options::Fast_Similar_Precision, params, cutlass_range, stream); } else { // Reuse kernel wrapper that we obtained above. This avoids performing the // dispatch twice. diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_sm80.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_sm80.cuh index dc30f2f239..83c2814682 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_sm80.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_sm80.cuh @@ -18,6 +18,7 @@ #include // std::min #include // cutlassDistanceKernel #include // dispatch_layout +#include // raft::distance::Compute_options namespace raft::distance::detail { @@ -28,6 +29,7 @@ template void pairwise_matrix_sm80_dispatch(OpT distance_op, + Compute_options compute_options, pairwise_matrix_params params, SM_compat_t sm_compat_range, cudaStream_t stream) @@ -44,20 +46,22 @@ void pairwise_matrix_sm80_dispatch(OpT distance_op, constexpr int vec_len = std::min(vec_len_aligned(), static_cast(16 / sizeof(DataT))); using AccT = typename OpT::AccT; - cutlassDistanceKernel(params.x, - params.y, - params.x_norm, - params.y_norm, - params.m, - params.n, - params.k, - params.ldx, - params.ldy, - params.ld_out, - params.out, - params.fin_op, - distance_op, - stream); + cutlassDistanceKernel( + params.x, + params.y, + params.x_norm, + params.y_norm, + params.m, + params.n, + params.k, + params.ldx, + params.ldy, + params.ld_out, + params.out, + params.fin_op, + distance_op, + compute_options, + stream); }; // Dispatch_layout calls f with appropriate compile time constants based on From 6e06bc40865f0e7ffc571165b5539a89fc20f5fd Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Fri, 5 May 2023 18:48:47 +0200 Subject: [PATCH 4/4] Add throughput benchmark for CUTLASS Peak T ops/s = 74 T/s (1x tfloat) Peak T ops/s = 22 T/s (3x tfloat) This roughly corresponds to: (assuming 2 flops / core op) Peak T ops/s = 144 Tflop/s (1x tfloat) Peak T ops/s = 33 Tflop/s (3x tfloat) --- .../distance/tune_pairwise/bench_cutlass.cu | 149 ++++++++++++++++++ .../distance/tune_pairwise/kernel_cutlass.cu | 48 ++++++ .../distance/tune_pairwise/kernel_cutlass.cuh | 38 +++++ 3 files changed, 235 insertions(+) create mode 100644 cpp/bench/prims/distance/tune_pairwise/bench_cutlass.cu create mode 100644 cpp/bench/prims/distance/tune_pairwise/kernel_cutlass.cu create mode 100644 cpp/bench/prims/distance/tune_pairwise/kernel_cutlass.cuh diff --git a/cpp/bench/prims/distance/tune_pairwise/bench_cutlass.cu b/cpp/bench/prims/distance/tune_pairwise/bench_cutlass.cu new file mode 100644 index 0000000000..a55d9864dd --- /dev/null +++ b/cpp/bench/prims/distance/tune_pairwise/bench_cutlass.cu @@ -0,0 +1,149 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Tuning benchmarks. +// +// Goals: +// +// 1. Fast compile times to maintain iteration speed. +// 2. Create benchmarks that can inform the design of the kernels. +// +// Non-goals: +// +// 1. Measure every distance operation. Instead measures just one distance +// operation at the same time. +// 2. Be useful for finding performance regressions. This is handled by the +// normal benchmarks. +// +// So far, both goals are partly achieved. +// +// RE (1), COMPILE TIMES: kernel.cu is fast to compile. This file is not. +// When the internals of a pairwise distance kernel is changed, this file is not +// recompiled. +// +// RE 2, benchmarks with intent: this file contains a benchmark to check the +// maximal throughput of a kernel. Measuring other things, like performance on +// skinny or wide matrices is not yet implemented. + +#include "kernel_cutlass.cuh" // launch_kernel +#include // std::min +#include // RAFT_BENCH_REGISTER +#include // pairwise_matrix_params +#include // rmm::device_uvector +#include // std::vector + +namespace raft::bench::distance::tune_cutlass { + +// Max throughput benchmark. +// +// Goal: Measure the maximum distances/sec that can be computed. +// +// To achieve this, we make sure that: +// +// - Input data size is a multiple of the block tile size. +// +// - Perfect distribution of work between SMs, i.e. the number of block tiles is +// a large multiple (num_waves) of the number of blocks (#SMs * occupancy). +// +// - Multiple iterations over Kblk are executed (num_k_iters). +struct throughput_param { + int m, n, k; + bool use_1x_tfloat; +}; + +const std::vector throughput_params{ + {1024, 1024, 1024, true}, + {1024, 1024, 1 << 11, true}, + {1024, 1024, 1 << 12, true}, + {1024, 1024, 1 << 13, true}, + {1024, 1 << 14, 1024, true}, + {1024, 1 << 14, 1 << 11, true}, + {1024, 1 << 14, 1 << 12, true}, + {1024, 1 << 14, 1 << 13, true}, + + {1024, 1024, 1024, false}, + {1024, 1024, 1 << 11, false}, + {1024, 1024, 1 << 12, false}, + {1024, 1024, 1 << 13, false}, + {1024, 1 << 14, 1024, false}, + {1024, 1 << 14, 1 << 11, false}, + {1024, 1 << 14, 1 << 12, false}, + {1024, 1 << 14, 1 << 13, false}, +}; + +struct throughput_cutlass : public fixture { + const throughput_param p; + + throughput_cutlass(const throughput_param& p_) : p(p_) {} + + void run_benchmark(::benchmark::State& state) override + { + size_t m = p.m; + size_t n = p.n; + size_t k = p.k; + + // DataT, OutT, IdxT, etc, are defined in tuned_kernel.cuh + rmm::device_uvector x_vec(m * k, stream); + rmm::device_uvector y_vec(n * k, stream); + rmm::device_uvector x_norm_vec(m, stream); + rmm::device_uvector y_norm_vec(n, stream); + rmm::device_uvector out_vec(m * n, stream); + + auto x = x_vec.data(); + auto y = y_vec.data(); + auto x_norm = x_norm_vec.data(); + auto y_norm = y_norm_vec.data(); + auto out = out_vec.data(); + FinOpT fin_op{}; + + // Create kernel parameter struct. Flip x and y if column major. + IdxT ldx = row_major ? k : m; + IdxT ldy = row_major ? k : n; + IdxT ld_out = row_major ? n : m; + + // Template parameters of pairwise_matrix_params are defined in kernel.cuh + pairwise_matrix_params kparams{ + IdxT(m), IdxT(n), IdxT(k), ldx, ldy, ld_out, x, y, x_norm, y_norm, out, fin_op, row_major}; + + // Run benchmark + loop_on_state(state, [&]() { launch_kernel(kparams, p.use_1x_tfloat, stream); }); + + // Report metrics. We don't report flop/s because we do not know for each + // distance operation how many flops it costs. For L2_unexp and l1, we can + // double this number to get the flop/s. For l2 expanded, core_ops/s should + // equal flop/s (modulo the sqrt and subtracting from the norm). + size_t num_core_ops = m * n * k; + size_t read_elts = n * k + m * k; + size_t write_elts = m * n; + + state.counters["m"] = benchmark::Counter(m); + state.counters["n"] = benchmark::Counter(n); + state.counters["k"] = benchmark::Counter(k); + state.counters["1xtfloat"] = benchmark::Counter(p.use_1x_tfloat); + + state.counters["core_ops/s"] = benchmark::Counter(num_core_ops, + benchmark::Counter::kIsIterationInvariantRate, + benchmark::Counter::OneK::kIs1000); + + state.counters["BW"] = benchmark::Counter(write_elts * sizeof(OutT) + read_elts * sizeof(DataT), + benchmark::Counter::kIsIterationInvariantRate, + benchmark::Counter::OneK::kIs1000); + } +}; + +RAFT_BENCH_REGISTER(throughput_cutlass, "", throughput_params); + +} // namespace raft::bench::distance::tune_cutlass diff --git a/cpp/bench/prims/distance/tune_pairwise/kernel_cutlass.cu b/cpp/bench/prims/distance/tune_pairwise/kernel_cutlass.cu new file mode 100644 index 0000000000..834cb7d6be --- /dev/null +++ b/cpp/bench/prims/distance/tune_pairwise/kernel_cutlass.cu @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernel_cutlass.cuh" +#include // distance_op +#include +#include +#include // Compute_options +#include // raft::util::arch::SM_compute_arch + +namespace raft::bench::distance::tune_cutlass { + +// Distance op +using OpT = raft::distance::detail::ops::l2_exp_distance_op; + +constexpr bool perform_sqrt = false; +OpT distance_op{perform_sqrt}; + +// Architecture +namespace arch = raft::util::arch; +constexpr auto sm_compat_range = arch::SM_range(arch::SM_80(), arch::SM_future()); + +void launch_kernel(pairwise_matrix_params params, bool use_1x_tfloat, cudaStream_t stream) +{ + raft::distance::detail::pairwise_matrix_sm80_dispatch( + distance_op, + use_1x_tfloat ? raft::distance::Compute_options::Fast_Reduced_Precision + : raft::distance::Compute_options::Fast_Similar_Precision, + params, + sm_compat_range, + stream); + RAFT_CUDA_TRY(cudaGetLastError()); +} + +} // namespace raft::bench::distance::tune_cutlass diff --git a/cpp/bench/prims/distance/tune_pairwise/kernel_cutlass.cuh b/cpp/bench/prims/distance/tune_pairwise/kernel_cutlass.cuh new file mode 100644 index 0000000000..ff5396922c --- /dev/null +++ b/cpp/bench/prims/distance/tune_pairwise/kernel_cutlass.cuh @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // raft::identity_op +#include // pairwise_matrix_params + +namespace raft::bench::distance::tune_cutlass { + +// Launch one specific kernel with the following template parameters +constexpr bool row_major = true; +using DataT = float; +using AccT = float; +using OutT = DataT; +using IdxT = int; + +using FinOpT = raft::identity_op; + +using pairwise_matrix_params = + raft::distance::detail::pairwise_matrix_params; + +void launch_kernel(pairwise_matrix_params params, bool use_1x_tfloat, cudaStream_t stream); + +} // namespace raft::bench::distance::tune_cutlass