From 1e78eaa0852471a1291203d80ac4649485c8aff0 Mon Sep 17 00:00:00 2001 From: Dave Cridland Date: Tue, 23 Jul 2024 16:43:04 +0100 Subject: [PATCH] Add co_thread (from Metre) --- CMakeLists.txt | 30 +++----- sigslot/cothread.h | 101 ++++++++++++++++++++++++++ test/cothread.cc | 174 +++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 286 insertions(+), 19 deletions(-) create mode 100644 sigslot/cothread.h create mode 100644 test/cothread.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index fbf34bc..e8227e0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,6 +13,8 @@ FetchContent_Declare( FetchContent_MakeAvailable(googletest) enable_testing() +link_libraries(GTest::gtest_main) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}) add_executable(sigslot-test test/sigslot.cc test/coroutine.cc @@ -20,37 +22,27 @@ add_executable(sigslot-test sigslot/tasklet.h sigslot/resume.h ) -target_link_libraries(sigslot-test - GTest::gtest_main -) -target_include_directories(sigslot-test - PUBLIC - ${CMAKE_CURRENT_SOURCE_DIR} -) add_executable(sigslot-test-resume sigslot/sigslot.h sigslot/tasklet.h test/resume.cc sigslot/resume.h ) -target_link_libraries(sigslot-test-resume - GTest::gtest_main -) -target_include_directories(sigslot-test-resume - PUBLIC - ${CMAKE_CURRENT_SOURCE_DIR} +add_executable(sigslot-test-cothread + sigslot/sigslot.h + sigslot/tasklet.h + test/cothread.cc + sigslot/resume.h + sigslot/cothread.h ) include(GoogleTest) gtest_discover_tests(sigslot-test) gtest_discover_tests(sigslot-test-resume) +gtest_discover_tests(sigslot-test-cothread) if (UNIX) - target_compile_options(sigslot-test PUBLIC -fcoroutines) - target_link_options(sigslot-test PUBLIC -fcoroutines) - target_compile_options(sigslot-test-resume PUBLIC -fcoroutines) - target_link_options(sigslot-test-resume PUBLIC -fcoroutines) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fcoroutines") endif () if (WIN32) - target_compile_options(sigslot-test-resume PUBLIC /await) - target_compile_options(sigslot-test PUBLIC /await) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /await") endif() diff --git a/sigslot/cothread.h b/sigslot/cothread.h new file mode 100644 index 0000000..5dc2a92 --- /dev/null +++ b/sigslot/cothread.h @@ -0,0 +1,101 @@ +// +// Created by dwd on 21/12/2021. +// + +#ifndef SIGSLOT_COTHREAD_H +#define SIGSLOT_COTHREAD_H + +#include +#include "sigslot/sigslot.h" +#include "sigslot/tasklet.h" + +namespace sigslot { + template + class co_thread { + public: + // Single argument version uses a bare T + struct awaitable { + std::coroutine_handle<> awaiting = nullptr; + co_thread & wait_for; + + explicit awaitable(co_thread & t) : wait_for(t) {} + + bool await_ready() { + return wait_for.has_payload(); + } + + void await_suspend(std::coroutine_handle<> h) { + // The awaiting coroutine is already suspended. + awaiting = h; + wait_for.await(this); + } + + auto await_resume() { + return wait_for.payload(); + } + + void resolve() { + std::coroutine_handle<> a = nullptr; + std::swap(a, awaiting); + if (a) sigslot::resume_switch(a); + } + }; + private: + std::function m_fn; + std::optional m_thread; + std::optional m_payload; + std::recursive_mutex m_mutex; + awaitable * m_awaitable = nullptr; + + public: + + explicit co_thread(std::function && fn) : m_fn(std::move(fn)) {} + + co_thread & run(Args&&... args) { + auto wrapped_fn = [this](Args... a) { + auto result = m_fn(a...); + { + std::lock_guard l_(m_mutex); + m_payload.emplace(result); + if (m_awaitable) { + m_awaitable->resolve(); + } + } + }; + m_thread.emplace(wrapped_fn, args...); + return *this; + } + auto & operator() (Args&&... args) { + return this->run(args...); + } + + awaitable operator co_await() { + if (!m_thread.has_value()) throw std::logic_error("No thread started"); + return awaitable(*this); + } + + bool has_payload() { + if (!m_thread.has_value()) throw std::logic_error("No thread started"); + std::lock_guard l_(m_mutex); + return m_payload.has_value(); + } + + auto payload() { + if (!m_thread.has_value()) throw std::logic_error("No thread started"); + m_thread->join(); + m_thread.reset(); + return *m_payload; + } + + void await(awaitable * a) { + if (!m_thread.has_value()) throw std::logic_error("No thread started"); + std::lock_guard l_(m_mutex); + m_awaitable = a; + if (m_payload.has_value()) { + a->resolve(); + } + } + }; +} + +#endif diff --git a/test/cothread.cc b/test/cothread.cc new file mode 100644 index 0000000..c9a14aa --- /dev/null +++ b/test/cothread.cc @@ -0,0 +1,174 @@ +#include "gtest/gtest.h" +#include +#include +#include +#include +#include +// Tiny event loop. co_thread won't work properly without, +// since it's got to (essentially) block while the thread runs. + +std::mutex lock_me; +std::vector> resume_me; + + +namespace sigslot { + void resume(std::coroutine_handle<> coro) { + std::lock_guard l(lock_me); + resume_me.push_back(coro); + } +} + +#include +#include +#include + +template +void run_until_complete_low(sigslot::tasklet & coro) { + if (!coro.started()) coro.start(); + while (coro.running()) { + std::vector> current; + { + std::lock_guard l(lock_me); + current.swap(resume_me); + } + std::cout << "Resuming " << current.size() << " coroutines." << std::endl; + for (auto coro : current) { + coro.resume(); + } + current.clear(); + sleep(1); + std::cout << "... tick" << std::endl; + } +} +template +R run_until_complete(sigslot::tasklet & coro) { + run_until_complete_low(coro); + return coro.get(); +} +template<> +void run_until_complete(sigslot::tasklet & coro) { + run_until_complete_low(coro); + coro.get(); +} + + +sigslot::tasklet inner(std::string const & s) { + std::cout << "Here!" << std::endl; + sigslot::co_thread thread1([](std::string const &s) { + std::cout << "There 1! " << s << std::endl; + return true; + }); + sigslot::co_thread thread2([]() { + std::cout << "+ Launch" << std::endl; + sleep(1); + std::cout << "+ There 2!" << std::endl; + sleep(1); + std::cout << "+ End" << std::endl; + return true; + }); + std::cout << "Still here!" << std::endl; + thread2(); + auto result1 = co_await thread1(s); + std::cout << "Got result1:" << result1 << std::endl; + auto result2 = co_await thread2; + std::cout << "Got result2:" << result2 << std::endl; + co_return true; +} + +sigslot::tasklet start() { + std::string s = "Hello world!"; + auto result = co_await inner(s); + std::cout << "Completed test with result " << result << std::endl; +} + +namespace { + sigslot::tasklet trivial_task(int i) { + co_return i; + } + + sigslot::tasklet basic_task(sigslot::signal &signal) { + co_return co_await signal; + } + + sigslot::tasklet signal_thread_task() { + sigslot::signal signal; + sigslot::co_thread thread([&signal]() { + sleep(1); + signal(42); + sleep(1); + return 42; + }); + thread(); + auto result = co_await signal; + co_await thread; + co_return result; + } + + sigslot::tasklet nested_task(int i) { + co_return co_await trivial_task(i); + } + + sigslot::tasklet exception_task(int i) { + if (i == 42) { + // Have to do this conditionally with a co_return otherwise it's not a coroutine. + throw std::runtime_error("Help"); + } + co_return i; + } +} + +TEST(CoThreadTest, CheckLoop) { + auto coro = trivial_task(42); + auto result = run_until_complete(coro); + EXPECT_EQ(result, 42); +} + +TEST(CoThreadTest, CheckLoop2) { + sigslot::signal signal; + auto coro = basic_task(signal); + coro.start(); + int i = 0; + while (coro.running()) { + std::vector> current; + { + std::lock_guard l(lock_me); + current.swap(resume_me); + } + std::cout << "Resuming " << current.size() << " coroutines." << std::endl; + for (auto coro : current) { + coro.resume(); + } + current.clear(); + sleep(1); + if (i == 2) { + std::cout << "Signalling" << std::endl; + signal(42); + } + ++i; + std::cout << "... tick" << std::endl; + } + auto result = coro.get(); + std::cout << "Result: " << result << std::endl; +} + +TEST(CoThreadTest, Tests) { + std::cout << "Start" << std::endl; + auto coro = start(); + coro.start(); + while (coro.running()) { + std::vector> current; + { + std::lock_guard l(lock_me); + current.swap(resume_me); + } + std::cout << "Resuming " << current.size() << " coroutines." << std::endl; + for (auto coro : current) { + coro.resume(); + } + current.clear(); + sleep(1); + std::cout << "... tick" << std::endl; + } + coro.get(); + std::cout << "*** END ***" << std::endl; +}