Skip to content

Commit

Permalink
Fix dataset deserialize not owning the data
Browse files Browse the repository at this point in the history
  • Loading branch information
achirkin committed Dec 11, 2024
1 parent 430424b commit 3f924de
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 4 deletions.
78 changes: 75 additions & 3 deletions cpp/include/cuvs/neighbors/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,77 @@ auto make_strided_dataset(const raft::resources& res, const SrcT& src, uint32_t
return std::make_unique<out_owning_type>(std::move(out_array), out_layout);
}

/**
* @brief Contstruct a strided matrix from any mdarray.
*
* This function constructs an owning device matrix and copies the data.
* When the data is copied, padding elements are filled with zeroes.
*
* @tparam DataT
* @tparam IdxT
* @tparam LayoutPolicy
* @tparam ContainerPolicy
*
* @param[in] res raft resources handle
* @param[in] src the source mdarray or mdspan
* @param[in] required_stride the leading dimension (in elements)
* @return owning current-device-accessible strided matrix
*/
template <typename DataT, typename IdxT, typename LayoutPolicy, typename ContainerPolicy>
auto make_strided_dataset(
const raft::resources& res,
raft::mdarray<DataT, raft::matrix_extent<IdxT>, LayoutPolicy, ContainerPolicy>&& src,
uint32_t required_stride) -> std::unique_ptr<strided_dataset<DataT, IdxT>>
{
using value_type = DataT;
using index_type = IdxT;
using layout_type = LayoutPolicy;
using container_policy_type = ContainerPolicy;
static_assert(std::is_same_v<layout_type, raft::layout_right> ||
std::is_same_v<layout_type, raft::layout_right_padded<value_type>> ||
std::is_same_v<layout_type, raft::layout_stride>,
"The input must be row-major");
RAFT_EXPECTS(src.extent(1) <= required_stride,
"The input row length must be not larger than the desired stride.");
const uint32_t src_stride = src.stride(0) > 0 ? src.stride(0) : src.extent(1);
const bool stride_matches = required_stride == src_stride;

auto out_layout =
raft::make_strided_layout(src.extents(), std::array<index_type, 2>{required_stride, 1});

using out_mdarray_type = raft::device_matrix<value_type, index_type>;
using out_layout_type = typename out_mdarray_type::layout_type;
using out_container_policy_type = typename out_mdarray_type::container_policy_type;
using out_owning_type =
owning_dataset<value_type, index_type, out_layout_type, out_container_policy_type>;

if constexpr (std::is_same_v<layout_type, out_layout_type> &&
std::is_same_v<container_policy_type, out_container_policy_type>) {
if (stride_matches) {
// Everything matches, we can own the mdarray
return std::make_unique<out_owning_type>(std::move(src), out_layout);
}
}
// Something is wrong: have to make a copy and produce an owning dataset
auto out_array =
raft::make_device_matrix<value_type, index_type>(res, src.extent(0), required_stride);

RAFT_CUDA_TRY(cudaMemsetAsync(out_array.data_handle(),
0,
out_array.size() * sizeof(value_type),
raft::resource::get_cuda_stream(res)));
RAFT_CUDA_TRY(cudaMemcpy2DAsync(out_array.data_handle(),
sizeof(value_type) * required_stride,
src.data_handle(),
sizeof(value_type) * src_stride,
sizeof(value_type) * src.extent(1),
src.extent(0),
cudaMemcpyDefault,
raft::resource::get_cuda_stream(res)));

return std::make_unique<out_owning_type>(std::move(out_array), out_layout);
}

/**
* @brief Contstruct a strided matrix from any mdarray or mdspan.
*
Expand All @@ -278,14 +349,15 @@ auto make_strided_dataset(const raft::resources& res, const SrcT& src, uint32_t
* @return maybe owning current-device-accessible strided matrix
*/
template <typename SrcT>
auto make_aligned_dataset(const raft::resources& res, const SrcT& src, uint32_t align_bytes = 16)
auto make_aligned_dataset(const raft::resources& res, SrcT src, uint32_t align_bytes = 16)
-> std::unique_ptr<strided_dataset<typename SrcT::value_type, typename SrcT::index_type>>
{
using value_type = typename SrcT::value_type;
using source_type = std::remove_cv_t<std::remove_reference_t<SrcT>>;
using value_type = typename source_type::value_type;
constexpr size_t kSize = sizeof(value_type);
uint32_t required_stride =
raft::round_up_safe<size_t>(src.extent(1) * kSize, std::lcm(align_bytes, kSize)) / kSize;
return make_strided_dataset(res, src, required_stride);
return make_strided_dataset(res, std::forward<SrcT>(src), required_stride);
}
/**
* @brief VPQ compressed dataset.
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/neighbors/detail/dataset_serialize.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ auto deserialize_strided(raft::resources const& res, std::istream& is)
auto stride = raft::deserialize_scalar<uint32_t>(res, is);
auto host_array = raft::make_host_matrix<DataT, IdxT>(n_rows, dim);
raft::deserialize_mdspan(res, is, host_array.view());
return make_strided_dataset(res, host_array, stride);
return make_strided_dataset(res, std::move(host_array), stride);
}

template <typename MathT, typename IdxT>
Expand Down

0 comments on commit 3f924de

Please sign in to comment.