Skip to content

Commit

Permalink
Merge pull request #250 from rapidsai/branch-24.08
Browse files Browse the repository at this point in the history
Forward-merge branch-24.08 into branch-24.10
  • Loading branch information
GPUtester authored Jul 26, 2024
2 parents ac2b898 + b442756 commit 0cde780
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
9 changes: 5 additions & 4 deletions cpp/src/neighbors/detail/knn_brute_force.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
#include "./knn_utils.cuh"

#include <raft/core/bitmap.cuh>
#include <raft/core/detail/popc.cuh>
#include <raft/core/device_csr_matrix.hpp>
#include <raft/core/host_mdspan.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resource/cuda_stream_pool.hpp>
#include <raft/core/resource/device_memory_resource.hpp>
Expand All @@ -46,6 +46,7 @@
#include <raft/sparse/matrix/select_k.cuh>
#include <raft/util/cuda_utils.cuh>
#include <raft/util/cudart_utils.hpp>
#include <raft/util/popc.cuh>

#include <rmm/cuda_device.hpp>
#include <rmm/device_uvector.hpp>
Expand Down Expand Up @@ -591,10 +592,10 @@ void brute_force_search_filtered(
auto nnz_view = raft::make_device_scalar_view<IdxT>(nnz.data());
auto filter_view =
raft::make_device_vector_view<const BitmapT, IdxT>(filter.data(), filter.n_elements());
IdxT size_h = n_queries * n_dataset;
auto size_view = raft::make_host_scalar_view<IdxT>(&size_h);

// TODO(rhdong): Need to switch to the public API,
// with the issue: https://github.com/rapidsai/cuvs/issues/158
raft::detail::popc(res, filter_view, n_queries * n_dataset, nnz_view);
raft::popc(res, filter_view, size_view, nnz_view);
raft::copy(&nnz_h, nnz.data(), 1, stream);

raft::resource::sync_stream(res, stream);
Expand Down
9 changes: 5 additions & 4 deletions cpp/test/neighbors/brute_force_prefiltered.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@
#include <cuvs/distance/distance.hpp>
#include <cuvs/neighbors/brute_force.hpp>

#include <raft/core/detail/popc.cuh>
#include <raft/core/host_mdspan.hpp>
#include <raft/matrix/copy.cuh>
#include <raft/random/make_blobs.cuh>
#include <raft/random/rmat_rectangular_generator.cuh>
#include <raft/random/rng.cuh>
#include <raft/random/rng_state.hpp>
#include <raft/util/popc.cuh>

#include <gtest/gtest.h>

Expand Down Expand Up @@ -192,12 +193,12 @@ class PrefilteredBruteForceTest
auto nnz_view = raft::make_device_scalar_view<index_t>(nnz.data());
auto filter_view =
raft::make_device_vector_view<const uint32_t, index_t>(filter_d.data(), filter_d.size());
index_t size_h = m * n;
auto size_view = raft::make_host_scalar_view<index_t>(&size_h);

set_bitmap(src, dst, bitmap, n_edges, n, stream);

// TODO(rhdong): Need to switch to the public API,
// with the issue: https://github.com/rapidsai/cuvs/issues/158
raft::detail::popc(handle, filter_view, m * n, nnz_view);
raft::popc(handle, filter_view, size_view, nnz_view);
raft::copy(&nnz_h, nnz.data(), 1, stream);

raft::resource::sync_stream(handle, stream);
Expand Down

0 comments on commit 0cde780

Please sign in to comment.