diff --git a/cpp/src/neighbors/detail/knn_brute_force.cuh b/cpp/src/neighbors/detail/knn_brute_force.cuh index f05bebf3f..559d33cc2 100644 --- a/cpp/src/neighbors/detail/knn_brute_force.cuh +++ b/cpp/src/neighbors/detail/knn_brute_force.cuh @@ -42,7 +42,7 @@ #include #include #include -#include +#include #include #include #include @@ -636,36 +636,13 @@ void brute_force_search_filtered( rows.data(), compressed_csr_view.get_nnz(), stream); - if (n_queries > 10) { - auto csr_view = raft::make_device_csr_matrix_view( - csr.get_elements().data(), compressed_csr_view); - - // create dataset view - auto dataset_view = raft::make_device_matrix_view( - idx.dataset().data_handle(), dim, n_dataset); - - // calc dot - T alpha = static_cast(1.0f); - T beta = static_cast(0.0f); - raft::sparse::linalg::sddmm(res, - queries, - dataset_view, - csr_view, - raft::linalg::Operation::NON_TRANSPOSE, - raft::linalg::Operation::NON_TRANSPOSE, - raft::make_host_scalar_view(&alpha), - raft::make_host_scalar_view(&beta)); - } else { - raft::sparse::distance::detail::faster_dot_on_csr(res, - csr.get_elements().data(), - compressed_csr_view.get_nnz(), - compressed_csr_view.get_indptr().data(), - compressed_csr_view.get_indices().data(), - queries.data_handle(), - idx.dataset().data_handle(), - compressed_csr_view.get_n_rows(), - dim); - } + auto dataset_view = raft::make_device_matrix_view( + idx.dataset().data_handle(), n_dataset, dim); + + auto csr_view = raft::make_device_csr_matrix_view( + csr.get_elements().data(), compressed_csr_view); + + raft::sparse::linalg::masked_matmul(res, queries, dataset_view, filter, csr_view); // post process std::optional> query_norms_;