From 3f924de082a87eb39761647b90e5e306f2982158 Mon Sep 17 00:00:00 2001 From: Artem Chirkin <9253178+achirkin@users.noreply.github.com> Date: Wed, 11 Dec 2024 07:21:31 -0800 Subject: [PATCH] Fix dataset deserialize not owning the data --- cpp/include/cuvs/neighbors/common.hpp | 78 ++++++++++++++++++- .../neighbors/detail/dataset_serialize.hpp | 2 +- 2 files changed, 76 insertions(+), 4 deletions(-) diff --git a/cpp/include/cuvs/neighbors/common.hpp b/cpp/include/cuvs/neighbors/common.hpp index 60b8cc122..bd9ea4834 100644 --- a/cpp/include/cuvs/neighbors/common.hpp +++ b/cpp/include/cuvs/neighbors/common.hpp @@ -264,6 +264,77 @@ auto make_strided_dataset(const raft::resources& res, const SrcT& src, uint32_t return std::make_unique(std::move(out_array), out_layout); } +/** + * @brief Contstruct a strided matrix from any mdarray. + * + * This function constructs an owning device matrix and copies the data. + * When the data is copied, padding elements are filled with zeroes. + * + * @tparam DataT + * @tparam IdxT + * @tparam LayoutPolicy + * @tparam ContainerPolicy + * + * @param[in] res raft resources handle + * @param[in] src the source mdarray or mdspan + * @param[in] required_stride the leading dimension (in elements) + * @return owning current-device-accessible strided matrix + */ +template +auto make_strided_dataset( + const raft::resources& res, + raft::mdarray, LayoutPolicy, ContainerPolicy>&& src, + uint32_t required_stride) -> std::unique_ptr> +{ + using value_type = DataT; + using index_type = IdxT; + using layout_type = LayoutPolicy; + using container_policy_type = ContainerPolicy; + static_assert(std::is_same_v || + std::is_same_v> || + std::is_same_v, + "The input must be row-major"); + RAFT_EXPECTS(src.extent(1) <= required_stride, + "The input row length must be not larger than the desired stride."); + const uint32_t src_stride = src.stride(0) > 0 ? src.stride(0) : src.extent(1); + const bool stride_matches = required_stride == src_stride; + + auto out_layout = + raft::make_strided_layout(src.extents(), std::array{required_stride, 1}); + + using out_mdarray_type = raft::device_matrix; + using out_layout_type = typename out_mdarray_type::layout_type; + using out_container_policy_type = typename out_mdarray_type::container_policy_type; + using out_owning_type = + owning_dataset; + + if constexpr (std::is_same_v && + std::is_same_v) { + if (stride_matches) { + // Everything matches, we can own the mdarray + return std::make_unique(std::move(src), out_layout); + } + } + // Something is wrong: have to make a copy and produce an owning dataset + auto out_array = + raft::make_device_matrix(res, src.extent(0), required_stride); + + RAFT_CUDA_TRY(cudaMemsetAsync(out_array.data_handle(), + 0, + out_array.size() * sizeof(value_type), + raft::resource::get_cuda_stream(res))); + RAFT_CUDA_TRY(cudaMemcpy2DAsync(out_array.data_handle(), + sizeof(value_type) * required_stride, + src.data_handle(), + sizeof(value_type) * src_stride, + sizeof(value_type) * src.extent(1), + src.extent(0), + cudaMemcpyDefault, + raft::resource::get_cuda_stream(res))); + + return std::make_unique(std::move(out_array), out_layout); +} + /** * @brief Contstruct a strided matrix from any mdarray or mdspan. * @@ -278,14 +349,15 @@ auto make_strided_dataset(const raft::resources& res, const SrcT& src, uint32_t * @return maybe owning current-device-accessible strided matrix */ template -auto make_aligned_dataset(const raft::resources& res, const SrcT& src, uint32_t align_bytes = 16) +auto make_aligned_dataset(const raft::resources& res, SrcT src, uint32_t align_bytes = 16) -> std::unique_ptr> { - using value_type = typename SrcT::value_type; + using source_type = std::remove_cv_t>; + using value_type = typename source_type::value_type; constexpr size_t kSize = sizeof(value_type); uint32_t required_stride = raft::round_up_safe(src.extent(1) * kSize, std::lcm(align_bytes, kSize)) / kSize; - return make_strided_dataset(res, src, required_stride); + return make_strided_dataset(res, std::forward(src), required_stride); } /** * @brief VPQ compressed dataset. diff --git a/cpp/src/neighbors/detail/dataset_serialize.hpp b/cpp/src/neighbors/detail/dataset_serialize.hpp index 40d9df930..0ecc2cf5d 100644 --- a/cpp/src/neighbors/detail/dataset_serialize.hpp +++ b/cpp/src/neighbors/detail/dataset_serialize.hpp @@ -140,7 +140,7 @@ auto deserialize_strided(raft::resources const& res, std::istream& is) auto stride = raft::deserialize_scalar(res, is); auto host_array = raft::make_host_matrix(n_rows, dim); raft::deserialize_mdspan(res, is, host_array.view()); - return make_strided_dataset(res, host_array, stride); + return make_strided_dataset(res, std::move(host_array), stride); } template