diff --git a/coroio/all.hpp b/coroio/all.hpp index 732892d..e205f2c 100644 --- a/coroio/all.hpp +++ b/coroio/all.hpp @@ -31,6 +31,7 @@ #include "corochain.hpp" #include "sockutils.hpp" #include "ssl.hpp" +#include "resolver.hpp" namespace NNet { #if defined(__APPLE__) || defined(__FreeBSD__) diff --git a/coroio/promises.hpp b/coroio/promises.hpp index b5ba680..1f7e0c3 100644 --- a/coroio/promises.hpp +++ b/coroio/promises.hpp @@ -47,7 +47,7 @@ inline auto SelfId() { } auto await_resume() noexcept { - return H.address(); + return H; } std::coroutine_handle<> H; diff --git a/coroio/resolver.cpp b/coroio/resolver.cpp index 42c7303..1a0b2fb 100644 --- a/coroio/resolver.cpp +++ b/coroio/resolver.cpp @@ -1 +1,189 @@ -#include "resolver.hpp" \ No newline at end of file +#include "resolver.hpp" +#include "socket.hpp" +#include "promises.hpp" +#include "uring.hpp" + +#include +#include + +namespace NNet { + +namespace { + +// Based on https://w3.cs.jmu.edu/kirkpams/OpenCSF/Books/csf/html/UDPSockets.html + +struct TDnsHeader { + uint16_t xid = 0; /* Randomly chosen identifier */ + uint16_t flags = 0; /* Bit-mask to indicate request/response */ + uint16_t qdcount = 0; /* Number of questions */ + uint16_t ancount = 0; /* Number of answers */ + uint16_t nscount = 0; /* Number of authority records */ + uint16_t arcount = 0; /* Number of additional records */ +} __attribute__((__packed__)); + +struct TDnsQuestion { + char* name; /* Pointer to the domain name in memory */ + uint16_t dnstype; /* The QTYPE (1 = A) */ + uint16_t dnsclass; /* The QCLASS (1 = IN) */ +} __attribute__((__packed__)); + +struct TDnsRecordA { + uint16_t compression; + uint16_t type; + uint16_t clazz; + uint32_t ttl; + uint16_t length; + in_addr addr; +} __attribute__((packed)); + +} // namespace + +template +TResolver::TResolver(TAddress dnsAddr, TPoller& poller) + : Socket(std::move(dnsAddr), poller, SOCK_DGRAM) + , Poller(poller) +{ + // Start tasks after fields initialization + Sender = SenderTask(); + Receiver = ReceiverTask(); +} + +template +TResolver::~TResolver() +{ + Sender.destroy(); + Receiver.destroy(); +} + +template +void TResolver::CreatePacket(const std::string& name, char* packet, int* size) +{ + TDnsHeader header = { + .xid = htons(Xid ++), + .flags = htons(0x0100), /* Q=0, RD=1 */ + .qdcount = htons(1) /* Sending 1 question */ + }; + + std::string query; query.resize(name.size() + 2); + TDnsQuestion question = { + .name = &query[0], + .dnstype = htons (1), /* QTYPE 1=A */ + .dnsclass = htons (1), /* QCLASS 1=IN */ + }; + + memcpy (question.name + 1, &name[0], name.size()); + uint8_t* prev = (uint8_t*) question.name; + uint8_t count = 0; /* Used to count the bytes in a field */ + + /* Traverse through the name, looking for the . locations */ + for (size_t i = 0; i < name.size(); i++) + { + /* A . indicates the end of a field */ + if (name[i] == '.') { + /* Copy the length to the byte before this field, then + update prev to the location of the . */ + *prev = count; + prev = (uint8_t*)question.name + i + 1; + count = 0; + } + else { + count++; + } + } + *prev = count; + + size_t packetlen = sizeof (header) + name.size() + 2 + + sizeof (question.dnstype) + sizeof (question.dnsclass); + assert(packetlen <= 4096); + *size = packetlen; + + uint8_t *p = (uint8_t *)packet; + + /* Copy the header first */ + memcpy (p, &header, sizeof (header)); + p += sizeof (header); + + /* Copy the question name, QTYPE, and QCLASS fields */ + memcpy(p, question.name, name.size() + 1); + p += name.size() + 2; /* includes 0 octet for end */ + memcpy(p, &question.dnstype, sizeof (question.dnstype)); + p += sizeof (question.dnstype); + memcpy(p, &question.dnsclass, sizeof (question.dnsclass)); +} + +template +TVoidSuspendedTask TResolver::SenderTask() { + co_await Socket.Connect(); + char buf[4096]; + while (true) { + while (AddResolveQueue.empty()) { + SenderSuspended = co_await SelfId(); + co_await std::suspend_always{}; + } + SenderSuspended = {}; + auto hostname = AddResolveQueue.front(); AddResolveQueue.pop(); + int len; + CreatePacket(hostname, buf, &len); + auto size = co_await Socket.WriteSome(buf, len); + assert(size == len); + } + co_return; +} + +template +TVoidSuspendedTask TResolver::ReceiverTask() { + char buf[4096]; + while (true) { + auto size = co_await Socket.ReadSome(buf, sizeof(buf)); + assert(size > sizeof(TDnsHeader)); + + TDnsHeader* header = (TDnsHeader*)(&buf[0]); + assert ((ntohs (header->flags) & 0xf) == 0); + uint8_t* startOfName = (uint8_t*)(&buf[0] + sizeof (TDnsHeader)); + uint8_t total = 0; + uint8_t* fieldLength = startOfName; + + // TODO: Check size + while (*fieldLength != 0) + { + /* Restore the dot in the name and advance to next length */ + total += *fieldLength + 1; + *fieldLength = '.'; + fieldLength = startOfName + total; + } + + TDnsRecordA* records = (TDnsRecordA*) (fieldLength + 5); + std::vector addresses; + for (int i = 0; i < ntohs (header->ancount); i++) + { + addresses.emplace_back(TAddress{inet_ntoa (records[i].addr), 0}); + } + + std::string name((char*)startOfName+1); + Addresses[name] = std::move(addresses); + auto maybeWaiting = WaitingAddrs.find(name); + if (maybeWaiting != WaitingAddrs.end()) { + for (auto h : maybeWaiting->second) { + h.resume(); + } + } + } + co_return; +} + +template +TValueTask> TResolver::Resolve(const std::string& hostname) { + auto handle = co_await SelfId(); + WaitingAddrs[hostname].emplace_back(handle); + AddResolveQueue.emplace(hostname); + if (SenderSuspended) { + SenderSuspended.resume(); + } + co_await std::suspend_always{}; + co_return Addresses[hostname]; +} + +template class TResolver; +template class TResolver; + +} // namespace NNet diff --git a/coroio/resolver.hpp b/coroio/resolver.hpp index 0557922..16e677a 100644 --- a/coroio/resolver.hpp +++ b/coroio/resolver.hpp @@ -1,14 +1,17 @@ #pragma once +#include + #include "promises.hpp" #include "socket.hpp" #include "corochain.hpp" namespace NNet { +template class TResolver { public: - TResolver(TAddress dnsAddr); + TResolver(TAddress dnsAddr, TPoller& poller); ~TResolver(); TValueTask> Resolve(const std::string& hostname); @@ -16,12 +19,20 @@ class TResolver { private: TVoidSuspendedTask SenderTask(); TVoidSuspendedTask ReceiverTask(); + void CreatePacket(const std::string& name, char* buf, int* size); TSocket Socket; + TPoller& Poller; std::coroutine_handle<> Sender; + std::coroutine_handle<> SenderSuspended; std::coroutine_handle<> Receiver; std::queue AddResolveQueue; + + std::unordered_map> Addresses; + std::unordered_map>> WaitingAddrs; + + uint16_t Xid = 0; }; } // namespace NNet diff --git a/coroio/socket.cpp b/coroio/socket.cpp index 9b27299..c77196b 100644 --- a/coroio/socket.cpp +++ b/coroio/socket.cpp @@ -1,4 +1,5 @@ #include "socket.hpp" +#include namespace NNet { @@ -64,6 +65,23 @@ bool TAddress::operator == (const TAddress& other) const { return memcmp(&Addr_, &other.Addr_, sizeof(Addr_)) == 0; } +std::string TAddress::ToString() const { + char buf[1024]; + if (const auto* val = std::get_if(&Addr_)) { + auto* r = inet_ntop(AF_INET, &val->sin_addr, buf, sizeof(buf)); + if (r) { + return std::string(r) + ":" + std::to_string(val->sin_port); + } + } else if (const auto* val = std::get_if(&Addr_)) { + auto* r = inet_ntop(AF_INET6, &val->sin6_addr, buf, sizeof(buf)); + if (r) { + return std::string(r) + ":" + std::to_string(val->sin6_port); + } + } + + return ""; +} + TSocketOps::TSocketOps(TPollerBase& poller, int domain, int type) : Poller_(&poller) , Fd_(Create(domain, type)) diff --git a/coroio/socket.hpp b/coroio/socket.hpp index 515647c..4e0fe9f 100644 --- a/coroio/socket.hpp +++ b/coroio/socket.hpp @@ -27,6 +27,8 @@ class TAddress { bool operator == (const TAddress& other) const; int Domain() const; + std::string ToString() const; + private: std::variant Addr_ = {}; }; diff --git a/tests/tests.cpp b/tests/tests.cpp index 942bb3d..a8a3ea7 100644 --- a/tests/tests.cpp +++ b/tests/tests.cpp @@ -1,3 +1,4 @@ +#include "coroio/socket.hpp" #include #include #include @@ -627,7 +628,7 @@ void test_zero_copy_line_splitter(void**) { void test_self_id(void**) { void* id; NNet::TVoidSuspendedTask h = [](void** id) -> NNet::TVoidSuspendedTask { - *id = co_await SelfId(); + *id = (co_await SelfId()).address(); co_return; }(&id); @@ -635,6 +636,32 @@ void test_self_id(void**) { h.destroy(); } +template +void test_resolver(void**) { + using TLoop = TLoop; + using TSocket = typename TPoller::TSocket; + + TLoop loop; + TResolver resolver({"8.8.8.8", 53}, loop.Poller()); + + std::vector addresses; + NNet::TVoidSuspendedTask h1 = [](auto& resolver, std::vector& addresses) -> NNet::TVoidSuspendedTask { + addresses = co_await resolver.Resolve("www.google.com"); + //for (auto& addr : addresses) { + // std::cout << addr.ToString() << "\n"; + //} + co_return; + }(resolver, addresses); + + while (!(h1.done())) { + loop.Step(); + } + + assert_int_equal(addresses.size(), 1); + + h1.destroy(); +} + template void test_read_write_full_ssl(void**) { using TLoop = TLoop; @@ -871,6 +898,7 @@ int main() { my_unit_poller(test_read_write_struct), my_unit_poller(test_read_write_lines), my_unit_test2(test_read_write_full_ssl, TSelect, TPoll), + my_unit_test2(test_resolver, TSelect, TPoll), #ifdef __linux__ cmocka_unit_test(test_uring_create), cmocka_unit_test(test_uring_write),