Skip to content

Commit

Permalink
Disallow symmetric/hermitian conjtrans configurations for spmv
Browse files Browse the repository at this point in the history
  • Loading branch information
Rbiessy committed Jun 7, 2024
1 parent 5cb4518 commit 7190c6a
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 26 deletions.
22 changes: 15 additions & 7 deletions src/sparse_blas/backends/mkl_common/mkl_spmv.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ sycl::event release_spmv_descr(sycl::queue &queue, oneapi::mkl::sparse::spmv_des
}

void check_valid_spmv(const std::string function_name, sycl::queue &queue,
oneapi::mkl::sparse::matrix_view A_view,
oneapi::mkl::transpose opA, oneapi::mkl::sparse::matrix_view A_view,
oneapi::mkl::sparse::matrix_handle_t A_handle,
oneapi::mkl::sparse::dense_vector_handle_t x_handle,
oneapi::mkl::sparse::dense_vector_handle_t y_handle, const void *alpha,
Expand All @@ -51,14 +51,22 @@ void check_valid_spmv(const std::string function_name, sycl::queue &queue,
}

if (A_view.type_view != oneapi::mkl::sparse::matrix_descr::triangular &&
A_view.diag_view != oneapi::mkl::diag::nonunit) {
A_view.diag_view == oneapi::mkl::diag::unit) {
throw mkl::invalid_argument(
"sparse_blas", function_name,
"`unit` diag_view can only be used with a triangular type_view.");
}

if ((A_view.type_view == oneapi::mkl::sparse::matrix_descr::symmetric ||
A_view.type_view == oneapi::mkl::sparse::matrix_descr::hermitian) &&
opA == oneapi::mkl::transpose::conjtrans) {
throw mkl::invalid_argument(
"sparse_blas", function_name,
"Symmetric or Hermitian matrix cannot be conjugated with `conjtrans`.");
}
}

