diff --git a/src/rng/backends/curand/curand_task.hpp b/src/rng/backends/curand/curand_task.hpp index adc08b840..240ced805 100644 --- a/src/rng/backends/curand/curand_task.hpp +++ b/src/rng/backends/curand/curand_task.hpp @@ -36,23 +36,39 @@ static inline void host_task_internal(H &cgh, E e, F f) { #else template static inline void host_task_internal(H &cgh, A acc, E e, F f) { +#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND + cgh.ext_codeplay_enqueue_native_command([=](sycl::interop_handle ih){ +#else cgh.host_task([=](sycl::interop_handle ih) { +#endif curandStatus_t status; auto stream = ih.get_native_queue(); CURAND_CALL(curandSetStream, status, e, stream); auto r_ptr = reinterpret_cast( ih.get_native_mem(acc)); f(r_ptr); +#ifndef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND + CUresult err; + CUDA_ERROR_FUNC(cuStreamSynchronize, err, stream); +#endif }); } template static inline void host_task_internal(H &cgh, E e, F f) { +#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND + cgh.ext_codeplay_enqueue_native_command([=](sycl::interop_handle ih){ +#else cgh.host_task([=](sycl::interop_handle ih) { +#endif curandStatus_t status; auto stream = ih.get_native_queue(); CURAND_CALL(curandSetStream, status, e, stream); f(ih); +#ifndef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND + CUresult err; + CUDA_ERROR_FUNC(cuStreamSynchronize, err, stream); +#endif }); } #endif diff --git a/src/rng/backends/rocrand/rocrand_task.hpp b/src/rng/backends/rocrand/rocrand_task.hpp index 2588dc901..bad40a9e5 100644 --- a/src/rng/backends/rocrand/rocrand_task.hpp +++ b/src/rng/backends/rocrand/rocrand_task.hpp @@ -36,29 +36,39 @@ static inline void host_task_internal(H &cgh, E e, F f) { #else template static inline void host_task_internal(H &cgh, A acc, E e, F f) { +#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND + cgh.ext_codeplay_enqueue_native_command([=](sycl::interop_handle ih){ +#else cgh.host_task([=](sycl::interop_handle ih) { +#endif rocrand_status status; auto stream = ih.get_native_queue(); ROCRAND_CALL(rocrand_set_stream, status, e, stream); auto r_ptr = reinterpret_cast( ih.get_native_mem(acc)); f(r_ptr); - +#ifndef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND hipError_t err; HIP_ERROR_FUNC(hipStreamSynchronize, err, stream); +#endif }); } template static inline void host_task_internal(H &cgh, E e, F f) { +#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND + cgh.ext_codeplay_enqueue_native_command([=](sycl::interop_handle ih){ +#else cgh.host_task([=](sycl::interop_handle ih) { +#endif rocrand_status status; auto stream = ih.get_native_queue(); ROCRAND_CALL(rocrand_set_stream, status, e, stream); f(ih); - +#ifndef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND hipError_t err; HIP_ERROR_FUNC(hipStreamSynchronize, err, stream); +#endif }); } #endif