Skip to content

Commit

Permalink
Keep the host dataset alive in case if CAGRA index is not-owning
Browse files Browse the repository at this point in the history
  • Loading branch information
achirkin committed Dec 11, 2024
1 parent ce56d93 commit 430424b
Showing 1 changed file with 14 additions and 9 deletions.
23 changes: 14 additions & 9 deletions cpp/test/neighbors/ann_cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -389,12 +389,13 @@ class AnnCagraTest : public ::testing::TestWithParam<AnnCagraInputs> {
(const DataT*)database.data(), ps.n_rows, ps.dim);

{
std::optional<raft::host_matrix<DataT, int64_t>> database_host{std::nullopt};
cagra::index<DataT, IdxT> index(handle_, index_params.metric);
if (ps.host_dataset) {
auto database_host = raft::make_host_matrix<DataT, int64_t>(ps.n_rows, ps.dim);
raft::copy(database_host.data_handle(), database.data(), database.size(), stream_);
database_host = raft::make_host_matrix<DataT, int64_t>(ps.n_rows, ps.dim);
raft::copy(database_host->data_handle(), database.data(), database.size(), stream_);
auto database_host_view = raft::make_host_matrix_view<const DataT, int64_t>(
(const DataT*)database_host.data_handle(), ps.n_rows, ps.dim);
(const DataT*)database_host->data_handle(), ps.n_rows, ps.dim);

index = cagra::build(handle_, index_params, database_host_view);
} else {
Expand Down Expand Up @@ -567,13 +568,16 @@ class AnnCagraAddNodesTest : public ::testing::TestWithParam<AnnCagraInputs> {
auto initial_database_view = raft::make_device_matrix_view<const DataT, int64_t>(
(const DataT*)database.data(), initial_database_size, ps.dim);

std::optional<raft::host_matrix<DataT, int64_t>> database_host{std::nullopt};
cagra::index<DataT, IdxT> index(handle_);
if (ps.host_dataset) {
auto database_host = raft::make_host_matrix<DataT, int64_t>(ps.n_rows, ps.dim);
database_host = raft::make_host_matrix<DataT, int64_t>(ps.n_rows, ps.dim);
raft::copy(
database_host.data_handle(), database.data(), initial_database_view.size(), stream_);
database_host->data_handle(), database.data(), initial_database_view.size(), stream_);
auto database_host_view = raft::make_host_matrix_view<const DataT, int64_t>(
(const DataT*)database_host.data_handle(), initial_database_size, ps.dim);
(const DataT*)database_host->data_handle(), initial_database_size, ps.dim);
// NB: database_host must live no less than the index, because the index _may_be_
// non-onwning
index = cagra::build(handle_, index_params, database_host_view);
} else {
index = cagra::build(handle_, index_params, initial_database_view);
Expand Down Expand Up @@ -763,12 +767,13 @@ class AnnCagraFilterTest : public ::testing::TestWithParam<AnnCagraInputs> {
auto database_view = raft::make_device_matrix_view<const DataT, int64_t>(
(const DataT*)database.data(), ps.n_rows, ps.dim);

std::optional<raft::host_matrix<DataT, int64_t>> database_host{std::nullopt};
cagra::index<DataT, IdxT> index(handle_);
if (ps.host_dataset) {
auto database_host = raft::make_host_matrix<DataT, int64_t>(ps.n_rows, ps.dim);
raft::copy(database_host.data_handle(), database.data(), database.size(), stream_);
database_host = raft::make_host_matrix<DataT, int64_t>(ps.n_rows, ps.dim);
raft::copy(database_host->data_handle(), database.data(), database.size(), stream_);
auto database_host_view = raft::make_host_matrix_view<const DataT, int64_t>(
(const DataT*)database_host.data_handle(), ps.n_rows, ps.dim);
(const DataT*)database_host->data_handle(), ps.n_rows, ps.dim);
index = cagra::build(handle_, index_params, database_host_view);
} else {
index = cagra::build(handle_, index_params, database_view);
Expand Down

0 comments on commit 430424b

Please sign in to comment.