From 3fc8718b3573fab3285a69913ab2985aaeb9bd4c Mon Sep 17 00:00:00 2001 From: Dylan Lim Date: Tue, 15 Oct 2024 19:22:55 -0700 Subject: [PATCH] Issue #1435, tests for managed stream and handle --- lib/kernels/src/managed_ff_stream.cc | 19 +++++++---- .../src/managed_per_device_ff_handle.cc | 33 +++++++++++------- .../test/src/test_managed_ff_stream.cc | 29 ++++++++++++++++ .../src/test_managed_per_device_ff_handle.cc | 34 +++++++++++++++++++ 4 files changed, 97 insertions(+), 18 deletions(-) create mode 100644 lib/kernels/test/src/test_managed_ff_stream.cc create mode 100644 lib/kernels/test/src/test_managed_per_device_ff_handle.cc diff --git a/lib/kernels/src/managed_ff_stream.cc b/lib/kernels/src/managed_ff_stream.cc index 7385b6cc3e..a8b44dc1d3 100644 --- a/lib/kernels/src/managed_ff_stream.cc +++ b/lib/kernels/src/managed_ff_stream.cc @@ -1,28 +1,35 @@ #include "kernels/managed_ff_stream.h" +#include "utils/exception.h" namespace FlexFlow { ManagedFFStream::ManagedFFStream() : stream(new ffStream_t) { - checkCUDA(cudaStreamCreate(stream)); + checkCUDA(cudaStreamCreate(this->stream)); } ManagedFFStream::ManagedFFStream(ManagedFFStream &&other) noexcept : stream(std::exchange(other.stream, nullptr)) {} ManagedFFStream &ManagedFFStream::operator=(ManagedFFStream &&other) noexcept { - std::swap(this->stream, other.stream); + if (this != &other) { + if (this->stream != nullptr) { + checkCUDA(cudaStreamDestroy(*this->stream)); + delete stream; + } + this->stream = std::exchange(other.stream, nullptr); + } return *this; } ManagedFFStream::~ManagedFFStream() { - if (stream != nullptr) { - checkCUDA(cudaStreamDestroy(*stream)); - delete stream; + if (this->stream != nullptr) { + checkCUDA(cudaStreamDestroy(*this->stream)); + delete this->stream; } } ffStream_t const &ManagedFFStream::raw_stream() const { - return *stream; + return *this->stream; } } // namespace FlexFlow diff --git a/lib/kernels/src/managed_per_device_ff_handle.cc b/lib/kernels/src/managed_per_device_ff_handle.cc index c050e887b6..ca105f9bc9 100644 --- a/lib/kernels/src/managed_per_device_ff_handle.cc +++ b/lib/kernels/src/managed_per_device_ff_handle.cc @@ -4,13 +4,13 @@ namespace FlexFlow { ManagedPerDeviceFFHandle::ManagedPerDeviceFFHandle() { - handle = new PerDeviceFFHandle; - handle->workSpaceSize = 1024 * 1024; - handle->allowTensorOpMathConversion = true; + this->handle = new PerDeviceFFHandle; + this->handle->workSpaceSize = 1024 * 1024; + this->handle->allowTensorOpMathConversion = true; - checkCUDNN(cudnnCreate(&handle->dnn)); - checkCUBLAS(cublasCreate(&handle->blas)); - checkCUDA(cudaMalloc(&handle->workSpace, handle->workSpaceSize)); + checkCUDNN(cudnnCreate(&this->handle->dnn)); + checkCUBLAS(cublasCreate(&this->handle->blas)); + checkCUDA(cudaMalloc(&this->handle->workSpace, this->handle->workSpaceSize)); } ManagedPerDeviceFFHandle::ManagedPerDeviceFFHandle( @@ -19,16 +19,25 @@ ManagedPerDeviceFFHandle::ManagedPerDeviceFFHandle( ManagedPerDeviceFFHandle &ManagedPerDeviceFFHandle::operator=( ManagedPerDeviceFFHandle &&other) noexcept { - std::swap(this->handle, other.handle); + if (this != &other) { + if (this->handle != nullptr) { + checkCUDNN(cudnnDestroy(this->handle->dnn)); + checkCUBLAS(cublasDestroy(this->handle->blas)); + checkCUDA(cudaFree(this->handle->workSpace)); + delete this->handle; + } + this->handle = std::exchange(other.handle, nullptr); + } return *this; } ManagedPerDeviceFFHandle::~ManagedPerDeviceFFHandle() { - if (handle != nullptr) { - checkCUDNN(cudnnDestroy(handle->dnn)); - checkCUBLAS(cublasDestroy(handle->blas)); - checkCUDA(cudaFree(handle->workSpace)); - delete handle; + if (this->handle != nullptr) { + checkCUDNN(cudnnDestroy(this->handle->dnn)); + checkCUBLAS(cublasDestroy(this->handle->blas)); + checkCUDA(cudaFree(this->handle->workSpace)); + delete this->handle; + this->handle = nullptr; } } diff --git a/lib/kernels/test/src/test_managed_ff_stream.cc b/lib/kernels/test/src/test_managed_ff_stream.cc new file mode 100644 index 0000000000..1dc40f0a92 --- /dev/null +++ b/lib/kernels/test/src/test_managed_ff_stream.cc @@ -0,0 +1,29 @@ +#include "doctest/doctest.h" +#include "kernels/managed_ff_stream.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("Test Managed FF Stream") { + ManagedFFStream base_stream{}; + + SUBCASE("Test ManagedFFStream Move Constructor") { + ffStream_t const *base_stream_ptr = &base_stream.raw_stream(); + + ManagedFFStream new_stream(std::move(base_stream)); + + CHECK(&base_stream.raw_stream() == nullptr); + CHECK(&new_stream.raw_stream() == base_stream_ptr); + } + + SUBCASE("Test ManagedFFStream Assignment Operator") { + ffStream_t const *base_stream_ptr = &base_stream.raw_stream(); + + ManagedFFStream new_stream{}; + new_stream = std::move(base_stream); + + CHECK(&base_stream.raw_stream() == nullptr); + CHECK(&new_stream.raw_stream() == base_stream_ptr); + } + } +} diff --git a/lib/kernels/test/src/test_managed_per_device_ff_handle.cc b/lib/kernels/test/src/test_managed_per_device_ff_handle.cc new file mode 100644 index 0000000000..d99d375a7c --- /dev/null +++ b/lib/kernels/test/src/test_managed_per_device_ff_handle.cc @@ -0,0 +1,34 @@ +#include "doctest/doctest.h" +#include "kernels/managed_per_device_ff_handle.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("Test Managed Per Device FF Handle") { + ManagedPerDeviceFFHandle base_handle{}; + + SUBCASE("Test ManagedPerDeviceFFHandle Constructor") { + CHECK(base_handle.raw_handle().workSpaceSize == 1024 * 1024); + CHECK(base_handle.raw_handle().allowTensorOpMathConversion == true); + } + + SUBCASE("Test ManagedPerDeviceFFHandle Move Constructor") { + PerDeviceFFHandle const *base_handle_ptr = &base_handle.raw_handle(); + + ManagedPerDeviceFFHandle new_handle(std::move(base_handle)); + + CHECK(&base_handle.raw_handle() == nullptr); + CHECK(&new_handle.raw_handle() == base_handle_ptr); + } + + SUBCASE("Test ManagedPerDeviceFFHandle Assignment Operator") { + PerDeviceFFHandle const *base_handle_ptr = &base_handle.raw_handle(); + + ManagedPerDeviceFFHandle new_handle{}; + new_handle = std::move(base_handle); + + CHECK(&base_handle.raw_handle() == nullptr); + CHECK(&new_handle.raw_handle() == base_handle_ptr); + } + } +}