Skip to content

Commit

Permalink
Issue flexflow#1435, tests for managed stream and handle
Browse files Browse the repository at this point in the history
  • Loading branch information
oOTigger committed Oct 16, 2024
1 parent de230cb commit 3fc8718
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 18 deletions.
19 changes: 13 additions & 6 deletions lib/kernels/src/managed_ff_stream.cc
Original file line number Diff line number Diff line change
@@ -1,28 +1,35 @@
#include "kernels/managed_ff_stream.h"
#include "utils/exception.h"

namespace FlexFlow {

ManagedFFStream::ManagedFFStream() : stream(new ffStream_t) {
checkCUDA(cudaStreamCreate(stream));
checkCUDA(cudaStreamCreate(this->stream));
}

ManagedFFStream::ManagedFFStream(ManagedFFStream &&other) noexcept
: stream(std::exchange(other.stream, nullptr)) {}

ManagedFFStream &ManagedFFStream::operator=(ManagedFFStream &&other) noexcept {
std::swap(this->stream, other.stream);
if (this != &other) {
if (this->stream != nullptr) {
checkCUDA(cudaStreamDestroy(*this->stream));
delete stream;
}
this->stream = std::exchange(other.stream, nullptr);
}
return *this;
}

ManagedFFStream::~ManagedFFStream() {
if (stream != nullptr) {
checkCUDA(cudaStreamDestroy(*stream));
delete stream;
if (this->stream != nullptr) {
checkCUDA(cudaStreamDestroy(*this->stream));
delete this->stream;
}
}

ffStream_t const &ManagedFFStream::raw_stream() const {
return *stream;
return *this->stream;
}

} // namespace FlexFlow
33 changes: 21 additions & 12 deletions lib/kernels/src/managed_per_device_ff_handle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
namespace FlexFlow {

ManagedPerDeviceFFHandle::ManagedPerDeviceFFHandle() {
handle = new PerDeviceFFHandle;
handle->workSpaceSize = 1024 * 1024;
handle->allowTensorOpMathConversion = true;
this->handle = new PerDeviceFFHandle;
this->handle->workSpaceSize = 1024 * 1024;
this->handle->allowTensorOpMathConversion = true;

checkCUDNN(cudnnCreate(&handle->dnn));
checkCUBLAS(cublasCreate(&handle->blas));
checkCUDA(cudaMalloc(&handle->workSpace, handle->workSpaceSize));
checkCUDNN(cudnnCreate(&this->handle->dnn));
checkCUBLAS(cublasCreate(&this->handle->blas));
checkCUDA(cudaMalloc(&this->handle->workSpace, this->handle->workSpaceSize));
}

ManagedPerDeviceFFHandle::ManagedPerDeviceFFHandle(
Expand All @@ -19,16 +19,25 @@ ManagedPerDeviceFFHandle::ManagedPerDeviceFFHandle(

ManagedPerDeviceFFHandle &ManagedPerDeviceFFHandle::operator=(
ManagedPerDeviceFFHandle &&other) noexcept {
std::swap(this->handle, other.handle);
if (this != &other) {
if (this->handle != nullptr) {
checkCUDNN(cudnnDestroy(this->handle->dnn));
checkCUBLAS(cublasDestroy(this->handle->blas));
checkCUDA(cudaFree(this->handle->workSpace));
delete this->handle;
}
this->handle = std::exchange(other.handle, nullptr);
}
return *this;
}

ManagedPerDeviceFFHandle::~ManagedPerDeviceFFHandle() {
if (handle != nullptr) {
checkCUDNN(cudnnDestroy(handle->dnn));
checkCUBLAS(cublasDestroy(handle->blas));
checkCUDA(cudaFree(handle->workSpace));
delete handle;
if (this->handle != nullptr) {
checkCUDNN(cudnnDestroy(this->handle->dnn));
checkCUBLAS(cublasDestroy(this->handle->blas));
checkCUDA(cudaFree(this->handle->workSpace));
delete this->handle;
this->handle = nullptr;
}
}

Expand Down
29 changes: 29 additions & 0 deletions lib/kernels/test/src/test_managed_ff_stream.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#include "doctest/doctest.h"
#include "kernels/managed_ff_stream.h"

using namespace ::FlexFlow;

TEST_SUITE(FF_TEST_SUITE) {
TEST_CASE("Test Managed FF Stream") {
ManagedFFStream base_stream{};

SUBCASE("Test ManagedFFStream Move Constructor") {
ffStream_t const *base_stream_ptr = &base_stream.raw_stream();

ManagedFFStream new_stream(std::move(base_stream));

CHECK(&base_stream.raw_stream() == nullptr);
CHECK(&new_stream.raw_stream() == base_stream_ptr);
}

SUBCASE("Test ManagedFFStream Assignment Operator") {
ffStream_t const *base_stream_ptr = &base_stream.raw_stream();

ManagedFFStream new_stream{};
new_stream = std::move(base_stream);

CHECK(&base_stream.raw_stream() == nullptr);
CHECK(&new_stream.raw_stream() == base_stream_ptr);
}
}
}
34 changes: 34 additions & 0 deletions lib/kernels/test/src/test_managed_per_device_ff_handle.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#include "doctest/doctest.h"
#include "kernels/managed_per_device_ff_handle.h"

using namespace ::FlexFlow;

TEST_SUITE(FF_TEST_SUITE) {
TEST_CASE("Test Managed Per Device FF Handle") {
ManagedPerDeviceFFHandle base_handle{};

SUBCASE("Test ManagedPerDeviceFFHandle Constructor") {
CHECK(base_handle.raw_handle().workSpaceSize == 1024 * 1024);
CHECK(base_handle.raw_handle().allowTensorOpMathConversion == true);
}

SUBCASE("Test ManagedPerDeviceFFHandle Move Constructor") {
PerDeviceFFHandle const *base_handle_ptr = &base_handle.raw_handle();

ManagedPerDeviceFFHandle new_handle(std::move(base_handle));

CHECK(&base_handle.raw_handle() == nullptr);
CHECK(&new_handle.raw_handle() == base_handle_ptr);
}

SUBCASE("Test ManagedPerDeviceFFHandle Assignment Operator") {
PerDeviceFFHandle const *base_handle_ptr = &base_handle.raw_handle();

ManagedPerDeviceFFHandle new_handle{};
new_handle = std::move(base_handle);

CHECK(&base_handle.raw_handle() == nullptr);
CHECK(&new_handle.raw_handle() == base_handle_ptr);
}
}
}

0 comments on commit 3fc8718

Please sign in to comment.