Skip to content

Commit

Permalink
Throw exception on bad resolve
Browse files Browse the repository at this point in the history
  • Loading branch information
resetius committed Dec 22, 2023
1 parent d16986f commit c684127
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 8 deletions.
23 changes: 16 additions & 7 deletions coroio/resolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ struct TDnsRecordA {
in_addr addr;
} __attribute__((packed));

void CreatePacket(const std::string& name, char* packet, int* size, uint16_t* xid)
void CreatePacket(const std::string& name, char* packet, int* size, uint16_t xid)
{
TDnsHeader header = {
.xid = htons((*xid) ++),
.xid = htons(xid),
.flags = htons(0x0100), /* Q=0, RD=1 */
.qdcount = htons(1) /* Sending 1 question */
};
Expand Down Expand Up @@ -92,10 +92,12 @@ void CreatePacket(const std::string& name, char* packet, int* size, uint16_t* xi
memcpy(p, &question.dnsclass, sizeof (question.dnsclass));
}

void ParsePacket(std::vector<TAddress>& addresses, std::string& name, char* buf, ssize_t size) {
if (size < sizeof(TDnsHeader)) { throw std::runtime_error("Not enough data"); }
void ParsePacket(uint16_t* xid, std::vector<TAddress>& addresses, std::string& name, char* buf, ssize_t size) {
TDnsHeader* header = (TDnsHeader*)(&buf[0]);
assert ((ntohs (header->flags) & 0xf) == 0);
*xid = ntohs(header->xid);
if ((ntohs (header->flags) & 0xf) != 0) {
throw std::runtime_error("Resolver Error");
}
uint8_t* startOfName = (uint8_t*)(&buf[0] + sizeof (TDnsHeader));
uint8_t fragmentSize = 0;
uint8_t* p = startOfName; size -= p - (uint8_t*)buf; if (size <= 0) { throw std::runtime_error("Not enough data"); }
Expand Down Expand Up @@ -200,7 +202,9 @@ TVoidSuspendedTask TResolver<TPoller>::SenderTask() {
auto hostname = AddResolveQueue.front(); AddResolveQueue.pop();
int len;
memset(buf, 0, sizeof(buf));
CreatePacket(hostname, buf, &len, &Xid);
Inflight[Xid] = hostname;
CreatePacket(hostname, buf, &len, Xid);
Xid = 1 + (Xid + 1) % 65535;
auto size = co_await Socket.WriteSome(buf, len);
assert(size == len);
}
Expand All @@ -215,16 +219,21 @@ TVoidSuspendedTask TResolver<TPoller>::ReceiverTask() {
if (size < 0) {
continue;
}
if (size < sizeof(TDnsHeader)) {
continue;
}

std::vector<TAddress> addresses;
std::string name;
std::exception_ptr exception;
uint16_t xid;
try {
ParsePacket(addresses, name, buf, size);
ParsePacket(&xid, addresses, name, buf, size);
} catch (const std::exception& ex) {
exception = std::current_exception();
}

name = Inflight[xid];
Results[name] = TResolveResult {
.Addresses = std::move(addresses),
.Exception = exception
Expand Down
3 changes: 2 additions & 1 deletion coroio/resolver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@ class TResolver {

std::unordered_map<std::string, TResolveResult> Results;
std::unordered_map<std::string, std::vector<std::coroutine_handle<>>> WaitingAddrs;
std::unordered_map<uint64_t, std::string> Inflight;

uint16_t Xid = 0;
uint16_t Xid = 1;
};

} // namespace NNet
27 changes: 27 additions & 0 deletions tests/tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,32 @@ void test_resolver(void**) {
h1.destroy();
}

template<typename TPoller>
void test_resolve_bad_name(void**) {
using TLoop = TLoop<TPoller>;
using TSocket = typename TPoller::TSocket;

TLoop loop;
TResolver<TPollerBase> resolver(loop.Poller());

std::exception_ptr ex;
NNet::TVoidSuspendedTask h1 = [](auto& resolver, auto& ex) -> NNet::TVoidSuspendedTask {
try {
co_await resolver.Resolve("bad.host.name.wtf123");
} catch (const std::exception& ) {
ex = std::current_exception();
}
}(resolver, ex);

while (!(h1.done())) {
loop.Step();
}

h1.destroy();

assert_true(!!ex);
}

template<typename TPoller>
void test_read_write_full_ssl(void**) {
using TLoop = TLoop<TPoller>;
Expand Down Expand Up @@ -916,6 +942,7 @@ int main() {
my_unit_poller(test_read_write_lines),
my_unit_test2(test_read_write_full_ssl, TSelect, TPoll),
my_unit_test2(test_resolver, TSelect, TPoll),
my_unit_test2(test_resolve_bad_name, TSelect, TPoll),
#ifdef __linux__
cmocka_unit_test(test_uring_create),
cmocka_unit_test(test_uring_write),
Expand Down

0 comments on commit c684127

Please sign in to comment.