From 1ae9e76375cdffa3592d65af48f775b091d8a7e5 Mon Sep 17 00:00:00 2001 From: Alexey Ozeritskiy Date: Thu, 30 Nov 2023 20:36:21 +0100 Subject: [PATCH] Implement full reader writer for socket --- coroio/all.hpp | 2 + coroio/corochain.hpp | 107 +++++++++++++++++++++++++++++++++++++++++++ coroio/sockutils.hpp | 59 ++++++++++++++++++++++++ tests/tests.cpp | 39 ++++++++++++++++ 4 files changed, 207 insertions(+) create mode 100644 coroio/corochain.hpp create mode 100644 coroio/sockutils.hpp diff --git a/coroio/all.hpp b/coroio/all.hpp index 16731f7..a4de377 100644 --- a/coroio/all.hpp +++ b/coroio/all.hpp @@ -28,3 +28,5 @@ #include "loop.hpp" #include "promises.hpp" #include "socket.hpp" +#include "corochain.hpp" +#include "sockutils.hpp" \ No newline at end of file diff --git a/coroio/corochain.hpp b/coroio/corochain.hpp new file mode 100644 index 0000000..a877f75 --- /dev/null +++ b/coroio/corochain.hpp @@ -0,0 +1,107 @@ +#pragma once + +#include +#include +#include +#include + +namespace NNet { + +template struct TFinalSuspendContinuation; + +template struct TValueTask; + +template +struct TValuePromiseBase { + std::suspend_never initial_suspend() { return {}; } + TFinalSuspendContinuation final_suspend() noexcept; + std::coroutine_handle<> Caller = std::noop_coroutine(); +}; + +template +struct TValuePromise: public TValuePromiseBase { + TValueTask get_return_object(); + + void return_value(const T& t) { + ErrorOr = t; + } + + void unhandled_exception() { + ErrorOr = std::current_exception(); + } + + std::optional> ErrorOr; +}; + +template +struct TValueTaskBase : std::coroutine_handle> { + ~TValueTaskBase() { this->destroy(); } + + bool await_ready() { + return !!this->promise().ErrorOr; + } + + void await_suspend(std::coroutine_handle<> caller) { + this->promise().Caller = caller; + } + + using promise_type = TValuePromise; +}; + +template +struct TValueTask : public TValueTaskBase { + T await_resume() { + auto& errorOr = *this->promise()->ErrorOr; + if (auto* res = std::get_if(&errorOr)) { + return *res; + } else { + std::rethrow_exception(std::get(errorOr)); + } + } +}; + +template<> struct TValueTask; + +template<> +struct TValuePromise: public TValuePromiseBase { + TValueTask get_return_object(); + + void return_void() { + ErrorOr = nullptr; + } + + void unhandled_exception() { + ErrorOr = std::current_exception(); + } + + std::optional ErrorOr; +}; + +template<> +struct TValueTask : public TValueTaskBase { + void await_resume() { + auto& errorOr = *this->promise().ErrorOr; + if (errorOr) { + std::rethrow_exception(errorOr); + } + } +}; + +template +struct TFinalSuspendContinuation { + bool await_ready() noexcept { return false; } + std::coroutine_handle<> await_suspend(std::coroutine_handle> h) noexcept { + return h.promise().Caller; + } + void await_resume() noexcept { } +}; + +inline TValueTask TValuePromise::get_return_object() { return { TValueTask::from_promise(*this) }; } +template +TValueTask TValuePromise::get_return_object() { return { TValueTask::from_promise(*this) }; } + + +template +TFinalSuspendContinuation TValuePromiseBase::final_suspend() noexcept { return {}; } + +} // namespace NNet diff --git a/coroio/sockutils.hpp b/coroio/sockutils.hpp new file mode 100644 index 0000000..9fb7771 --- /dev/null +++ b/coroio/sockutils.hpp @@ -0,0 +1,59 @@ +#pragma once + +#include "corochain.hpp" + +namespace NNet { + +template +struct TReader { + TReader(TSocket& socket) + : Socket(socket) + { } + + TValueTask Read(void* data, size_t size) { + char* p = static_cast(data); + while (size != 0) { + auto read_size = co_await Socket.ReadSome(p, size); + if (read_size == 0) { + throw std::runtime_error("Connection closed"); + } + if (read_size < 0) { + continue; // retry + } + p += read_size; + size -= read_size; + } + co_return; + } + +private: + TSocket& Socket; +}; + +template +struct TWriter { + TWriter(TSocket& socket) + : Socket(socket) + { } + + TValueTask Write(const void* data, size_t size) { + const char* p = static_cast(data); + while (size != 0) { + auto read_size = co_await Socket.WriteSome(const_cast(p) /* TODO: cast */, size); + if (read_size == 0) { + throw std::runtime_error("Connection closed"); + } + if (read_size < 0) { + continue; // retry + } + p += read_size; + size -= read_size; + } + co_return; + } + +private: + TSocket& Socket; +}; + +} // namespace NNet { diff --git a/tests/tests.cpp b/tests/tests.cpp index 54fa147..73cc092 100644 --- a/tests/tests.cpp +++ b/tests/tests.cpp @@ -406,6 +406,45 @@ void test_timeout(void**) { assert_true(next >= now + timeout); } +void test_read_write_full(void**) { + std::vector data(1024*1024); + int cur = 0; + for (auto& ch : data) { + ch = cur + 'a'; + cur = (cur + 1) % ('z' - 'a' + 1); + } + + NNet::TLoop loop; + NNet::TSocket socket(NNet::TAddress{"127.0.0.1", 8988}, loop.Poller()); + socket.Bind(); + socket.Listen(); + + NNet::TSocket client(NNet::TAddress{"127.0.0.1", 8988}, loop.Poller()); + + NNet::TTestTask h1 = [](NNet::TSocket& client, const std::vector& data) -> NNet::TTestTask + { + co_await client.Connect(); + co_await TWriter(client).Write(data.data(), data.size()); + co_return; + }(client, data); + + std::vector received(1024*1024); + NNet::TTestTask h2 = [](NNet::TSocket& server, std::vector& received) -> NNet::TTestTask + { + auto client = std::move(co_await server.Accept()); + co_await TReader(client).Read(received.data(), received.size()); + co_return; + }(socket, received); + + while (!(h1.done() && h2.done())) { + loop.Step(); + } + + assert_memory_equal(data.data(), received.data(), data.size()); + + h1.destroy(); h2.destroy(); +} + #ifdef __linux__ void test_uring_create(void**) { TUring uring(256);