Skip to content

Commit

Permalink
Implement full reader writer for socket
Browse files Browse the repository at this point in the history
  • Loading branch information
resetius committed Nov 30, 2023
1 parent 6c5a53a commit 1ae9e76
Show file tree
Hide file tree
Showing 4 changed files with 207 additions and 0 deletions.
2 changes: 2 additions & 0 deletions coroio/all.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,5 @@
#include "loop.hpp"
#include "promises.hpp"
#include "socket.hpp"
#include "corochain.hpp"
#include "sockutils.hpp"
107 changes: 107 additions & 0 deletions coroio/corochain.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
#pragma once

#include <coroutine>
#include <optional>
#include <variant>
#include <memory>

namespace NNet {

template<typename T> struct TFinalSuspendContinuation;

template<typename T> struct TValueTask;

template<typename T>
struct TValuePromiseBase {
std::suspend_never initial_suspend() { return {}; }
TFinalSuspendContinuation<T> final_suspend() noexcept;
std::coroutine_handle<> Caller = std::noop_coroutine();
};

template<typename T>
struct TValuePromise: public TValuePromiseBase<T> {
TValueTask<T> get_return_object();

void return_value(const T& t) {
ErrorOr = t;
}

void unhandled_exception() {
ErrorOr = std::current_exception();
}

std::optional<std::variant<T, std::exception_ptr>> ErrorOr;
};

template<typename T>
struct TValueTaskBase : std::coroutine_handle<TValuePromise<T>> {
~TValueTaskBase() { this->destroy(); }

bool await_ready() {
return !!this->promise().ErrorOr;
}

void await_suspend(std::coroutine_handle<> caller) {
this->promise().Caller = caller;
}

using promise_type = TValuePromise<T>;
};

template<typename T>
struct TValueTask : public TValueTaskBase<T> {
T await_resume() {
auto& errorOr = *this->promise()->ErrorOr;
if (auto* res = std::get_if(&errorOr)) {
return *res;
} else {
std::rethrow_exception(std::get<std::exception_ptr>(errorOr));
}
}
};

template<> struct TValueTask<void>;

template<>
struct TValuePromise<void>: public TValuePromiseBase<void> {
TValueTask<void> get_return_object();

void return_void() {
ErrorOr = nullptr;
}

void unhandled_exception() {
ErrorOr = std::current_exception();
}

std::optional<std::exception_ptr> ErrorOr;
};

template<>
struct TValueTask<void> : public TValueTaskBase<void> {
void await_resume() {
auto& errorOr = *this->promise().ErrorOr;
if (errorOr) {
std::rethrow_exception(errorOr);
}
}
};

template<typename T>
struct TFinalSuspendContinuation {
bool await_ready() noexcept { return false; }
std::coroutine_handle<> await_suspend(std::coroutine_handle<TValuePromise<T>> h) noexcept {
return h.promise().Caller;
}
void await_resume() noexcept { }
};

inline TValueTask<void> TValuePromise<void>::get_return_object() { return { TValueTask<void>::from_promise(*this) }; }
template<typename T>
TValueTask<T> TValuePromise<T>::get_return_object() { return { TValueTask<T>::from_promise(*this) }; }


template<typename T>
TFinalSuspendContinuation<T> TValuePromiseBase<T>::final_suspend() noexcept { return {}; }

} // namespace NNet
59 changes: 59 additions & 0 deletions coroio/sockutils.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#pragma once

#include "corochain.hpp"

namespace NNet {

template<typename TSocket>
struct TReader {
TReader(TSocket& socket)
: Socket(socket)
{ }

TValueTask<void> Read(void* data, size_t size) {
char* p = static_cast<char*>(data);
while (size != 0) {
auto read_size = co_await Socket.ReadSome(p, size);
if (read_size == 0) {
throw std::runtime_error("Connection closed");
}
if (read_size < 0) {
continue; // retry
}
p += read_size;
size -= read_size;
}
co_return;
}

private:
TSocket& Socket;
};

template<typename TSocket>
struct TWriter {
TWriter(TSocket& socket)
: Socket(socket)
{ }

TValueTask<void> Write(const void* data, size_t size) {
const char* p = static_cast<const char*>(data);
while (size != 0) {
auto read_size = co_await Socket.WriteSome(const_cast<char*>(p) /* TODO: cast */, size);
if (read_size == 0) {
throw std::runtime_error("Connection closed");
}
if (read_size < 0) {
continue; // retry
}
p += read_size;
size -= read_size;
}
co_return;
}

private:
TSocket& Socket;
};

} // namespace NNet {
39 changes: 39 additions & 0 deletions tests/tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,45 @@ void test_timeout(void**) {
assert_true(next >= now + timeout);
}

void test_read_write_full(void**) {
std::vector<char> data(1024*1024);
int cur = 0;
for (auto& ch : data) {
ch = cur + 'a';
cur = (cur + 1) % ('z' - 'a' + 1);
}

NNet::TLoop<NNet::TPoll> loop;
NNet::TSocket socket(NNet::TAddress{"127.0.0.1", 8988}, loop.Poller());
socket.Bind();
socket.Listen();

NNet::TSocket client(NNet::TAddress{"127.0.0.1", 8988}, loop.Poller());

NNet::TTestTask h1 = [](NNet::TSocket& client, const std::vector<char>& data) -> NNet::TTestTask
{
co_await client.Connect();
co_await TWriter(client).Write(data.data(), data.size());
co_return;
}(client, data);

std::vector<char> received(1024*1024);
NNet::TTestTask h2 = [](NNet::TSocket& server, std::vector<char>& received) -> NNet::TTestTask
{
auto client = std::move(co_await server.Accept());
co_await TReader(client).Read(received.data(), received.size());
co_return;
}(socket, received);

while (!(h1.done() && h2.done())) {
loop.Step();
}

assert_memory_equal(data.data(), received.data(), data.size());

h1.destroy(); h2.destroy();
}

#ifdef __linux__
void test_uring_create(void**) {
TUring uring(256);
Expand Down

0 comments on commit 1ae9e76

Please sign in to comment.