-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement full reader writer for socket
- Loading branch information
Showing
4 changed files
with
207 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
#pragma once | ||
|
||
#include <coroutine> | ||
#include <optional> | ||
#include <variant> | ||
#include <memory> | ||
|
||
namespace NNet { | ||
|
||
template<typename T> struct TFinalSuspendContinuation; | ||
|
||
template<typename T> struct TValueTask; | ||
|
||
template<typename T> | ||
struct TValuePromiseBase { | ||
std::suspend_never initial_suspend() { return {}; } | ||
TFinalSuspendContinuation<T> final_suspend() noexcept; | ||
std::coroutine_handle<> Caller = std::noop_coroutine(); | ||
}; | ||
|
||
template<typename T> | ||
struct TValuePromise: public TValuePromiseBase<T> { | ||
TValueTask<T> get_return_object(); | ||
|
||
void return_value(const T& t) { | ||
ErrorOr = t; | ||
} | ||
|
||
void unhandled_exception() { | ||
ErrorOr = std::current_exception(); | ||
} | ||
|
||
std::optional<std::variant<T, std::exception_ptr>> ErrorOr; | ||
}; | ||
|
||
template<typename T> | ||
struct TValueTaskBase : std::coroutine_handle<TValuePromise<T>> { | ||
~TValueTaskBase() { this->destroy(); } | ||
|
||
bool await_ready() { | ||
return !!this->promise().ErrorOr; | ||
} | ||
|
||
void await_suspend(std::coroutine_handle<> caller) { | ||
this->promise().Caller = caller; | ||
} | ||
|
||
using promise_type = TValuePromise<T>; | ||
}; | ||
|
||
template<typename T> | ||
struct TValueTask : public TValueTaskBase<T> { | ||
T await_resume() { | ||
auto& errorOr = *this->promise()->ErrorOr; | ||
if (auto* res = std::get_if(&errorOr)) { | ||
return *res; | ||
} else { | ||
std::rethrow_exception(std::get<std::exception_ptr>(errorOr)); | ||
} | ||
} | ||
}; | ||
|
||
template<> struct TValueTask<void>; | ||
|
||
template<> | ||
struct TValuePromise<void>: public TValuePromiseBase<void> { | ||
TValueTask<void> get_return_object(); | ||
|
||
void return_void() { | ||
ErrorOr = nullptr; | ||
} | ||
|
||
void unhandled_exception() { | ||
ErrorOr = std::current_exception(); | ||
} | ||
|
||
std::optional<std::exception_ptr> ErrorOr; | ||
}; | ||
|
||
template<> | ||
struct TValueTask<void> : public TValueTaskBase<void> { | ||
void await_resume() { | ||
auto& errorOr = *this->promise().ErrorOr; | ||
if (errorOr) { | ||
std::rethrow_exception(errorOr); | ||
} | ||
} | ||
}; | ||
|
||
template<typename T> | ||
struct TFinalSuspendContinuation { | ||
bool await_ready() noexcept { return false; } | ||
std::coroutine_handle<> await_suspend(std::coroutine_handle<TValuePromise<T>> h) noexcept { | ||
return h.promise().Caller; | ||
} | ||
void await_resume() noexcept { } | ||
}; | ||
|
||
inline TValueTask<void> TValuePromise<void>::get_return_object() { return { TValueTask<void>::from_promise(*this) }; } | ||
template<typename T> | ||
TValueTask<T> TValuePromise<T>::get_return_object() { return { TValueTask<T>::from_promise(*this) }; } | ||
|
||
|
||
template<typename T> | ||
TFinalSuspendContinuation<T> TValuePromiseBase<T>::final_suspend() noexcept { return {}; } | ||
|
||
} // namespace NNet |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
#pragma once | ||
|
||
#include "corochain.hpp" | ||
|
||
namespace NNet { | ||
|
||
template<typename TSocket> | ||
struct TReader { | ||
TReader(TSocket& socket) | ||
: Socket(socket) | ||
{ } | ||
|
||
TValueTask<void> Read(void* data, size_t size) { | ||
char* p = static_cast<char*>(data); | ||
while (size != 0) { | ||
auto read_size = co_await Socket.ReadSome(p, size); | ||
if (read_size == 0) { | ||
throw std::runtime_error("Connection closed"); | ||
} | ||
if (read_size < 0) { | ||
continue; // retry | ||
} | ||
p += read_size; | ||
size -= read_size; | ||
} | ||
co_return; | ||
} | ||
|
||
private: | ||
TSocket& Socket; | ||
}; | ||
|
||
template<typename TSocket> | ||
struct TWriter { | ||
TWriter(TSocket& socket) | ||
: Socket(socket) | ||
{ } | ||
|
||
TValueTask<void> Write(const void* data, size_t size) { | ||
const char* p = static_cast<const char*>(data); | ||
while (size != 0) { | ||
auto read_size = co_await Socket.WriteSome(const_cast<char*>(p) /* TODO: cast */, size); | ||
if (read_size == 0) { | ||
throw std::runtime_error("Connection closed"); | ||
} | ||
if (read_size < 0) { | ||
continue; // retry | ||
} | ||
p += read_size; | ||
size -= read_size; | ||
} | ||
co_return; | ||
} | ||
|
||
private: | ||
TSocket& Socket; | ||
}; | ||
|
||
} // namespace NNet { |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters