Skip to content

Commit

Permalink
Add co_thread (from Metre)
Browse files Browse the repository at this point in the history
  • Loading branch information
dwd committed Jul 23, 2024
1 parent 65e4ff2 commit 1e78eaa
Show file tree
Hide file tree
Showing 3 changed files with 286 additions and 19 deletions.
30 changes: 11 additions & 19 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,44 +13,36 @@ FetchContent_Declare(
FetchContent_MakeAvailable(googletest)

enable_testing()
link_libraries(GTest::gtest_main)
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
add_executable(sigslot-test
test/sigslot.cc
test/coroutine.cc
sigslot/sigslot.h
sigslot/tasklet.h
sigslot/resume.h
)
target_link_libraries(sigslot-test
GTest::gtest_main
)
target_include_directories(sigslot-test
PUBLIC
${CMAKE_CURRENT_SOURCE_DIR}
)
add_executable(sigslot-test-resume
sigslot/sigslot.h
sigslot/tasklet.h
test/resume.cc
sigslot/resume.h
)
target_link_libraries(sigslot-test-resume
GTest::gtest_main
)
target_include_directories(sigslot-test-resume
PUBLIC
${CMAKE_CURRENT_SOURCE_DIR}
add_executable(sigslot-test-cothread
sigslot/sigslot.h
sigslot/tasklet.h
test/cothread.cc
sigslot/resume.h
sigslot/cothread.h
)
include(GoogleTest)
gtest_discover_tests(sigslot-test)
gtest_discover_tests(sigslot-test-resume)
gtest_discover_tests(sigslot-test-cothread)

if (UNIX)
target_compile_options(sigslot-test PUBLIC -fcoroutines)
target_link_options(sigslot-test PUBLIC -fcoroutines)
target_compile_options(sigslot-test-resume PUBLIC -fcoroutines)
target_link_options(sigslot-test-resume PUBLIC -fcoroutines)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fcoroutines")
endif ()
if (WIN32)
target_compile_options(sigslot-test-resume PUBLIC /await)
target_compile_options(sigslot-test PUBLIC /await)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /await")
endif()
101 changes: 101 additions & 0 deletions sigslot/cothread.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
//
// Created by dwd on 21/12/2021.
//

#ifndef SIGSLOT_COTHREAD_H
#define SIGSLOT_COTHREAD_H

#include <thread>
#include "sigslot/sigslot.h"
#include "sigslot/tasklet.h"

namespace sigslot {
template<typename Result, class... Args>
class co_thread {
public:
// Single argument version uses a bare T
struct awaitable {
std::coroutine_handle<> awaiting = nullptr;
co_thread & wait_for;

explicit awaitable(co_thread & t) : wait_for(t) {}

bool await_ready() {
return wait_for.has_payload();
}

void await_suspend(std::coroutine_handle<> h) {
// The awaiting coroutine is already suspended.
awaiting = h;
wait_for.await(this);
}

auto await_resume() {
return wait_for.payload();
}

void resolve() {
std::coroutine_handle<> a = nullptr;
std::swap(a, awaiting);
if (a) sigslot::resume_switch(a);
}
};
private:
std::function<Result(Args...)> m_fn;
std::optional<std::jthread> m_thread;
std::optional<Result> m_payload;
std::recursive_mutex m_mutex;
awaitable * m_awaitable = nullptr;

public:

explicit co_thread(std::function<Result(Args...)> && fn) : m_fn(std::move(fn)) {}

co_thread & run(Args&&... args) {
auto wrapped_fn = [this](Args... a) {
auto result = m_fn(a...);
{
std::lock_guard l_(m_mutex);
m_payload.emplace(result);
if (m_awaitable) {
m_awaitable->resolve();
}
}
};
m_thread.emplace(wrapped_fn, args...);
return *this;
}
auto & operator() (Args&&... args) {
return this->run(args...);
}

awaitable operator co_await() {
if (!m_thread.has_value()) throw std::logic_error("No thread started");
return awaitable(*this);
}

bool has_payload() {
if (!m_thread.has_value()) throw std::logic_error("No thread started");
std::lock_guard l_(m_mutex);
return m_payload.has_value();
}

auto payload() {
if (!m_thread.has_value()) throw std::logic_error("No thread started");
m_thread->join();
m_thread.reset();
return *m_payload;
}

void await(awaitable * a) {
if (!m_thread.has_value()) throw std::logic_error("No thread started");
std::lock_guard l_(m_mutex);
m_awaitable = a;
if (m_payload.has_value()) {
a->resolve();
}
}
};
}

#endif
174 changes: 174 additions & 0 deletions test/cothread.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
#include "gtest/gtest.h"
#include <iostream>
#include <regex>
#include <coroutine>
#include <list>
#include <sigslot/resume.h>
// Tiny event loop. co_thread won't work properly without,
// since it's got to (essentially) block while the thread runs.

std::mutex lock_me;
std::vector<std::coroutine_handle<>> resume_me;


namespace sigslot {
void resume(std::coroutine_handle<> coro) {
std::lock_guard l(lock_me);
resume_me.push_back(coro);
}
}

#include <sigslot/sigslot.h>
#include <sigslot/tasklet.h>
#include <sigslot/cothread.h>

template<typename R>
void run_until_complete_low(sigslot::tasklet<R> & coro) {
if (!coro.started()) coro.start();
while (coro.running()) {
std::vector<std::coroutine_handle<>> current;
{
std::lock_guard l(lock_me);
current.swap(resume_me);
}
std::cout << "Resuming " << current.size() << " coroutines." << std::endl;
for (auto coro : current) {
coro.resume();
}
current.clear();
sleep(1);
std::cout << "... tick" << std::endl;
}
}
template<typename R>
R run_until_complete(sigslot::tasklet<R> & coro) {
run_until_complete_low(coro);
return coro.get();
}
template<>
void run_until_complete<void>(sigslot::tasklet<void> & coro) {
run_until_complete_low(coro);
coro.get();
}


sigslot::tasklet<bool> inner(std::string const & s) {
std::cout << "Here!" << std::endl;
sigslot::co_thread<bool, std::string const &> thread1([](std::string const &s) {
std::cout << "There 1! " << s << std::endl;
return true;
});
sigslot::co_thread<bool> thread2([]() {
std::cout << "+ Launch" << std::endl;
sleep(1);
std::cout << "+ There 2!" << std::endl;
sleep(1);
std::cout << "+ End" << std::endl;
return true;
});
std::cout << "Still here!" << std::endl;
thread2();
auto result1 = co_await thread1(s);
std::cout << "Got result1:" << result1 << std::endl;
auto result2 = co_await thread2;
std::cout << "Got result2:" << result2 << std::endl;
co_return true;
}

sigslot::tasklet<void> start() {
std::string s = "Hello world!";
auto result = co_await inner(s);
std::cout << "Completed test with result " << result << std::endl;
}

namespace {
sigslot::tasklet<int> trivial_task(int i) {
co_return i;
}

sigslot::tasklet<int> basic_task(sigslot::signal<int> &signal) {
co_return co_await signal;
}

sigslot::tasklet<int> signal_thread_task() {
sigslot::signal<int> signal;
sigslot::co_thread<int> thread([&signal]() {
sleep(1);
signal(42);
sleep(1);
return 42;
});
thread();
auto result = co_await signal;
co_await thread;
co_return result;
}

sigslot::tasklet<int> nested_task(int i) {
co_return co_await trivial_task(i);
}

sigslot::tasklet<int> exception_task(int i) {
if (i == 42) {
// Have to do this conditionally with a co_return otherwise it's not a coroutine.
throw std::runtime_error("Help");
}
co_return i;
}
}

TEST(CoThreadTest, CheckLoop) {
auto coro = trivial_task(42);
auto result = run_until_complete(coro);
EXPECT_EQ(result, 42);
}

TEST(CoThreadTest, CheckLoop2) {
sigslot::signal<int> signal;
auto coro = basic_task(signal);
coro.start();
int i = 0;
while (coro.running()) {
std::vector<std::coroutine_handle<>> current;
{
std::lock_guard l(lock_me);
current.swap(resume_me);
}
std::cout << "Resuming " << current.size() << " coroutines." << std::endl;
for (auto coro : current) {
coro.resume();
}
current.clear();
sleep(1);
if (i == 2) {
std::cout << "Signalling" << std::endl;
signal(42);
}
++i;
std::cout << "... tick" << std::endl;
}
auto result = coro.get();
std::cout << "Result: " << result << std::endl;
}

TEST(CoThreadTest, Tests) {
std::cout << "Start" << std::endl;
auto coro = start();
coro.start();
while (coro.running()) {
std::vector<std::coroutine_handle<>> current;
{
std::lock_guard l(lock_me);
current.swap(resume_me);
}
std::cout << "Resuming " << current.size() << " coroutines." << std::endl;
for (auto coro : current) {
coro.resume();
}
current.clear();
sleep(1);
std::cout << "... tick" << std::endl;
}
coro.get();
std::cout << "*** END ***" << std::endl;
}

0 comments on commit 1e78eaa

Please sign in to comment.