Skip to content

Commit

Permalink
Add struct reader, enable tests for readers
Browse files Browse the repository at this point in the history
  • Loading branch information
resetius committed Nov 30, 2023
1 parent 5146f9d commit fe842ee
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 11 deletions.
36 changes: 32 additions & 4 deletions coroio/sockutils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
namespace NNet {

template<typename TSocket>
struct TReader {
TReader(TSocket& socket)
struct TByteReader {
TByteReader(TSocket& socket)
: Socket(socket)
{ }

Expand All @@ -31,8 +31,8 @@ struct TReader {
};

template<typename TSocket>
struct TWriter {
TWriter(TSocket& socket)
struct TByteWriter {
TByteWriter(TSocket& socket)
: Socket(socket)
{ }

Expand All @@ -56,4 +56,32 @@ struct TWriter {
TSocket& Socket;
};

template<typename T, typename TSocket>
struct TStructReader {
TStructReader(TSocket& socket)
: Socket(socket)
{ }

TValueTask<T> Read() {
T res;
size_t size = sizeof(T);
char* p = reinterpret_cast<char*>(&res);
while (size != 0) {
auto read_size = co_await Socket.ReadSome(p, size);
if (read_size == 0) {
throw std::runtime_error("Connection closed");
}
if (read_size < 0) {
continue; // retry
}
p += read_size;
size -= read_size;
}
co_return res;
}

private:
TSocket& Socket;
};

} // namespace NNet {
68 changes: 61 additions & 7 deletions tests/tests.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <chrono>
#include <array>
#include <exception>
#include <stdarg.h>
#include <stddef.h>
Expand Down Expand Up @@ -406,33 +407,37 @@ void test_timeout(void**) {
assert_true(next >= now + timeout);
}

template<typename TPoller>
void test_read_write_full(void**) {
using TLoop = TLoop<TPoller>;
using TSocket = typename TPoller::TSocket;

std::vector<char> data(1024*1024);
int cur = 0;
for (auto& ch : data) {
ch = cur + 'a';
cur = (cur + 1) % ('z' - 'a' + 1);
}

NNet::TLoop<NNet::TPoll> loop;
NNet::TSocket socket(NNet::TAddress{"127.0.0.1", 8988}, loop.Poller());
TLoop loop;
TSocket socket(NNet::TAddress{"127.0.0.1", 8988}, loop.Poller());
socket.Bind();
socket.Listen();

NNet::TSocket client(NNet::TAddress{"127.0.0.1", 8988}, loop.Poller());
TSocket client(NNet::TAddress{"127.0.0.1", 8988}, loop.Poller());

NNet::TTestTask h1 = [](NNet::TSocket& client, const std::vector<char>& data) -> NNet::TTestTask
NNet::TTestTask h1 = [](TSocket& client, const std::vector<char>& data) -> NNet::TTestTask
{
co_await client.Connect();
co_await TWriter(client).Write(data.data(), data.size());
co_await TByteWriter(client).Write(data.data(), data.size());
co_return;
}(client, data);

std::vector<char> received(1024*1024);
NNet::TTestTask h2 = [](NNet::TSocket& server, std::vector<char>& received) -> NNet::TTestTask
NNet::TTestTask h2 = [](TSocket& server, std::vector<char>& received) -> NNet::TTestTask
{
auto client = std::move(co_await server.Accept());
co_await TReader(client).Read(received.data(), received.size());
co_await TByteReader(client).Read(received.data(), received.size());
co_return;
}(socket, received);

Expand All @@ -445,6 +450,53 @@ void test_read_write_full(void**) {
h1.destroy(); h2.destroy();
}

template<typename TPoller>
void test_read_write_struct(void**) {
using TLoop = TLoop<TPoller>;
using TSocket = typename TPoller::TSocket;

struct Test {
std::array<char, 1024> data;
};
Test data;

int cur = 0;
for (auto& ch : data.data) {
ch = cur + 'a';
cur = (cur + 1) % ('z' - 'a' + 1);
}

TLoop loop;
TSocket socket(NNet::TAddress{"127.0.0.1", 8988}, loop.Poller());
socket.Bind();
socket.Listen();

TSocket client(NNet::TAddress{"127.0.0.1", 8988}, loop.Poller());

NNet::TTestTask h1 = [](TSocket& client, auto& data) -> NNet::TTestTask
{
co_await client.Connect();
co_await TByteWriter(client).Write(&data, data.data.size());
co_return;
}(client, data);

Test received;
NNet::TTestTask h2 = [](TSocket& server, auto& received) -> NNet::TTestTask
{
auto client = std::move(co_await server.Accept());
received = co_await TStructReader<Test, TSocket>(client).Read();
co_return;
}(socket, received);

while (!(h1.done() && h2.done())) {
loop.Step();
}

assert_memory_equal(data.data.data(), received.data.data(), data.data.size());

h1.destroy(); h2.destroy();
}

#ifdef __linux__
void test_uring_create(void**) {
TUring uring(256);
Expand Down Expand Up @@ -605,6 +657,8 @@ int main() {
my_unit_poller(test_connection_refused_on_write),
my_unit_poller(test_connection_refused_on_read),
my_unit_poller(test_read_write_same_socket),
my_unit_poller(test_read_write_full),
my_unit_poller(test_read_write_struct),
#ifdef __linux__
cmocka_unit_test(test_uring_create),
cmocka_unit_test(test_uring_write),
Expand Down

0 comments on commit fe842ee

Please sign in to comment.