Skip to content

Commit

Permalink
Add base dns resolver
Browse files Browse the repository at this point in the history
  • Loading branch information
resetius committed Dec 17, 2023
1 parent 1f25ad1 commit 810f865
Show file tree
Hide file tree
Showing 7 changed files with 252 additions and 4 deletions.
1 change: 1 addition & 0 deletions coroio/all.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "corochain.hpp"
#include "sockutils.hpp"
#include "ssl.hpp"
#include "resolver.hpp"

namespace NNet {
#if defined(__APPLE__) || defined(__FreeBSD__)
Expand Down
2 changes: 1 addition & 1 deletion coroio/promises.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ inline auto SelfId() {
}

auto await_resume() noexcept {
return H.address();
return H;
}

std::coroutine_handle<> H;
Expand Down
190 changes: 189 additions & 1 deletion coroio/resolver.cpp
Original file line number Diff line number Diff line change
@@ -1 +1,189 @@
#include "resolver.hpp"
#include "resolver.hpp"
#include "socket.hpp"
#include "promises.hpp"
#include "uring.hpp"

#include <string_view>
#include <utility>

namespace NNet {

namespace {

// Based on https://w3.cs.jmu.edu/kirkpams/OpenCSF/Books/csf/html/UDPSockets.html

struct TDnsHeader {
uint16_t xid = 0; /* Randomly chosen identifier */
uint16_t flags = 0; /* Bit-mask to indicate request/response */
uint16_t qdcount = 0; /* Number of questions */
uint16_t ancount = 0; /* Number of answers */
uint16_t nscount = 0; /* Number of authority records */
uint16_t arcount = 0; /* Number of additional records */
} __attribute__((__packed__));

struct TDnsQuestion {
char* name; /* Pointer to the domain name in memory */
uint16_t dnstype; /* The QTYPE (1 = A) */
uint16_t dnsclass; /* The QCLASS (1 = IN) */
} __attribute__((__packed__));

struct TDnsRecordA {
uint16_t compression;
uint16_t type;
uint16_t clazz;
uint32_t ttl;
uint16_t length;
in_addr addr;
} __attribute__((packed));

} // namespace

template<typename TPoller>
TResolver<TPoller>::TResolver(TAddress dnsAddr, TPoller& poller)
: Socket(std::move(dnsAddr), poller, SOCK_DGRAM)
, Poller(poller)
{
// Start tasks after fields initialization
Sender = SenderTask();
Receiver = ReceiverTask();
}

template<typename TPoller>
TResolver<TPoller>::~TResolver()
{
Sender.destroy();
Receiver.destroy();
}

template<typename TPoller>
void TResolver<TPoller>::CreatePacket(const std::string& name, char* packet, int* size)
{
TDnsHeader header = {
.xid = htons(Xid ++),
.flags = htons(0x0100), /* Q=0, RD=1 */
.qdcount = htons(1) /* Sending 1 question */
};

std::string query; query.resize(name.size() + 2);
TDnsQuestion question = {
.name = &query[0],
.dnstype = htons (1), /* QTYPE 1=A */
.dnsclass = htons (1), /* QCLASS 1=IN */
};

memcpy (question.name + 1, &name[0], name.size());
uint8_t* prev = (uint8_t*) question.name;
uint8_t count = 0; /* Used to count the bytes in a field */

/* Traverse through the name, looking for the . locations */
for (size_t i = 0; i < name.size(); i++)
{
/* A . indicates the end of a field */
if (name[i] == '.') {
/* Copy the length to the byte before this field, then
update prev to the location of the . */
*prev = count;
prev = (uint8_t*)question.name + i + 1;
count = 0;
}
else {
count++;
}
}
*prev = count;

size_t packetlen = sizeof (header) + name.size() + 2 +
sizeof (question.dnstype) + sizeof (question.dnsclass);
assert(packetlen <= 4096);
*size = packetlen;

uint8_t *p = (uint8_t *)packet;

/* Copy the header first */
memcpy (p, &header, sizeof (header));
p += sizeof (header);

/* Copy the question name, QTYPE, and QCLASS fields */
memcpy(p, question.name, name.size() + 1);
p += name.size() + 2; /* includes 0 octet for end */
memcpy(p, &question.dnstype, sizeof (question.dnstype));
p += sizeof (question.dnstype);
memcpy(p, &question.dnsclass, sizeof (question.dnsclass));
}

template<typename TPoller>
TVoidSuspendedTask TResolver<TPoller>::SenderTask() {
co_await Socket.Connect();
char buf[4096];
while (true) {
while (AddResolveQueue.empty()) {
SenderSuspended = co_await SelfId();
co_await std::suspend_always{};
}
SenderSuspended = {};
auto hostname = AddResolveQueue.front(); AddResolveQueue.pop();
int len;
CreatePacket(hostname, buf, &len);
auto size = co_await Socket.WriteSome(buf, len);
assert(size == len);
}
co_return;
}

template<typename TPoller>
TVoidSuspendedTask TResolver<TPoller>::ReceiverTask() {
char buf[4096];
while (true) {
auto size = co_await Socket.ReadSome(buf, sizeof(buf));
assert(size > sizeof(TDnsHeader));

TDnsHeader* header = (TDnsHeader*)(&buf[0]);
assert ((ntohs (header->flags) & 0xf) == 0);
uint8_t* startOfName = (uint8_t*)(&buf[0] + sizeof (TDnsHeader));
uint8_t total = 0;
uint8_t* fieldLength = startOfName;

// TODO: Check size
while (*fieldLength != 0)
{
/* Restore the dot in the name and advance to next length */
total += *fieldLength + 1;
*fieldLength = '.';
fieldLength = startOfName + total;
}

TDnsRecordA* records = (TDnsRecordA*) (fieldLength + 5);
std::vector<TAddress> addresses;
for (int i = 0; i < ntohs (header->ancount); i++)
{
addresses.emplace_back(TAddress{inet_ntoa (records[i].addr), 0});
}

std::string name((char*)startOfName+1);
Addresses[name] = std::move(addresses);
auto maybeWaiting = WaitingAddrs.find(name);
if (maybeWaiting != WaitingAddrs.end()) {
for (auto h : maybeWaiting->second) {
h.resume();
}
}
}
co_return;
}

template<typename TPoller>
TValueTask<std::vector<TAddress>> TResolver<TPoller>::Resolve(const std::string& hostname) {
auto handle = co_await SelfId();
WaitingAddrs[hostname].emplace_back(handle);
AddResolveQueue.emplace(hostname);
if (SenderSuspended) {
SenderSuspended.resume();
}
co_await std::suspend_always{};
co_return Addresses[hostname];
}

template class TResolver<TPollerBase>;
template class TResolver<TUring>;

} // namespace NNet
13 changes: 12 additions & 1 deletion coroio/resolver.hpp
Original file line number Diff line number Diff line change
@@ -1,27 +1,38 @@
#pragma once

#include <unordered_map>

#include "promises.hpp"
#include "socket.hpp"
#include "corochain.hpp"

namespace NNet {

template<typename TPoller>
class TResolver {
public:
TResolver(TAddress dnsAddr);
TResolver(TAddress dnsAddr, TPoller& poller);
~TResolver();

TValueTask<std::vector<TAddress>> Resolve(const std::string& hostname);

private:
TVoidSuspendedTask SenderTask();
TVoidSuspendedTask ReceiverTask();
void CreatePacket(const std::string& name, char* buf, int* size);

TSocket Socket;
TPoller& Poller;

std::coroutine_handle<> Sender;
std::coroutine_handle<> SenderSuspended;
std::coroutine_handle<> Receiver;
std::queue<std::string> AddResolveQueue;

std::unordered_map<std::string, std::vector<TAddress>> Addresses;
std::unordered_map<std::string, std::vector<std::coroutine_handle<>>> WaitingAddrs;

uint16_t Xid = 0;
};

} // namespace NNet
18 changes: 18 additions & 0 deletions coroio/socket.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "socket.hpp"
#include <sys/socket.h>

namespace NNet {

Expand Down Expand Up @@ -64,6 +65,23 @@ bool TAddress::operator == (const TAddress& other) const {
return memcmp(&Addr_, &other.Addr_, sizeof(Addr_)) == 0;
}

std::string TAddress::ToString() const {
char buf[1024];
if (const auto* val = std::get_if<sockaddr_in>(&Addr_)) {
auto* r = inet_ntop(AF_INET, &val->sin_addr, buf, sizeof(buf));
if (r) {
return std::string(r) + ":" + std::to_string(val->sin_port);
}
} else if (const auto* val = std::get_if<sockaddr_in6>(&Addr_)) {
auto* r = inet_ntop(AF_INET6, &val->sin6_addr, buf, sizeof(buf));
if (r) {
return std::string(r) + ":" + std::to_string(val->sin6_port);
}
}

return "";
}

TSocketOps::TSocketOps(TPollerBase& poller, int domain, int type)
: Poller_(&poller)
, Fd_(Create(domain, type))
Expand Down
2 changes: 2 additions & 0 deletions coroio/socket.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class TAddress {
bool operator == (const TAddress& other) const;
int Domain() const;

std::string ToString() const;

private:
std::variant<sockaddr_in, sockaddr_in6> Addr_ = {};
};
Expand Down
30 changes: 29 additions & 1 deletion tests/tests.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "coroio/socket.hpp"
#include <chrono>
#include <array>
#include <exception>
Expand Down Expand Up @@ -627,14 +628,40 @@ void test_zero_copy_line_splitter(void**) {
void test_self_id(void**) {
void* id;
NNet::TVoidSuspendedTask h = [](void** id) -> NNet::TVoidSuspendedTask {
*id = co_await SelfId();
*id = (co_await SelfId()).address();
co_return;
}(&id);

assert_ptr_equal(id, h.address());
h.destroy();
}

template<typename TPoller>
void test_resolver(void**) {
using TLoop = TLoop<TPoller>;
using TSocket = typename TPoller::TSocket;

TLoop loop;
TResolver<TPollerBase> resolver({"8.8.8.8", 53}, loop.Poller());

std::vector<TAddress> addresses;
NNet::TVoidSuspendedTask h1 = [](auto& resolver, std::vector<TAddress>& addresses) -> NNet::TVoidSuspendedTask {
addresses = co_await resolver.Resolve("www.google.com");
//for (auto& addr : addresses) {
// std::cout << addr.ToString() << "\n";
//}
co_return;
}(resolver, addresses);

while (!(h1.done())) {
loop.Step();
}

assert_int_equal(addresses.size(), 1);

h1.destroy();
}

template<typename TPoller>
void test_read_write_full_ssl(void**) {
using TLoop = TLoop<TPoller>;
Expand Down Expand Up @@ -871,6 +898,7 @@ int main() {
my_unit_poller(test_read_write_struct),
my_unit_poller(test_read_write_lines),
my_unit_test2(test_read_write_full_ssl, TSelect, TPoll),
my_unit_test2(test_resolver, TSelect, TPoll),
#ifdef __linux__
cmocka_unit_test(test_uring_create),
cmocka_unit_test(test_uring_write),
Expand Down

0 comments on commit 810f865

Please sign in to comment.