Skip to content

Commit

Permalink
[RAND] rocrand/curand enqueue_native_command impls (#579)
Browse files Browse the repository at this point in the history
Signed-off-by: JackAKirk <[email protected]>
  • Loading branch information
JackAKirk authored Oct 10, 2024
1 parent c5ac41f commit d19d454
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 2 deletions.
16 changes: 16 additions & 0 deletions src/rng/backends/curand/curand_task.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,23 +36,39 @@ static inline void host_task_internal(H &cgh, E e, F f) {
#else
template <typename H, typename A, typename E, typename F>
static inline void host_task_internal(H &cgh, A acc, E e, F f) {
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
cgh.ext_codeplay_enqueue_native_command([=](sycl::interop_handle ih){
#else
cgh.host_task([=](sycl::interop_handle ih) {
#endif
curandStatus_t status;
auto stream = ih.get_native_queue<sycl::backend::ext_oneapi_cuda>();
CURAND_CALL(curandSetStream, status, e, stream);
auto r_ptr = reinterpret_cast<typename A::value_type *>(
ih.get_native_mem<sycl::backend::ext_oneapi_cuda>(acc));
f(r_ptr);
#ifndef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
CUresult err;
CUDA_ERROR_FUNC(cuStreamSynchronize, err, stream);
#endif
});
}

template <typename H, typename E, typename F>
static inline void host_task_internal(H &cgh, E e, F f) {
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
cgh.ext_codeplay_enqueue_native_command([=](sycl::interop_handle ih){
#else
cgh.host_task([=](sycl::interop_handle ih) {
#endif
curandStatus_t status;
auto stream = ih.get_native_queue<sycl::backend::ext_oneapi_cuda>();
CURAND_CALL(curandSetStream, status, e, stream);
f(ih);
#ifndef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
CUresult err;
CUDA_ERROR_FUNC(cuStreamSynchronize, err, stream);
#endif
});
}
#endif
Expand Down
14 changes: 12 additions & 2 deletions src/rng/backends/rocrand/rocrand_task.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,29 +36,39 @@ static inline void host_task_internal(H &cgh, E e, F f) {
#else
template <typename H, typename A, typename E, typename F>
static inline void host_task_internal(H &cgh, A acc, E e, F f) {
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
cgh.ext_codeplay_enqueue_native_command([=](sycl::interop_handle ih){
#else
cgh.host_task([=](sycl::interop_handle ih) {
#endif
rocrand_status status;
auto stream = ih.get_native_queue<sycl::backend::ext_oneapi_hip>();
ROCRAND_CALL(rocrand_set_stream, status, e, stream);
auto r_ptr = reinterpret_cast<typename A::value_type *>(
ih.get_native_mem<sycl::backend::ext_oneapi_hip>(acc));
f(r_ptr);

#ifndef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
hipError_t err;
HIP_ERROR_FUNC(hipStreamSynchronize, err, stream);
#endif
});
}

template <typename H, typename E, typename F>
static inline void host_task_internal(H &cgh, E e, F f) {
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
cgh.ext_codeplay_enqueue_native_command([=](sycl::interop_handle ih){
#else
cgh.host_task([=](sycl::interop_handle ih) {
#endif
rocrand_status status;
auto stream = ih.get_native_queue<sycl::backend::ext_oneapi_hip>();
ROCRAND_CALL(rocrand_set_stream, status, e, stream);
f(ih);

#ifndef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
hipError_t err;
HIP_ERROR_FUNC(hipStreamSynchronize, err, stream);
#endif
});
}
#endif
Expand Down

0 comments on commit d19d454

Please sign in to comment.