diff --git a/docs/building_the_project_with_dpcpp.rst b/docs/building_the_project_with_dpcpp.rst index 2fea9395f..0e46e8dc0 100644 --- a/docs/building_the_project_with_dpcpp.rst +++ b/docs/building_the_project_with_dpcpp.rst @@ -287,6 +287,9 @@ portBLAS relies heavily on JIT compilation. This may cause time-outs on some systems. To avoid this issue, use ahead-of-time compilation through tuning targets or ``sycl-targets``. +The ``sycl::half`` type can be supported by setting +``-DPORTBLAS_ENABLE_HALF=ON``. + .. _build_for_portfft_dpcpp: Building for portFFT diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 0beadc3ec..31c1b49ad 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -45,6 +45,11 @@ foreach(domain ${TARGET_DOMAINS}) add_subdirectory(${domain}) endforeach() +if (PORTBLAS_ENABLE_HALF) + # Set the variable used for C++ macro + set(ENABLE_PORTBLAS_HALF ON) +endif() + # Generate header with enabled backends for testing configure_file(config.hpp.in "${CMAKE_CURRENT_BINARY_DIR}/oneapi/mkl/config.hpp.configured") file(GENERATE diff --git a/src/blas/backends/portblas/CMakeLists.txt b/src/blas/backends/portblas/CMakeLists.txt index 03fddbb38..730cca92b 100644 --- a/src/blas/backends/portblas/CMakeLists.txt +++ b/src/blas/backends/portblas/CMakeLists.txt @@ -20,9 +20,8 @@ set(LIB_NAME onemkl_blas_portblas) set(LIB_OBJ ${LIB_NAME}_obj) -if(NOT DEFINED PORTBLAS_TUNING_TARGET) - option(PORTBLAS_TUNING_TARGET "Set a TUNING_TARGET for portBLAS" "") -endif() +option(PORTBLAS_TUNING_TARGET "Set a TUNING_TARGET for portBLAS" "") +option(PORTBLAS_ENABLE_HALF "Enable half support with the portBLAS backend" OFF) # Parse compiler flags and return a list of SYCL targets # The list is empty if no targets are set @@ -152,6 +151,9 @@ if (NOT PORTBLAS_FOUND) # Following variable TUNING_TARGET will be used in portBLAS internal configuration set(TUNING_TARGET ${PORTBLAS_TUNING_TARGET}) set(BLAS_ENABLE_COMPLEX ON) + if (PORTBLAS_ENABLE_HALF) + set(BLAS_ENABLE_HALF ON) + endif() # Set the policy to forward variables to portBLAS configure step set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) set(FETCHCONTENT_BASE_DIR "${CMAKE_BINARY_DIR}/deps") diff --git a/src/blas/backends/portblas/portblas_batch.cxx b/src/blas/backends/portblas/portblas_batch.cxx index 28c7ee5dc..1e11e8624 100644 --- a/src/blas/backends/portblas/portblas_batch.cxx +++ b/src/blas/backends/portblas/portblas_batch.cxx @@ -210,7 +210,12 @@ void gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl:: sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, sycl::half beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { - throw unimplemented("blas", "gemm_batch", " for complex"); +#ifdef ENABLE_PORTBLAS_HALF + CALL_PORTBLAS_FN(::blas::_gemm_strided_batched, queue, transa, transb, m, n, k, alpha, a, lda, + stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size); +#else + throw unimplemented("blas", "gemm_batch", " for half"); +#endif } void gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, @@ -219,7 +224,12 @@ void gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl:: sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { - throw unimplemented("blas", "gemm_batch", " for unsupported dtype"); +#ifdef ENABLE_PORTBLAS_HALF + CALL_PORTBLAS_FN(::blas::_gemm_strided_batched, queue, transa, transb, m, n, k, alpha, a, lda, + stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size); +#else + throw unimplemented("blas", "gemm_batch", " for half"); +#endif } void gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, @@ -228,7 +238,7 @@ void gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl:: sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { - throw unimplemented("blas", "gemm_batch", " for unsupported dtype"); + throw unimplemented("blas", "gemm_batch", " for int8"); } void gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, @@ -237,7 +247,7 @@ void gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl:: sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { - throw unimplemented("blas", "gemm_batch", " for unsupported dtype"); + throw unimplemented("blas", "gemm_batch", " for int8"); } void trsm_batch(sycl::queue &queue, oneapi::mkl::side left_right, oneapi::mkl::uplo upper_lower, @@ -686,7 +696,12 @@ sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, const float **b, std::int64_t *ldb, float *beta, float **c, std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "gemm_batch", " for USM"); + if (group_count != 1) { + throw unimplemented("blas", "gemm_batch", " using group API and group_count != 1"); + } + CALL_PORTBLAS_USM_FN(::blas::_gemm_batched, queue, transa[0], transb[0], m[0], n[0], k[0], + alpha[0], a[0], lda[0], b[0], ldb[0], beta[0], c[0], ldc[0], group_size[0], + ::blas::gemm_batch_type_t::strided, dependencies); } sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, @@ -695,7 +710,12 @@ sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, const double **b, std::int64_t *ldb, double *beta, double **c, std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "gemm_batch", " for USM"); + if (group_count != 1) { + throw unimplemented("blas", "gemm_batch", " using group API and group_count != 1"); + } + CALL_PORTBLAS_USM_FN(::blas::_gemm_batched, queue, transa[0], transb[0], m[0], n[0], k[0], + alpha[0], a[0], lda[0], b[0], ldb[0], beta[0], c[0], ldc[0], group_size[0], + ::blas::gemm_batch_type_t::strided, dependencies); } sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, @@ -705,7 +725,7 @@ sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, std::complex *beta, std::complex **c, std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "gemm_batch", " for USM"); + throw unimplemented("blas", "gemm_batch", " for USM using complex"); } sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, @@ -715,7 +735,7 @@ sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, std::complex *beta, std::complex **c, std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "gemm_batch", " for USM"); + throw unimplemented("blas", "gemm_batch", " for USM using complex"); } sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, @@ -724,7 +744,16 @@ sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, const sycl::half **b, std::int64_t *ldb, sycl::half *beta, sycl::half **c, std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "gemm_batch", " for USM"); +#ifdef ENABLE_PORTBLAS_HALF + if (group_count != 1) { + throw unimplemented("blas", "gemm_batch", " using group API and group_count != 1"); + } + CALL_PORTBLAS_USM_FN(::blas::_gemm_batched, queue, transa[0], transb[0], m[0], n[0], k[0], + alpha[0], a[0], lda[0], b[0], ldb[0], beta[0], c[0], ldc[0], group_size[0], + ::blas::gemm_batch_type_t::strided, dependencies); +#else + throw unimplemented("blas", "gemm_batch", " for USM using half"); +#endif } sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, @@ -733,7 +762,16 @@ sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, const sycl::half **b, std::int64_t *ldb, float *beta, float **c, std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "gemm_batch", " for USM"); +#ifdef ENABLE_PORTBLAS_HALF + if (group_count != 1) { + throw unimplemented("blas", "gemm_batch", " using group API and group_count != 1"); + } + CALL_PORTBLAS_USM_FN(::blas::_gemm_batched, queue, transa[0], transb[0], m[0], n[0], k[0], + alpha[0], a[0], lda[0], b[0], ldb[0], beta[0], c[0], ldc[0], group_size[0], + ::blas::gemm_batch_type_t::strided, dependencies); +#else + throw unimplemented("blas", "gemm_batch", " for USM using half"); +#endif } sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, @@ -742,7 +780,7 @@ sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, const std::int8_t **b, std::int64_t *ldb, float *beta, float **c, std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "gemm_batch", " for USM"); + throw unimplemented("blas", "gemm_batch", " for USM using int8"); } sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, @@ -751,7 +789,7 @@ sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, const std::int8_t **b, std::int64_t *ldb, float *beta, std::int32_t **c, std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "gemm_batch", " for USM"); + throw unimplemented("blas", "gemm_batch", " for USM using int8"); } sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, @@ -785,7 +823,7 @@ sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, std::int64_t ldb, std::int64_t stride_b, std::complex beta, std::complex *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "gemm_batch", " for USM"); + throw unimplemented("blas", "gemm_batch", " for USM using complex"); } sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, @@ -795,7 +833,7 @@ sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, std::int64_t ldb, std::int64_t stride_b, std::complex beta, std::complex *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "gemm_batch", " for USM"); + throw unimplemented("blas", "gemm_batch", " for USM using complex"); } sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, @@ -805,7 +843,13 @@ sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, std::int64_t stride_b, sycl::half beta, sycl::half *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "gemm_batch", " for USM"); +#ifdef ENABLE_PORTBLAS_HALF + CALL_PORTBLAS_USM_FN(::blas::_gemm_strided_batched, queue, transa, transb, m, n, k, alpha, a, + lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size, + dependencies); +#else + throw unimplemented("blas", "gemm_batch", " for USM using half"); +#endif } sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, @@ -815,7 +859,13 @@ sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, std::int64_t stride_b, float beta, float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "gemm_batch", " for USM"); +#ifdef ENABLE_PORTBLAS_HALF + CALL_PORTBLAS_USM_FN(::blas::_gemm_strided_batched, queue, transa, transb, m, n, k, alpha, a, + lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size, + dependencies); +#else + throw unimplemented("blas", "gemm_batch", " for USM using half"); +#endif } sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, @@ -825,7 +875,7 @@ sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, std::int64_t stride_b, float beta, float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "gemm_batch", " for USM"); + throw unimplemented("blas", "gemm_batch", " for USM using int8"); } sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, @@ -835,7 +885,7 @@ sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, std::int64_t stride_b, float beta, std::int32_t *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "gemm_batch", " for USM"); + throw unimplemented("blas", "gemm_batch", " for USM using int8"); } sycl::event trsm_batch(sycl::queue &queue, oneapi::mkl::side left_right, diff --git a/src/blas/backends/portblas/portblas_level3_half.cpp b/src/blas/backends/portblas/portblas_level3_half.cpp index 0e42528fa..b3c2a0837 100644 --- a/src/blas/backends/portblas/portblas_level3_half.cpp +++ b/src/blas/backends/portblas/portblas_level3_half.cpp @@ -23,6 +23,7 @@ #include #endif +#include "portblas_common.hpp" #include "oneapi/mkl/exceptions.hpp" #include "oneapi/mkl/blas/detail/portblas/onemkl_blas_portblas.hpp" @@ -32,19 +33,33 @@ namespace blas { namespace portblas { namespace column_major { +constexpr bool is_column_major() { + return true; +} + // BUFFER void gemm(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, sycl::half alpha, sycl::buffer &a, std::int64_t lda, sycl::buffer &b, std::int64_t ldb, sycl::half beta, sycl::buffer &c, std::int64_t ldc) { +#ifdef ENABLE_PORTBLAS_HALF + CALL_PORTBLAS_FN(::blas::_gemm, queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, + ldc); +#else throw unimplemented("blas", "gemm", " half"); +#endif } void gemm(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, sycl::buffer &a, std::int64_t lda, sycl::buffer &b, std::int64_t ldb, float beta, sycl::buffer &c, std::int64_t ldc) { +#ifdef ENABLE_PORTBLAS_HALF + CALL_PORTBLAS_FN(::blas::_gemm, queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, + ldc); +#else throw unimplemented("blas", "gemm", " for different argument data types"); +#endif } // USM @@ -53,31 +68,56 @@ sycl::event gemm(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl: const sycl::half *a, std::int64_t lda, const sycl::half *b, std::int64_t ldb, sycl::half beta, sycl::half *c, std::int64_t ldc, const std::vector &dependencies) { +#ifdef ENABLE_PORTBLAS_HALF + CALL_PORTBLAS_USM_FN(::blas::_gemm, queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, + c, ldc, dependencies); +#else throw unimplemented("blas", "gemm", " for USM"); +#endif } sycl::event gemm(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const sycl::half *a, std::int64_t lda, const sycl::half *b, std::int64_t ldb, float beta, float *c, std::int64_t ldc, const std::vector &dependencies) { +#ifdef ENABLE_PORTBLAS_HALF + CALL_PORTBLAS_USM_FN(::blas::_gemm, queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, + c, ldc, dependencies); +#else throw unimplemented("blas", "gemm", " for USM"); +#endif } } // namespace column_major + namespace row_major { +constexpr bool is_column_major() { + return false; +} + // BUFFER void gemm(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, sycl::half alpha, sycl::buffer &a, std::int64_t lda, sycl::buffer &b, std::int64_t ldb, sycl::half beta, sycl::buffer &c, std::int64_t ldc) { +#ifdef ENABLE_PORTBLAS_HALF + CALL_PORTBLAS_FN(::blas::_gemm, queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, + ldc); +#else throw unimplemented("blas", "gemm", " half"); +#endif } void gemm(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, sycl::buffer &a, std::int64_t lda, sycl::buffer &b, std::int64_t ldb, float beta, sycl::buffer &c, std::int64_t ldc) { +#ifdef ENABLE_PORTBLAS_HALF + CALL_PORTBLAS_FN(::blas::_gemm, queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, + ldc); +#else throw unimplemented("blas", "gemm", " for different argument data types"); +#endif } // USM @@ -86,14 +126,24 @@ sycl::event gemm(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl: const sycl::half *a, std::int64_t lda, const sycl::half *b, std::int64_t ldb, sycl::half beta, sycl::half *c, std::int64_t ldc, const std::vector &dependencies) { +#ifdef ENABLE_PORTBLAS_HALF + CALL_PORTBLAS_USM_FN(::blas::_gemm, queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, + c, ldc, dependencies); +#else throw unimplemented("blas", "gemm", " for USM"); +#endif } sycl::event gemm(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const sycl::half *a, std::int64_t lda, const sycl::half *b, std::int64_t ldb, float beta, float *c, std::int64_t ldc, const std::vector &dependencies) { +#ifdef ENABLE_PORTBLAS_HALF + CALL_PORTBLAS_USM_FN(::blas::_gemm, queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, + c, ldc, dependencies); +#else throw unimplemented("blas", "gemm", " for USM"); +#endif } } // namespace row_major diff --git a/src/config.hpp.in b/src/config.hpp.in index 5698abf9b..0c23e04da 100644 --- a/src/config.hpp.in +++ b/src/config.hpp.in @@ -32,6 +32,7 @@ #cmakedefine ENABLE_PORTBLAS_BACKEND_INTEL_CPU #cmakedefine ENABLE_PORTBLAS_BACKEND_INTEL_GPU #cmakedefine ENABLE_PORTBLAS_BACKEND_NVIDIA_GPU +#cmakedefine ENABLE_PORTBLAS_HALF #cmakedefine ENABLE_PORTFFT_BACKEND #cmakedefine ENABLE_ROCBLAS_BACKEND #cmakedefine ENABLE_ROCFFT_BACKEND diff --git a/tests/unit_tests/blas/batch/gemm_batch_usm.cpp b/tests/unit_tests/blas/batch/gemm_batch_usm.cpp index a651f9ae3..c697e644a 100644 --- a/tests/unit_tests/blas/batch/gemm_batch_usm.cpp +++ b/tests/unit_tests/blas/batch/gemm_batch_usm.cpp @@ -364,58 +364,79 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { } class GemmBatchUsmTests - : public ::testing::TestWithParam> {}; + : public ::testing::TestWithParam> {}; TEST_P(GemmBatchUsmTests, RealHalfPrecision) { + int group_count = std::get<2>(GetParam()); EXPECT_TRUEORSKIP((test( - std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); + std::get<0>(GetParam()), std::get<1>(GetParam()), group_count))); } TEST_P(GemmBatchUsmTests, HalfHalfFloatPrecision) { - EXPECT_TRUEORSKIP((test(std::get<0>(GetParam()), - std::get<1>(GetParam()), 5))); + int group_count = std::get<2>(GetParam()); + EXPECT_TRUEORSKIP((test( + std::get<0>(GetParam()), std::get<1>(GetParam()), group_count))); } TEST_P(GemmBatchUsmTests, Int8Int8SinglePrecision) { - EXPECT_TRUEORSKIP((test(std::get<0>(GetParam()), - std::get<1>(GetParam()), 5))); + int group_count = std::get<2>(GetParam()); + EXPECT_TRUEORSKIP((test( + std::get<0>(GetParam()), std::get<1>(GetParam()), group_count))); } TEST_P(GemmBatchUsmTests, Int8Int8Int32Precision) { + int group_count = std::get<2>(GetParam()); EXPECT_TRUEORSKIP((test( - std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); + std::get<0>(GetParam()), std::get<1>(GetParam()), group_count))); } TEST_P(GemmBatchUsmTests, RealSinglePrecision) { - EXPECT_TRUEORSKIP( - (test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); + int group_count = std::get<2>(GetParam()); + EXPECT_TRUEORSKIP((test(std::get<0>(GetParam()), + std::get<1>(GetParam()), group_count))); } TEST_P(GemmBatchUsmTests, RealDoublePrecision) { CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); - EXPECT_TRUEORSKIP(( - test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); + int group_count = std::get<2>(GetParam()); + EXPECT_TRUEORSKIP((test(std::get<0>(GetParam()), + std::get<1>(GetParam()), group_count))); } TEST_P(GemmBatchUsmTests, ComplexSinglePrecision) { + int group_count = std::get<2>(GetParam()); EXPECT_TRUEORSKIP( (test, std::complex, std::complex, std::complex>( - std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); + std::get<0>(GetParam()), std::get<1>(GetParam()), group_count))); } TEST_P(GemmBatchUsmTests, ComplexDoublePrecision) { CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); - EXPECT_TRUEORSKIP( - (test, std::complex, std::complex, - std::complex>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); + int group_count = std::get<2>(GetParam()); + EXPECT_TRUEORSKIP(( + test, std::complex, std::complex, + std::complex>(std::get<0>(GetParam()), std::get<1>(GetParam()), group_count))); } +class GemmBatchGroupNamePrint { +public: + std::string operator()( + testing::TestParamInfo> params) const { + std::string base_name = LayoutDeviceNamePrint()( + { { std::get<0>(params.param), std::get<1>(params.param) }, 0 }); + std::string group_name = "GroupCount_" + std::to_string(std::get<2>(params.param)); + std::string info_name = base_name + "_" + group_name; + return info_name; + } +}; + INSTANTIATE_TEST_SUITE_P(GemmBatchUsmTestSuite, GemmBatchUsmTests, ::testing::Combine(testing::ValuesIn(devices), testing::Values(oneapi::mkl::layout::col_major, - oneapi::mkl::layout::row_major)), - ::LayoutDeviceNamePrint()); + oneapi::mkl::layout::row_major), + testing::Values(1, 5)), + ::GemmBatchGroupNamePrint()); } // anonymous namespace diff --git a/tests/unit_tests/blas/include/test_common.hpp b/tests/unit_tests/blas/include/test_common.hpp index 5d607991e..8974d39c6 100644 --- a/tests/unit_tests/blas/include/test_common.hpp +++ b/tests/unit_tests/blas/include/test_common.hpp @@ -120,7 +120,10 @@ struct ref_type_info { // Random initialization. template static fp rand_scalar() { +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wimplicit-const-int-float-conversion" return fp(std::rand()) / fp(RAND_MAX) - fp(0.5); +#pragma clang diagnostic pop } template static std::complex rand_complex_scalar() {