void spmv_buffer_size(sycl::queue &queue, oneapi::mkl::transpose /*opA*/, const void *alpha,
void spmv_buffer_size(sycl::queue &queue, oneapi::mkl::transpose opA, const void *alpha,
oneapi::mkl::sparse::matrix_view A_view,
oneapi::mkl::sparse::matrix_handle_t A_handle,
oneapi::mkl::sparse::dense_vector_handle_t x_handle, const void *beta,
Expand All @@ -67,7 +75,7 @@ void spmv_buffer_size(sycl::queue &queue, oneapi::mkl::transpose /*opA*/, const
oneapi::mkl::sparse::spmv_descr_t /*spmv_descr*/,
std::size_t &temp_buffer_size) {
// TODO: Add support for external workspace once the close-source oneMKL backend supports it.
check_valid_spmv(__FUNCTION__, queue, A_view, A_handle, x_handle, y_handle, alpha, beta);
check_valid_spmv(__FUNCTION__, queue, opA, A_view, A_handle, x_handle, y_handle, alpha, beta);
temp_buffer_size = 0;
}

Expand All @@ -79,7 +87,7 @@ void spmv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, const void *a
oneapi::mkl::sparse::spmv_alg alg,
oneapi::mkl::sparse::spmv_descr_t /*spmv_descr*/,
sycl::buffer<std::uint8_t, 1> /*workspace*/) {
check_valid_spmv(__FUNCTION__, queue, A_view, A_handle, x_handle, y_handle, alpha, beta);
check_valid_spmv(__FUNCTION__, queue, opA, A_view, A_handle, x_handle, y_handle, alpha, beta);
auto internal_A_handle = detail::get_internal_handle(A_handle);
if (!internal_A_handle->all_use_buffer()) {
detail::throw_incompatible_container(__FUNCTION__);
Expand Down Expand Up @@ -113,7 +121,7 @@ sycl::event spmv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, const
oneapi::mkl::sparse::spmv_alg alg,
oneapi::mkl::sparse::spmv_descr_t /*spmv_descr*/, void * /*workspace*/,
const std::vector<sycl::event> &dependencies) {
check_valid_spmv(__FUNCTION__, queue, A_view, A_handle, x_handle, y_handle, alpha, beta);
check_valid_spmv(__FUNCTION__, queue, opA, A_view, A_handle, x_handle, y_handle, alpha, beta);
auto internal_A_handle = detail::get_internal_handle(A_handle);
if (internal_A_handle->all_use_buffer()) {
detail::throw_incompatible_container(__FUNCTION__);
Expand Down Expand Up @@ -196,7 +204,7 @@ sycl::event spmv(sycl::queue &queue, oneapi::mkl::transpose opA, const void *alp
oneapi::mkl::sparse::dense_vector_handle_t y_handle,
oneapi::mkl::sparse::spmv_alg alg, oneapi::mkl::sparse::spmv_descr_t spmv_descr,
const std::vector<sycl::event> &dependencies) {
check_valid_spmv(__FUNCTION__, queue, A_view, A_handle, x_handle, y_handle, alpha, beta);
check_valid_spmv(__FUNCTION__, queue, opA, A_view, A_handle, x_handle, y_handle, alpha, beta);
auto value_type = detail::get_internal_handle(A_handle)->get_value_type();
DISPATCH_MKL_OPERATION("spmv", value_type, internal_spmv, queue, opA, alpha, A_view, A_handle,
x_handle, beta, y_handle, alg, spmv_descr, dependencies);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,7 @@ std::vector<fpType> sparse_to_dense(sparse_matrix_format_t format, const intType
const bool is_symmetric_or_hermitian_view =
type_view == oneapi::mkl::sparse::matrix_descr::symmetric ||
type_view == oneapi::mkl::sparse::matrix_descr::hermitian;
// Matrices are not conjugated if they are symmetric
const bool apply_conjugate =
!is_symmetric_or_hermitian_view && transpose_val == oneapi::mkl::transpose::conjtrans;
const bool apply_conjugate = transpose_val == oneapi::mkl::transpose::conjtrans;
std::vector<fpType> dense_a(a_nrows * a_ncols, fpType(0));

auto write_to_dense_if_needed = [&](std::size_t a_idx, std::size_t row, std::size_t col) {
Expand Down
34 changes: 18 additions & 16 deletions tests/unit_tests/sparse_blas/include/test_spmv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,22 +143,24 @@ void test_helper_with_format(
fp_one, fp_zero, default_alg, triangular_unit_A_view, no_properties,
no_reset_data),
num_passed, num_skipped);
// Lower symmetric or hermitian
oneapi::mkl::sparse::matrix_view symmetric_view(
complex_info<fpType>::is_complex ? oneapi::mkl::sparse::matrix_descr::hermitian
: oneapi::mkl::sparse::matrix_descr::symmetric);
EXPECT_TRUE_OR_FUTURE_SKIP(
test_functor_i32(dev, format, nrows_A, ncols_A, density_A_matrix, index_zero, transpose_val,
fp_one, fp_zero, default_alg, symmetric_view, no_properties,
no_reset_data),
num_passed, num_skipped);
// Upper symmetric or hermitian
symmetric_view.uplo_view = oneapi::mkl::uplo::upper;
EXPECT_TRUE_OR_FUTURE_SKIP(
test_functor_i32(dev, format, nrows_A, ncols_A, density_A_matrix, index_zero, transpose_val,
fp_one, fp_zero, default_alg, symmetric_view, no_properties,
no_reset_data),
num_passed, num_skipped);
if (transpose_val != oneapi::mkl::transpose::conjtrans) {
// Lower symmetric or hermitian
oneapi::mkl::sparse::matrix_view symmetric_view(
complex_info<fpType>::is_complex ? oneapi::mkl::sparse::matrix_descr::hermitian
: oneapi::mkl::sparse::matrix_descr::symmetric);
EXPECT_TRUE_OR_FUTURE_SKIP(
test_functor_i32(dev, format, nrows_A, ncols_A, density_A_matrix, index_zero,
transpose_val, fp_one, fp_zero, default_alg, symmetric_view,
no_properties, no_reset_data),
num_passed, num_skipped);
// Upper symmetric or hermitian
symmetric_view.uplo_view = oneapi::mkl::uplo::upper;
EXPECT_TRUE_OR_FUTURE_SKIP(
test_functor_i32(dev, format, nrows_A, ncols_A, density_A_matrix, index_zero,
transpose_val, fp_one, fp_zero, default_alg, symmetric_view,
no_properties, no_reset_data),
num_passed, num_skipped);
}
// Test other algorithms
for (auto alg : non_default_algorithms) {
EXPECT_TRUE_OR_FUTURE_SKIP(test_functor_i32(dev, format, nrows_A, ncols_A, density_A_matrix,
Expand Down

0 comments on commit 7190c6a

Please sign in to comment.