Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Manually destroy cuBLAS and cuDNN handles before threads exit #1201

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/ctranslate2/devices.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ namespace ctranslate2 {
void synchronize_device(Device device, int index);
void synchronize_stream(Device device);

void destroy_context(Device device);

class ScopedDeviceSetter {
public:
ScopedDeviceSetter(Device device, int index)
Expand Down
2 changes: 2 additions & 0 deletions include/ctranslate2/replica_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,8 @@ namespace ctranslate2 {

void finalize() override {
_replica.reset();

destroy_context(_device);
}

private:
Expand Down
40 changes: 33 additions & 7 deletions src/cuda/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
#include <stdexcept>
#include <vector>

#include <spdlog/spdlog.h>

#include "ctranslate2/utils.h"

#include "env.h"
Expand Down Expand Up @@ -81,7 +83,11 @@ namespace ctranslate2 {
}
~CublasHandle() {
ScopedDeviceSetter scoped_device_setter(Device::CUDA, _device);
cublasDestroy(_handle);
cublasStatus_t status = cublasDestroy(_handle);

if (status != CUBLAS_STATUS_SUCCESS)
spdlog::error("cublasDestroy failed with status "
+ std::string(cuda::cublasGetStatusName(status)));
}
cublasHandle_t get() const {
return _handle;
Expand All @@ -92,16 +98,20 @@ namespace ctranslate2 {
};

// We create one cuBLAS/cuDNN handle per host thread. The handle is destroyed
// when the thread exits.
// when the thread exits or when destroy_handles is called.

cudaStream_t get_cuda_stream() {
static thread_local CudaStream cuda_stream;
return cuda_stream.get();
}

static thread_local std::unique_ptr<CublasHandle> cublas_handle;

cublasHandle_t get_cublas_handle() {
static thread_local CublasHandle cublas_handle;
return cublas_handle.get();
if (!cublas_handle)
cublas_handle = std::make_unique<CublasHandle>();

return cublas_handle->get();
}

#ifdef CT2_WITH_CUDNN
Expand All @@ -114,7 +124,11 @@ namespace ctranslate2 {
}
~CudnnHandle() {
ScopedDeviceSetter scoped_device_setter(Device::CUDA, _device);
cudnnDestroy(_handle);
cudnnStatus_t status = cudnnDestroy(_handle);

if (status != CUDNN_STATUS_SUCCESS)
spdlog::error("cudnnDestroy failed with status "
+ std::string(cudnnGetErrorString(status)));
}
cudnnHandle_t get() const {
return _handle;
Expand All @@ -124,9 +138,13 @@ namespace ctranslate2 {
cudnnHandle_t _handle;
};

static thread_local std::unique_ptr<CudnnHandle> cudnn_handle;

cudnnHandle_t get_cudnn_handle() {
static thread_local CudnnHandle cudnn_handle;
return cudnn_handle.get();
if (!cudnn_handle)
cudnn_handle = std::make_unique<CudnnHandle>();

return cudnn_handle->get();
}

cudnnDataType_t get_cudnn_data_type(DataType dtype) {
Expand All @@ -145,6 +163,14 @@ namespace ctranslate2 {
}
#endif

void destroy_handles() {
#ifdef CT2_WITH_CUDNN
cudnn_handle.reset();
#endif

cublas_handle.reset();
}

int get_gpu_count() {
int gpu_count = 0;
cudaError_t status = cudaGetDeviceCount(&gpu_count);
Expand Down
3 changes: 3 additions & 0 deletions src/cuda/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ namespace ctranslate2 {
cudnnDataType_t get_cudnn_data_type(DataType dtype);
#endif

// Destroy cuBLAS and cuDNN handles for the current thread.
void destroy_handles();

int get_gpu_count();
bool has_gpu();
const cudaDeviceProp& get_device_properties(int device = -1);
Expand Down
10 changes: 10 additions & 0 deletions src/devices.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,4 +116,14 @@ namespace ctranslate2 {
#endif
}

void destroy_context(Device device) {
#ifdef CT2_WITH_CUDA
if (device == Device::CUDA) {
cuda::destroy_handles();
}
#else
(void)device;
#endif
}

}