forked from WanderPig/SigSlot
-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
286 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
// | ||
// Created by dwd on 21/12/2021. | ||
// | ||
|
||
#ifndef SIGSLOT_COTHREAD_H | ||
#define SIGSLOT_COTHREAD_H | ||
|
||
#include <thread> | ||
#include "sigslot/sigslot.h" | ||
#include "sigslot/tasklet.h" | ||
|
||
namespace sigslot { | ||
template<typename Result, class... Args> | ||
class co_thread { | ||
public: | ||
// Single argument version uses a bare T | ||
struct awaitable { | ||
std::coroutine_handle<> awaiting = nullptr; | ||
co_thread & wait_for; | ||
|
||
explicit awaitable(co_thread & t) : wait_for(t) {} | ||
|
||
bool await_ready() { | ||
return wait_for.has_payload(); | ||
} | ||
|
||
void await_suspend(std::coroutine_handle<> h) { | ||
// The awaiting coroutine is already suspended. | ||
awaiting = h; | ||
wait_for.await(this); | ||
} | ||
|
||
auto await_resume() { | ||
return wait_for.payload(); | ||
} | ||
|
||
void resolve() { | ||
std::coroutine_handle<> a = nullptr; | ||
std::swap(a, awaiting); | ||
if (a) sigslot::resume_switch(a); | ||
} | ||
}; | ||
private: | ||
std::function<Result(Args...)> m_fn; | ||
std::optional<std::jthread> m_thread; | ||
std::optional<Result> m_payload; | ||
std::recursive_mutex m_mutex; | ||
awaitable * m_awaitable = nullptr; | ||
|
||
public: | ||
|
||
explicit co_thread(std::function<Result(Args...)> && fn) : m_fn(std::move(fn)) {} | ||
|
||
co_thread & run(Args&&... args) { | ||
auto wrapped_fn = [this](Args... a) { | ||
auto result = m_fn(a...); | ||
{ | ||
std::lock_guard l_(m_mutex); | ||
m_payload.emplace(result); | ||
if (m_awaitable) { | ||
m_awaitable->resolve(); | ||
} | ||
} | ||
}; | ||
m_thread.emplace(wrapped_fn, args...); | ||
return *this; | ||
} | ||
auto & operator() (Args&&... args) { | ||
return this->run(args...); | ||
} | ||
|
||
awaitable operator co_await() { | ||
if (!m_thread.has_value()) throw std::logic_error("No thread started"); | ||
return awaitable(*this); | ||
} | ||
|
||
bool has_payload() { | ||
if (!m_thread.has_value()) throw std::logic_error("No thread started"); | ||
std::lock_guard l_(m_mutex); | ||
return m_payload.has_value(); | ||
} | ||
|
||
auto payload() { | ||
if (!m_thread.has_value()) throw std::logic_error("No thread started"); | ||
m_thread->join(); | ||
m_thread.reset(); | ||
return *m_payload; | ||
} | ||
|
||
void await(awaitable * a) { | ||
if (!m_thread.has_value()) throw std::logic_error("No thread started"); | ||
std::lock_guard l_(m_mutex); | ||
m_awaitable = a; | ||
if (m_payload.has_value()) { | ||
a->resolve(); | ||
} | ||
} | ||
}; | ||
} | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
#include "gtest/gtest.h" | ||
#include <iostream> | ||
#include <regex> | ||
#include <coroutine> | ||
#include <list> | ||
#include <sigslot/resume.h> | ||
// Tiny event loop. co_thread won't work properly without, | ||
// since it's got to (essentially) block while the thread runs. | ||
|
||
std::mutex lock_me; | ||
std::vector<std::coroutine_handle<>> resume_me; | ||
|
||
|
||
namespace sigslot { | ||
void resume(std::coroutine_handle<> coro) { | ||
std::lock_guard l(lock_me); | ||
resume_me.push_back(coro); | ||
} | ||
} | ||
|
||
#include <sigslot/sigslot.h> | ||
#include <sigslot/tasklet.h> | ||
#include <sigslot/cothread.h> | ||
|
||
template<typename R> | ||
void run_until_complete_low(sigslot::tasklet<R> & coro) { | ||
if (!coro.started()) coro.start(); | ||
while (coro.running()) { | ||
std::vector<std::coroutine_handle<>> current; | ||
{ | ||
std::lock_guard l(lock_me); | ||
current.swap(resume_me); | ||
} | ||
std::cout << "Resuming " << current.size() << " coroutines." << std::endl; | ||
for (auto coro : current) { | ||
coro.resume(); | ||
} | ||
current.clear(); | ||
sleep(1); | ||
std::cout << "... tick" << std::endl; | ||
} | ||
} | ||
template<typename R> | ||
R run_until_complete(sigslot::tasklet<R> & coro) { | ||
run_until_complete_low(coro); | ||
return coro.get(); | ||
} | ||
template<> | ||
void run_until_complete<void>(sigslot::tasklet<void> & coro) { | ||
run_until_complete_low(coro); | ||
coro.get(); | ||
} | ||
|
||
|
||
sigslot::tasklet<bool> inner(std::string const & s) { | ||
std::cout << "Here!" << std::endl; | ||
sigslot::co_thread<bool, std::string const &> thread1([](std::string const &s) { | ||
std::cout << "There 1! " << s << std::endl; | ||
return true; | ||
}); | ||
sigslot::co_thread<bool> thread2([]() { | ||
std::cout << "+ Launch" << std::endl; | ||
sleep(1); | ||
std::cout << "+ There 2!" << std::endl; | ||
sleep(1); | ||
std::cout << "+ End" << std::endl; | ||
return true; | ||
}); | ||
std::cout << "Still here!" << std::endl; | ||
thread2(); | ||
auto result1 = co_await thread1(s); | ||
std::cout << "Got result1:" << result1 << std::endl; | ||
auto result2 = co_await thread2; | ||
std::cout << "Got result2:" << result2 << std::endl; | ||
co_return true; | ||
} | ||
|
||
sigslot::tasklet<void> start() { | ||
std::string s = "Hello world!"; | ||
auto result = co_await inner(s); | ||
std::cout << "Completed test with result " << result << std::endl; | ||
} | ||
|
||
namespace { | ||
sigslot::tasklet<int> trivial_task(int i) { | ||
co_return i; | ||
} | ||
|
||
sigslot::tasklet<int> basic_task(sigslot::signal<int> &signal) { | ||
co_return co_await signal; | ||
} | ||
|
||
sigslot::tasklet<int> signal_thread_task() { | ||
sigslot::signal<int> signal; | ||
sigslot::co_thread<int> thread([&signal]() { | ||
sleep(1); | ||
signal(42); | ||
sleep(1); | ||
return 42; | ||
}); | ||
thread(); | ||
auto result = co_await signal; | ||
co_await thread; | ||
co_return result; | ||
} | ||
|
||
sigslot::tasklet<int> nested_task(int i) { | ||
co_return co_await trivial_task(i); | ||
} | ||
|
||
sigslot::tasklet<int> exception_task(int i) { | ||
if (i == 42) { | ||
// Have to do this conditionally with a co_return otherwise it's not a coroutine. | ||
throw std::runtime_error("Help"); | ||
} | ||
co_return i; | ||
} | ||
} | ||
|
||
TEST(CoThreadTest, CheckLoop) { | ||
auto coro = trivial_task(42); | ||
auto result = run_until_complete(coro); | ||
EXPECT_EQ(result, 42); | ||
} | ||
|
||
TEST(CoThreadTest, CheckLoop2) { | ||
sigslot::signal<int> signal; | ||
auto coro = basic_task(signal); | ||
coro.start(); | ||
int i = 0; | ||
while (coro.running()) { | ||
std::vector<std::coroutine_handle<>> current; | ||
{ | ||
std::lock_guard l(lock_me); | ||
current.swap(resume_me); | ||
} | ||
std::cout << "Resuming " << current.size() << " coroutines." << std::endl; | ||
for (auto coro : current) { | ||
coro.resume(); | ||
} | ||
current.clear(); | ||
sleep(1); | ||
if (i == 2) { | ||
std::cout << "Signalling" << std::endl; | ||
signal(42); | ||
} | ||
++i; | ||
std::cout << "... tick" << std::endl; | ||
} | ||
auto result = coro.get(); | ||
std::cout << "Result: " << result << std::endl; | ||
} | ||
|
||
TEST(CoThreadTest, Tests) { | ||
std::cout << "Start" << std::endl; | ||
auto coro = start(); | ||
coro.start(); | ||
while (coro.running()) { | ||
std::vector<std::coroutine_handle<>> current; | ||
{ | ||
std::lock_guard l(lock_me); | ||
current.swap(resume_me); | ||
} | ||
std::cout << "Resuming " << current.size() << " coroutines." << std::endl; | ||
for (auto coro : current) { | ||
coro.resume(); | ||
} | ||
current.clear(); | ||
sleep(1); | ||
std::cout << "... tick" << std::endl; | ||
} | ||
coro.get(); | ||
std::cout << "*** END ***" << std::endl; | ||
} |