Skip to content

Commit

Permalink
Add missing tests for serial Trsm
Browse files Browse the repository at this point in the history
Signed-off-by: Yuuichi Asahi <[email protected]>
  • Loading branch information
Yuuichi Asahi committed Nov 18, 2024
1 parent f3b887e commit ed75ca4
Show file tree
Hide file tree
Showing 4 changed files with 841 additions and 169 deletions.
45 changes: 45 additions & 0 deletions batched/dense/unit_test/Test_Batched_DenseUtils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,51 @@ void banded_to_full(InViewType& in, OutViewType& out, int k = 1) {
Kokkos::deep_copy(out, h_out);
}

/// \brief Create a triangular matrix from an input matrix:
/// Copies the input matrix into the upper/lower triangular of the output matrix specified
/// by the parameter k. Zero out elements below/above the k-th diagonal.
///
/// \tparam InViewType: Input type for the matrix, needs to be a 3D view
/// \tparam OutViewType: Output type for the matrix, needs to be a 3D view
/// \tparam UploType: Type indicating whether the matrix is upper or lower triangular
/// \tparam DiagType: Type indicating whether the matrix is unit or non-unit diagonal
///
/// \param in [in]: Input batched matrix, a rank 3 view
/// \param out [out]: Output batched matrix, where the upper or lower
/// triangular components are kept, a rank 3 view
/// \param k [in]: The diagonal offset to be zero out (default is 0).
///
template <typename InViewType, typename OutViewType, typename UploType, typename DiagType>
void create_triangular_matrix(InViewType& in, OutViewType& out, int k = 0) {
auto h_in = Kokkos::create_mirror_view(in);
auto h_out = Kokkos::create_mirror_view(out);
const int N = in.extent(0), BlkSize = in.extent(1);

Kokkos::deep_copy(h_in, in);
Kokkos::deep_copy(h_out, 0.0);
for (int i0 = 0; i0 < N; i0++) {
for (int i1 = 0; i1 < BlkSize; i1++) {
for (int i2 = 0; i2 < BlkSize; i2++) {
if constexpr (std::is_same_v<UploType, KokkosBatched::Uplo::Upper>) {
// Upper
// Zero out elements below the k-th diagonal
h_out(i0, i1, i2) = i2 < i1 + k ? 0.0 : h_in(i0, i1, i2);
} else {
// Lower
// Zero out elements above the k-th diagonal
h_out(i0, i1, i2) = i2 > i1 + k ? 0.0 : h_in(i0, i1, i2);
}
}

if constexpr (std::is_same_v<DiagType, KokkosBatched::Diag::Unit>) {
// Unit diagonal
h_out(i0, i1, i1) = 1.0;
}
}
}
Kokkos::deep_copy(out, h_out);
}

} // namespace KokkosBatched

#endif // TEST_BATCHED_DENSE_HELPER_HPP
Loading

0 comments on commit ed75ca4

Please sign in to comment.