Skip to content

Commit

Permalink
masked_matmul supports bitset mask
Browse files Browse the repository at this point in the history
  • Loading branch information
rhdong committed Dec 9, 2024
1 parent 5e259e9 commit ab6d71e
Show file tree
Hide file tree
Showing 4 changed files with 276 additions and 64 deletions.
109 changes: 76 additions & 33 deletions cpp/bench/prims/linalg/masked_matmul.cu
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,14 @@ inline auto operator<<(std::ostream& os, const MaskedMatmulBenchParams<value_t>&
{
os << " m*k*n=" << params.m << "*" << params.k << "*" << params.n
<< "\tsparsity=" << params.sparsity;
if (params.sparsity == 1.0) { os << "<-inner product for comparison"; }
if (params.sparsity == 0.0) { os << "<-inner product for comparison"; }
return os;
}

template <typename value_t, typename index_t = int64_t, typename bitmap_t = uint32_t>
template <typename value_t,
bool bitmap_or_bitset = true,
typename index_t = int64_t,
typename bits_t = uint32_t>
struct MaskedMatmulBench : public fixture {
MaskedMatmulBench(const MaskedMatmulBenchParams<value_t>& p)
: fixture(true),
Expand All @@ -64,15 +67,15 @@ struct MaskedMatmulBench : public fixture {
c_indptr_d(0, stream),
c_indices_d(0, stream),
c_data_d(0, stream),
bitmap_d(0, stream),
bits_d(0, stream),
c_dense_data_d(0, stream)
{
index_t element = raft::ceildiv(index_t(params.m * params.n), index_t(sizeof(bitmap_t) * 8));
std::vector<bitmap_t> bitmap_h(element);
index_t element = raft::ceildiv(index_t(params.m * params.n), index_t(sizeof(bits_t) * 8));
std::vector<bits_t> bits_h(element);

a_data_d.resize(params.m * params.k, stream);
b_data_d.resize(params.k * params.n, stream);
bitmap_d.resize(element, stream);
bits_d.resize(element, stream);

raft::random::RngState rng(2024ULL);
raft::random::uniform(
Expand All @@ -82,7 +85,13 @@ struct MaskedMatmulBench : public fixture {

std::vector<bool> c_dense_data_h(params.m * params.n);

c_true_nnz = create_sparse_matrix(params.m, params.n, params.sparsity, bitmap_h);
if constexpr (bitmap_or_bitset) {
c_true_nnz = create_sparse_matrix(params.m, params.n, params.sparsity, bits_h);
} else {
c_true_nnz = create_sparse_matrix(1, params.n, params.sparsity, bits_h);
repeat_cpu_bitset_inplace(bits_h, params.n, params.m - 1);
c_true_nnz *= params.m;
}

std::vector<value_t> values(c_true_nnz);
std::vector<index_t> indices(c_true_nnz);
Expand All @@ -93,24 +102,49 @@ struct MaskedMatmulBench : public fixture {
c_indices_d.resize(c_true_nnz, stream);
c_dense_data_d.resize(params.m * params.n, stream);

cpu_convert_to_csr(bitmap_h, params.m, params.n, indices, indptr);
cpu_convert_to_csr(bits_h, params.m, params.n, indices, indptr);
RAFT_EXPECTS(c_true_nnz == c_indices_d.size(),
"Something wrong. The c_true_nnz != c_indices_d.size()!");

update_device(c_data_d.data(), values.data(), c_true_nnz, stream);
update_device(c_indices_d.data(), indices.data(), c_true_nnz, stream);
update_device(c_indptr_d.data(), indptr.data(), params.m + 1, stream);
update_device(bitmap_d.data(), bitmap_h.data(), element, stream);
update_device(bits_d.data(), bits_h.data(), element, stream);
}

void repeat_cpu_bitset_inplace(std::vector<bits_t>& inout, size_t input_bits, size_t repeat)
{
size_t output_bit_index = input_bits;

for (size_t r = 0; r < repeat; ++r) {
for (size_t i = 0; i < input_bits; ++i) {
size_t input_unit_index = i / (sizeof(bits_t) * 8);
size_t input_bit_offset = i % (sizeof(bits_t) * 8);
bool bit = (inout[input_unit_index] >> input_bit_offset) & 1;

size_t output_unit_index = output_bit_index / (sizeof(bits_t) * 8);
size_t output_bit_offset = output_bit_index % (sizeof(bits_t) * 8);

inout[output_unit_index] |= (static_cast<bits_t>(bit) << output_bit_offset);

++output_bit_index;
}
}
}

index_t create_sparse_matrix(index_t m, index_t n, float sparsity, std::vector<bitmap_t>& bitmap)
index_t create_sparse_matrix(index_t m, index_t n, float sparsity, std::vector<bits_t>& bits)
{
index_t total = static_cast<index_t>(m * n);
index_t num_ones = static_cast<index_t>((total * 1.0f) * sparsity);
index_t num_ones = static_cast<index_t>((total * 1.0f) * (1.0f - sparsity));
index_t res = num_ones;

for (auto& item : bitmap) {
item = static_cast<bitmap_t>(0);
if (sparsity == 0.0f) {
std::fill(bits.begin(), bits.end(), 0xffffffff);
return num_ones;
}

for (auto& item : bits) {
item = static_cast<bits_t>(0);
}

std::random_device rd;
Expand All @@ -120,8 +154,8 @@ struct MaskedMatmulBench : public fixture {
while (num_ones > 0) {
index_t index = dis(gen);

bitmap_t& element = bitmap[index / (8 * sizeof(bitmap_t))];
index_t bit_position = index % (8 * sizeof(bitmap_t));
bits_t& element = bits[index / (8 * sizeof(bits_t))];
index_t bit_position = index % (8 * sizeof(bits_t));

if (((element >> bit_position) & 1) == 0) {
element |= (static_cast<index_t>(1) << bit_position);
Expand All @@ -131,7 +165,7 @@ struct MaskedMatmulBench : public fixture {
return res;
}

void cpu_convert_to_csr(std::vector<bitmap_t>& bitmap,
void cpu_convert_to_csr(std::vector<bits_t>& bits,
index_t rows,
index_t cols,
std::vector<index_t>& indices,
Expand All @@ -142,14 +176,14 @@ struct MaskedMatmulBench : public fixture {
indptr[offset_indptr++] = 0;

index_t index = 0;
bitmap_t element = 0;
bits_t element = 0;
index_t bit_position = 0;

for (index_t i = 0; i < rows; ++i) {
for (index_t j = 0; j < cols; ++j) {
index = i * cols + j;
element = bitmap[index / (8 * sizeof(bitmap_t))];
bit_position = index % (8 * sizeof(bitmap_t));
element = bits[index / (8 * sizeof(bits_t))];
bit_position = index % (8 * sizeof(bits_t));

if (((element >> bit_position) & 1)) {
indices[offset_values] = static_cast<index_t>(j);
Expand Down Expand Up @@ -181,13 +215,17 @@ struct MaskedMatmulBench : public fixture {
params.n,
static_cast<index_t>(c_indices_d.size()));

auto mask =
raft::core::bitmap_view<const bitmap_t, index_t>(bitmap_d.data(), params.m, params.n);

auto c = raft::make_device_csr_matrix_view<value_t>(c_data_d.data(), c_structure);

if (params.sparsity < 1.0) {
raft::sparse::linalg::masked_matmul(handle, a, b, mask, c);
if (params.sparsity > 0.0) {
if constexpr (bitmap_or_bitset) {
auto mask =
raft::core::bitmap_view<const bits_t, index_t>(bits_d.data(), params.m, params.n);
raft::sparse::linalg::masked_matmul(handle, a, b, mask, c);
} else {
auto mask = raft::core::bitset_view<const bits_t, index_t>(bits_d.data(), params.n);
raft::sparse::linalg::masked_matmul(handle, a, b, mask, c);
}
} else {
raft::distance::pairwise_distance(handle,
a_data_d.data(),
Expand All @@ -201,12 +239,16 @@ struct MaskedMatmulBench : public fixture {
}
resource::sync_stream(handle);

raft::sparse::linalg::masked_matmul(handle, a, b, mask, c);
resource::sync_stream(handle);

loop_on_state(state, [this, &a, &b, &mask, &c]() {
if (params.sparsity < 1.0) {
raft::sparse::linalg::masked_matmul(handle, a, b, mask, c);
loop_on_state(state, [this, &a, &b, &c]() {
if (params.sparsity > 0.0) {
if constexpr (bitmap_or_bitset) {
auto mask =
raft::core::bitmap_view<const bits_t, index_t>(bits_d.data(), params.m, params.n);
raft::sparse::linalg::masked_matmul(handle, a, b, mask, c);
} else {
auto mask = raft::core::bitset_view<const bits_t, index_t>(bits_d.data(), params.n);
raft::sparse::linalg::masked_matmul(handle, a, b, mask, c);
}
} else {
raft::distance::pairwise_distance(handle,
a_data_d.data(),
Expand All @@ -228,7 +270,7 @@ struct MaskedMatmulBench : public fixture {

rmm::device_uvector<value_t> a_data_d;
rmm::device_uvector<value_t> b_data_d;
rmm::device_uvector<bitmap_t> bitmap_d;
rmm::device_uvector<bits_t> bits_d;

rmm::device_uvector<value_t> c_dense_data_d;

Expand All @@ -253,7 +295,7 @@ static std::vector<MaskedMatmulBenchParams<value_t>> getInputs()
raft::util::itertools::product<TestParams>({size_t(10), size_t(1024)},
{size_t(128), size_t(1024)},
{size_t(1024 * 1024)},
{0.01f, 0.1f, 0.2f, 0.5f, 1.0f});
{0.99f, 0.9f, 0.8f, 0.5f, 0.0f});

param_vec.reserve(params_group.size());
for (TestParams params : params_group) {
Expand All @@ -263,6 +305,7 @@ static std::vector<MaskedMatmulBenchParams<value_t>> getInputs()
return param_vec;
}

RAFT_BENCH_REGISTER((MaskedMatmulBench<float>), "", getInputs<float>());
RAFT_BENCH_REGISTER((MaskedMatmulBench<float, true>), "", getInputs<float>());
RAFT_BENCH_REGISTER((MaskedMatmulBench<float, false>), "", getInputs<float>());

} // namespace raft::bench::linalg
64 changes: 64 additions & 0 deletions cpp/include/raft/sparse/linalg/detail/masked_matmul.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#pragma once

#include <raft/core/bitmap.cuh>
#include <raft/core/bitset.cuh>
#include <raft/core/device_csr_matrix.hpp>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_mdspan.hpp>
Expand Down Expand Up @@ -100,6 +101,69 @@ void masked_matmul(raft::resources const& handle,
}
}

template <typename value_t, typename output_t, typename index_t, typename nnz_t, typename bitset_t>
void masked_matmul(raft::resources const& handle,
raft::device_matrix_view<const value_t, index_t, raft::row_major>& A,
raft::device_matrix_view<const value_t, index_t, raft::row_major>& B,
raft::core::bitset_view<const bitset_t, index_t>& mask,
raft::device_csr_matrix_view<output_t, index_t, index_t, nnz_t>& C,
std::optional<raft::host_scalar_view<output_t>> alpha,
std::optional<raft::host_scalar_view<output_t>> beta)
{
index_t m = A.extent(0);
index_t n = B.extent(0);
index_t dim = A.extent(1);

auto compressed_C_view = C.structure_view();

RAFT_EXPECTS(A.extent(1) == B.extent(1), "The dim of A must be equal to the dim of B.");
RAFT_EXPECTS(A.extent(0) == compressed_C_view.get_n_rows(),
"Number of rows in C must match the number of rows in A.");
RAFT_EXPECTS(B.extent(0) == compressed_C_view.get_n_cols(),
"Number of columns in C must match the number of columns in B.");

auto stream = raft::resource::get_cuda_stream(handle);

auto C_matrix = raft::make_device_csr_matrix<output_t, index_t>(handle, compressed_C_view);

// fill C
raft::sparse::convert::bitset_to_csr(handle, mask, C_matrix);

if (m > 10 || alpha.has_value() || beta.has_value()) {
auto C_view = raft::make_device_csr_matrix_view<output_t, index_t, index_t, index_t>(
C.get_elements().data(), compressed_C_view);

// create B col_major view
auto B_col_major = raft::make_device_matrix_view<const value_t, index_t, raft::col_major>(
B.data_handle(), dim, n);

output_t default_alpha = static_cast<output_t>(1.0f);
output_t default_beta = static_cast<output_t>(0.0f);

if (!alpha.has_value()) { alpha = raft::make_host_scalar_view<output_t>(&default_alpha); }
if (!beta.has_value()) { beta = raft::make_host_scalar_view<output_t>(&default_beta); }

raft::sparse::linalg::sddmm(handle,
A,
B_col_major,
C_view,
raft::linalg::Operation::NON_TRANSPOSE,
raft::linalg::Operation::NON_TRANSPOSE,
*alpha,
*beta);
} else {
raft::sparse::distance::detail::faster_dot_on_csr(handle,
C.get_elements().data(),
compressed_C_view.get_nnz(),
compressed_C_view.get_indptr().data(),
compressed_C_view.get_indices().data(),
A.data_handle(),
B.data_handle(),
compressed_C_view.get_n_rows(),
dim);
}
}

} // namespace detail
} // namespace linalg
} // namespace sparse
Expand Down
45 changes: 45 additions & 0 deletions cpp/include/raft/sparse/linalg/masked_matmul.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,51 @@ void masked_matmul(raft::resources const& handle,
detail::masked_matmul(handle, A, B, mask, C, alpha, beta);
}

/**
* @brief Computes a sparse matrix product with a masked sparsity pattern and scaling.
*
* This function computes the result of:
* C = alpha * ((A * B) ∘ spy(mask)) + beta * C
* where:
* - A and B are dense input matrices.
* - "mask" defines the sparsity pattern for element-wise multiplication.
* - The result is scaled by alpha and added to beta times the original C.
*
* **Special behavior of the mask**:
* - The `bitset` mask represents a single row of data, with its bits indicating whether
* each corresponding element in (A * B) is included (1) or masked out (0).
* - If the output CSR matrix `C` has multiple rows, the `bitset` is logically repeated
* across all rows of `C`. For example, if `C` has `n_rows` rows, the same `bitset`
* pattern is applied to all rows.
*
* @tparam value_t Data type of input matrix elements (e.g., half, float, double).
* @tparam output_t Data type of output matrix elements (e.g., float, double).
* @tparam index_t Type for matrix indices.
* @tparam nnz_t Type for non-zero entries in CSR format.
* @tparam bitmap_t Type for the bitmap mask.
*
* @param[in] handle RAFT handle for managing resources.
* @param[in] A Dense input matrix [m, k] (row-major).
* @param[in] B Dense input matrix [n, k] (row-major).
* @param[in] mask Bitmap view representing a single row [1, n], where each bit
* indicates if the corresponding element in (A * B) is included (1)
* or masked out (0). The pattern is repeated for all rows of `C`.
* @param[inout] C Output sparse matrix in CSR format [m, n].
* @param[in] alpha Scalar multiplier for (A * B) (default: 1.0 if std::nullopt).
* @param[in] beta Scalar multiplier for the initial C (default: 0 if std::nullopt).
*/
template <typename value_t, typename output_t, typename index_t, typename nnz_t, typename bitmap_t>
void masked_matmul(raft::resources const& handle,
raft::device_matrix_view<const value_t, index_t, raft::row_major> A,
raft::device_matrix_view<const value_t, index_t, raft::row_major> B,
raft::core::bitset_view<const bitmap_t, index_t> mask,
raft::device_csr_matrix_view<output_t, index_t, index_t, nnz_t> C,
std::optional<raft::host_scalar_view<output_t>> alpha = std::nullopt,
std::optional<raft::host_scalar_view<output_t>> beta = std::nullopt)
{
detail::masked_matmul(handle, A, B, mask, C, alpha, beta);
}

/** @} */ // end of masked_matmul

} // end namespace linalg
Expand Down
Loading

0 comments on commit ab6d71e

Please sign in to comment.