diff --git a/gloo/common/CMakeLists.txt b/gloo/common/CMakeLists.txt index a8d47449..96793c52 100644 --- a/gloo/common/CMakeLists.txt +++ b/gloo/common/CMakeLists.txt @@ -1,6 +1,7 @@ set(GLOO_COMMON_SRCS "${CMAKE_CURRENT_SOURCE_DIR}/logging.cc" "${CMAKE_CURRENT_SOURCE_DIR}/utils.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/error.cc" ) set(GLOO_COMMON_HDRS diff --git a/gloo/common/error.cc b/gloo/common/error.cc new file mode 100644 index 00000000..c14f38b4 --- /dev/null +++ b/gloo/common/error.cc @@ -0,0 +1,46 @@ +/** + * Copyright (c) 2017-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include "gloo/common/error.h" + +namespace gloo { + + +std::list _cvs; +std::mutex _cvs_mutex; + +std::atomic_bool _is_aborted_flag(false); + +bool _is_aborted() { + return _is_aborted_flag.load(); +} + +void abort() { + _is_aborted_flag.store(true); + std::lock_guard guard(_cvs_mutex); + for(auto& cv : _cvs) { + if(cv != NULL) { + cv->notify_all(); + } + } + GLOO_THROW("GLOO ABORTED"); +} + +void _register_cv(std::condition_variable *cv) { + std::lock_guard guard(_cvs_mutex); + _cvs.push_back(cv); +} + +void _deregister_cv(std::condition_variable *cv) { + std::lock_guard guard(_cvs_mutex); + _cvs.remove(cv); +} +} // namespace gloo diff --git a/gloo/common/error.h b/gloo/common/error.h index 4eac45ec..c7e98fa4 100644 --- a/gloo/common/error.h +++ b/gloo/common/error.h @@ -10,6 +10,7 @@ #include #include +#include #include "gloo/common/string.h" @@ -20,6 +21,11 @@ namespace gloo { const std::chrono::milliseconds kNoTimeout = std::chrono::milliseconds::zero(); +bool _is_aborted(); +void abort(); +void _register_cv(std::condition_variable *cv); +void _deregister_cv(std::condition_variable *cv); + // A base class for all gloo runtime errors struct Exception : public std::runtime_error { Exception() = delete; diff --git a/gloo/test/CMakeLists.txt b/gloo/test/CMakeLists.txt index f7eb97a0..422d4251 100644 --- a/gloo/test/CMakeLists.txt +++ b/gloo/test/CMakeLists.txt @@ -1,6 +1,7 @@ find_package(OpenSSL 1.1 REQUIRED EXACT) set(GLOO_TEST_SRCS + "${CMAKE_CURRENT_SOURCE_DIR}/abort_test.cc" "${CMAKE_CURRENT_SOURCE_DIR}/allgather_test.cc" "${CMAKE_CURRENT_SOURCE_DIR}/allgatherv_test.cc" "${CMAKE_CURRENT_SOURCE_DIR}/allreduce_test.cc" diff --git a/gloo/test/abort_test.cc b/gloo/test/abort_test.cc new file mode 100644 index 00000000..da8c4ac6 --- /dev/null +++ b/gloo/test/abort_test.cc @@ -0,0 +1,80 @@ +/** + * Copyright (c) 2017-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#include "gloo/barrier_all_to_all.h" +#include "gloo/barrier_all_to_one.h" +#include "gloo/broadcast.h" +#include "gloo/test/base_test.h" + +namespace gloo { +namespace test { +namespace { + +// Synchronized version of std::chrono::clock::now(). +// All processes participating in the specified context will +// see the same value. +template +std::chrono::time_point syncNow(std::shared_ptr context) { + const typename clock::time_point now = clock::now(); + typename clock::duration::rep count = now.time_since_epoch().count(); + BroadcastOptions opts(context); + opts.setRoot(0); + opts.setOutput(&count, 1); + broadcast(opts); + return typename clock::time_point(typename clock::duration(count)); +} + +using NewParam = std::tuple; + +class AbortBarrierTest : public BaseTest, + public ::testing::WithParamInterface {}; + +TEST_P(AbortBarrierTest, Default) { + const auto transport = std::get<0>(GetParam()); + const auto contextSize = std::get<1>(GetParam()); + + spawn(transport, contextSize, [&](std::shared_ptr context) { + BarrierOptions opts(context); + + // Run barrier to synchronize processes after starting. + barrier(opts); + + auto timeout = std::chrono::milliseconds(context->getTimeout()); + const auto start = syncNow(context); + // Run barrier on all ranks but 0 so it hangs + if (context->rank != 0) { + barrier(opts); + } + + // Abort should unhang the barrier + try { + abort(); + } catch (const Exception &e) { + EXPECT_TRUE(strstr(e.what(), "GLOO ABORTED") != NULL); + } + + // Expect all processes to have taken less than the timeout, as abort was + // called + auto stop = std::chrono::high_resolution_clock::now(); + auto delta = std::chrono::duration_cast(stop - start); + ASSERT_LE(delta.count(), timeout.count() / 4); + }); +} + +INSTANTIATE_TEST_CASE_P( + AbortBarrier, AbortBarrierTest, + ::testing::Combine(::testing::ValuesIn(kTransportsForFunctionAlgorithms), + ::testing::Values(1, 2, 4, 7))); + +} // namespace +} // namespace test +} // namespace gloo diff --git a/gloo/transport/tcp/unbound_buffer.cc b/gloo/transport/tcp/unbound_buffer.cc index b8cac467..5210eb48 100644 --- a/gloo/transport/tcp/unbound_buffer.cc +++ b/gloo/transport/tcp/unbound_buffer.cc @@ -26,9 +26,15 @@ UnboundBuffer::UnboundBuffer( recvRank_(-1), sendCompletions_(0), sendRank_(-1), - shareableNonOwningPtr_(this) {} + shareableNonOwningPtr_(this) { + gloo::_register_cv(&recvCv_); + gloo::_register_cv(&sendCv_); +} -UnboundBuffer::~UnboundBuffer() {} +UnboundBuffer::~UnboundBuffer() { + gloo::_deregister_cv(&recvCv_); + gloo::_deregister_cv(&sendCv_); +} void UnboundBuffer::handleRecvCompletion(int rank) { std::lock_guard lock(m_); @@ -58,6 +64,9 @@ bool UnboundBuffer::waitRecv(int* rank, std::chrono::milliseconds timeout) { if (recvCompletions_ == 0) { auto done = recvCv_.wait_for(lock, timeout, [&] { throwIfException(); + if(gloo::_is_aborted()) { + abortWaitRecv_ = true; + } return abortWaitRecv_ || recvCompletions_ > 0; }); if (!done) { @@ -109,6 +118,9 @@ bool UnboundBuffer::waitSend(int* rank, std::chrono::milliseconds timeout) { if (sendCompletions_ == 0) { auto done = sendCv_.wait_for(lock, timeout, [&] { throwIfException(); + if(gloo::_is_aborted()) { + abortWaitSend_ = true; + } return abortWaitSend_ || sendCompletions_ > 0; }); if (!done) { diff --git a/gloo/transport/uv/unbound_buffer.cc b/gloo/transport/uv/unbound_buffer.cc index 593020b3..b237cf15 100644 --- a/gloo/transport/uv/unbound_buffer.cc +++ b/gloo/transport/uv/unbound_buffer.cc @@ -26,9 +26,15 @@ UnboundBuffer::UnboundBuffer( recvRank_(-1), sendCompletions_(0), sendRank_(-1), - shareableNonOwningPtr_(this) {} + shareableNonOwningPtr_(this) { + gloo::_register_cv(&recvCv_); + gloo::_register_cv(&sendCv_); +} -UnboundBuffer::~UnboundBuffer() {} +UnboundBuffer::~UnboundBuffer() { + gloo::_deregister_cv(&recvCv_); + gloo::_deregister_cv(&sendCv_); +} void UnboundBuffer::handleRecvCompletion(int rank) { std::lock_guard lock(mutex_); @@ -56,8 +62,12 @@ bool UnboundBuffer::waitRecv(int* rank, std::chrono::milliseconds timeout) { } if (recvCompletions_ == 0) { - auto done = recvCv_.wait_for( - lock, timeout, [&] { return abortWaitRecv_ || recvCompletions_ > 0; }); + auto done = recvCv_.wait_for(lock, timeout, [&] { + if(gloo::_is_aborted()) { + abortWaitRecv_ = true; + } + return abortWaitRecv_ || recvCompletions_ > 0; + }); if (!done) { throw ::gloo::IoException(GLOO_ERROR_MSG( "Timed out waiting ", @@ -92,8 +102,12 @@ bool UnboundBuffer::waitSend(int* rank, std::chrono::milliseconds timeout) { } if (sendCompletions_ == 0) { - auto done = sendCv_.wait_for( - lock, timeout, [&] { return abortWaitSend_ || sendCompletions_ > 0; }); + auto done = sendCv_.wait_for(lock, timeout, [&] { + if(gloo::_is_aborted()) { + abortWaitSend_ = true; + } + return abortWaitSend_ || sendCompletions_ > 0; + }); if (!done) { throw ::gloo::IoException(GLOO_ERROR_MSG( "Timed out waiting ",