Skip to content

Commit

Permalink
Add ability to allocate with RMM to the c-api and rust api (#56)
Browse files Browse the repository at this point in the history
Authors:
  - Ben Frederickson (https://github.com/benfred)

Approvers:
  - Divye Gala (https://github.com/divyegala)
  - Corey J. Nolet (https://github.com/cjnolet)
  - Ray Douglass (https://github.com/raydouglass)

URL: #56
  • Loading branch information
benfred authored Mar 19, 2024
1 parent e5d5e3a commit 6981a08
Show file tree
Hide file tree
Showing 9 changed files with 126 additions and 28 deletions.
1 change: 0 additions & 1 deletion ci/build_rust.sh
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,4 @@ rapids-mamba-retry install \
libcuvs \
libraft

export EXTRA_CMAKE_ARGS=""
bash ./build.sh rust
44 changes: 44 additions & 0 deletions cpp/include/cuvs/core/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,50 @@ cuvsError_t cuvsResourcesDestroy(cuvsResources_t res);
*/
cuvsError_t cuvsStreamSet(cuvsResources_t res, cudaStream_t stream);

/**
* @brief Get the cudaStream_t from a cuvsResources_t t
*
* @param[in] res cuvsResources_t opaque C handle
* @param[out] stream cudaStream_t stream to queue CUDA kernels
* @return cuvsError_t
*/
cuvsError_t cuvsStreamGet(cuvsResources_t res, cudaStream_t* stream);

/**
* @brief Syncs the current CUDA stream on the resources object
*
* @param[in] res cuvsResources_t opaque C handle
* @return cuvsError_t
*/
cuvsError_t cuvsStreamSync(cuvsResources_t res);
/** @} */

/**
* @defgroup memory_c cuVS Memory Allocation
* @{
*/

/**
* @brief Allocates device memory using RMM
*
*
* @param[in] res cuvsResources_t opaque C handle
* @param[out] ptr Pointer to allocated device memory
* @param[in] bytes Size in bytes to allocate
* @return cuvsError_t
*/
cuvsError_t cuvsRMMAlloc(cuvsResources_t res, void** ptr, size_t bytes);

/**
* @brief Deallocates device memory using RMM
*
* @param[in] res cuvsResources_t opaque C handle
* @param[in] ptr Pointer to allocated device memory to free
* @param[in] bytes Size in bytes to allocate
* @return cuvsError_t
*/
cuvsError_t cuvsRMMFree(cuvsResources_t res, void* ptr, size_t bytes);

/** @} */

#ifdef __cplusplus
Expand Down
35 changes: 35 additions & 0 deletions cpp/src/core/c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resources.hpp>
#include <rmm/cuda_stream_view.hpp>
#include <rmm/mr/device/per_device_resource.hpp>
#include <thread>

extern "C" cuvsError_t cuvsResourcesCreate(cuvsResources_t* res)
Expand All @@ -47,6 +48,40 @@ extern "C" cuvsError_t cuvsStreamSet(cuvsResources_t res, cudaStream_t stream)
});
}

extern "C" cuvsError_t cuvsStreamGet(cuvsResources_t res, cudaStream_t* stream)
{
return cuvs::core::translate_exceptions([=] {
auto res_ptr = reinterpret_cast<raft::resources*>(res);
*stream = raft::resource::get_cuda_stream(*res_ptr);
});
}

extern "C" cuvsError_t cuvsStreamSync(cuvsResources_t res)
{
return cuvs::core::translate_exceptions([=] {
auto res_ptr = reinterpret_cast<raft::resources*>(res);
raft::resource::sync_stream(*res_ptr);
});
}

extern "C" cuvsError_t cuvsRMMAlloc(cuvsResources_t res, void** ptr, size_t bytes)
{
return cuvs::core::translate_exceptions([=] {
auto res_ptr = reinterpret_cast<raft::resources*>(res);
auto mr = rmm::mr::get_current_device_resource();
*ptr = mr->allocate(bytes, raft::resource::get_cuda_stream(*res_ptr));
});
}

extern "C" cuvsError_t cuvsRMMFree(cuvsResources_t res, void* ptr, size_t bytes)
{
return cuvs::core::translate_exceptions([=] {
auto res_ptr = reinterpret_cast<raft::resources*>(res);
auto mr = rmm::mr::get_current_device_resource();
mr->deallocate(ptr, bytes, raft::resource::get_cuda_stream(*res_ptr));
});
}

thread_local std::string last_error_text = "";

extern "C" const char* cuvsGetLastErrorText()
Expand Down
5 changes: 2 additions & 3 deletions rust/cuvs-sys/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,8 @@ fn main() {
.allowlist_type("(cuvs|cagra|DL).*")
.allowlist_function("(cuvs|cagra).*")
.rustified_enum("(cuvs|cagra|DL).*")
// also need some basic cuda mem functions
// (TODO: should we be adding in RMM support instead here?)
.allowlist_function("(cudaMalloc|cudaFree|cudaMemcpy)")
// also need some basic cuda mem functions for copying data
.allowlist_function("(cudaMemcpyAsync|cudaMemcpy)")
.rustified_enum("cudaError")
.generate()
.expect("Unable to generate cagra_c bindings")
Expand Down
2 changes: 1 addition & 1 deletion rust/cuvs/src/cagra/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ impl Index {
/// Creates a new empty index
pub fn new() -> Result<Index> {
unsafe {
let mut index = core::mem::MaybeUninit::<ffi::cuvsCagraIndex_t>::uninit();
let mut index = std::mem::MaybeUninit::<ffi::cuvsCagraIndex_t>::uninit();
check_cuvs(ffi::cuvsCagraIndexCreate(index.as_mut_ptr()))?;
Ok(Index(index.assume_init()))
}
Expand Down
2 changes: 1 addition & 1 deletion rust/cuvs/src/cagra/index_params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ impl IndexParams {
/// Returns a new IndexParams
pub fn new() -> Result<IndexParams> {
unsafe {
let mut params = core::mem::MaybeUninit::<ffi::cuvsCagraIndexParams_t>::uninit();
let mut params = std::mem::MaybeUninit::<ffi::cuvsCagraIndexParams_t>::uninit();
check_cuvs(ffi::cuvsCagraIndexParamsCreate(params.as_mut_ptr()))?;
Ok(IndexParams(params.assume_init()))
}
Expand Down
2 changes: 1 addition & 1 deletion rust/cuvs/src/cagra/search_params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ impl SearchParams {
/// Returns a new SearchParams object
pub fn new() -> Result<SearchParams> {
unsafe {
let mut params = core::mem::MaybeUninit::<ffi::cuvsCagraSearchParams_t>::uninit();
let mut params = std::mem::MaybeUninit::<ffi::cuvsCagraSearchParams_t>::uninit();
check_cuvs(ffi::cuvsCagraSearchParamsCreate(params.as_mut_ptr()))?;
Ok(SearchParams(params.assume_init()))
}
Expand Down
44 changes: 23 additions & 21 deletions rust/cuvs/src/dlpack.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

use std::convert::From;

use crate::error::{check_cuda, Result};
use crate::error::{check_cuda, check_cuvs, Result};
use crate::resources::Resources;

/// ManagedTensor is a wrapper around a dlpack DLManagedTensor object.
Expand All @@ -33,36 +33,27 @@ impl ManagedTensor {
&self.0 as *const _ as *mut _
}

fn bytes(&self) -> usize {
// figure out how many bytes to allocate
let mut bytes: usize = 1;
for x in 0..self.0.dl_tensor.ndim {
bytes *= unsafe { (*self.0.dl_tensor.shape.add(x as usize)) as usize };
}
bytes *= (self.0.dl_tensor.dtype.bits / 8) as usize;
bytes
}

/// Creates a new ManagedTensor on the current GPU device, and copies
/// the data into it.
pub fn to_device(&self, _res: &Resources) -> Result<ManagedTensor> {
pub fn to_device(&self, res: &Resources) -> Result<ManagedTensor> {
unsafe {
let bytes = self.bytes();
let bytes = dl_tensor_bytes(&self.0.dl_tensor);
let mut device_data: *mut std::ffi::c_void = std::ptr::null_mut();

// allocate storage, copy over
check_cuda(ffi::cudaMalloc(&mut device_data as *mut _, bytes))?;
check_cuda(ffi::cudaMemcpy(
check_cuvs(ffi::cuvsRMMAlloc(res.0, &mut device_data as *mut _, bytes))?;

check_cuda(ffi::cudaMemcpyAsync(
device_data,
self.0.dl_tensor.data,
bytes,
ffi::cudaMemcpyKind_cudaMemcpyDefault,
res.get_cuda_stream()?,
))?;

let mut ret = self.0.clone();
ret.dl_tensor.data = device_data;
// call cudaFree automatically to clean up data
ret.deleter = Some(cuda_free_tensor);
ret.deleter = Some(rmm_free_tensor);
ret.dl_tensor.device.device_type = ffi::DLDeviceType::kDLCUDA;

Ok(ManagedTensor(ret))
Expand All @@ -80,21 +71,32 @@ impl ManagedTensor {
arr: &mut ndarray::ArrayBase<S, D>,
) -> Result<()> {
unsafe {
let bytes = self.bytes();
let bytes = dl_tensor_bytes(&self.0.dl_tensor);
check_cuda(ffi::cudaMemcpy(
arr.as_mut_ptr() as *mut std::ffi::c_void,
self.0.dl_tensor.data,
bytes,
ffi::cudaMemcpyKind_cudaMemcpyDefault,
))?;

Ok(())
}
}
}

unsafe extern "C" fn cuda_free_tensor(self_: *mut ffi::DLManagedTensor) {
let _ = ffi::cudaFree((*self_).dl_tensor.data);
/// Figures out how many bytes are in a DLTensor
fn dl_tensor_bytes(tensor: &ffi::DLTensor) -> usize {
let mut bytes: usize = 1;
for dim in 0..tensor.ndim {
bytes *= unsafe { (*tensor.shape.add(dim as usize)) as usize };
}
bytes *= (tensor.dtype.bits / 8) as usize;
bytes
}

unsafe extern "C" fn rmm_free_tensor(self_: *mut ffi::DLManagedTensor) {
let bytes = dl_tensor_bytes(&(*self_).dl_tensor);
let res = Resources::new().unwrap();
let _ = ffi::cuvsRMMFree(res.0, (*self_).dl_tensor.data as *mut _, bytes);
}

/// Create a non-owning view of a Tensor from a ndarray
Expand Down
19 changes: 19 additions & 0 deletions rust/cuvs/src/resources.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,25 @@ impl Resources {
}
Ok(Resources(res))
}

/// Sets the current cuda stream
pub fn set_cuda_stream(&self, stream: ffi::cudaStream_t) -> Result<()> {
unsafe { check_cuvs(ffi::cuvsStreamSet(self.0, stream)) }
}

/// Gets the current cuda stream
pub fn get_cuda_stream(&self) -> Result<ffi::cudaStream_t> {
unsafe {
let mut stream = std::mem::MaybeUninit::<ffi::cudaStream_t>::uninit();
check_cuvs(ffi::cuvsStreamGet(self.0, stream.as_mut_ptr()))?;
Ok(stream.assume_init())
}
}

/// Syncs the current cuda stream
pub fn sync_stream(&self) -> Result<()> {
unsafe { check_cuvs(ffi::cuvsStreamSync(self.0)) }
}
}

impl Drop for Resources {
Expand Down

0 comments on commit 6981a08

Please sign in to comment.