diff --git a/coroio/resolver.cpp b/coroio/resolver.cpp index ed29d36..2a8deb0 100644 --- a/coroio/resolver.cpp +++ b/coroio/resolver.cpp @@ -37,10 +37,10 @@ struct TDnsRecordA { in_addr addr; } __attribute__((packed)); -void CreatePacket(const std::string& name, char* packet, int* size, uint16_t* xid) +void CreatePacket(const std::string& name, char* packet, int* size, uint16_t xid) { TDnsHeader header = { - .xid = htons((*xid) ++), + .xid = htons(xid), .flags = htons(0x0100), /* Q=0, RD=1 */ .qdcount = htons(1) /* Sending 1 question */ }; @@ -92,10 +92,12 @@ void CreatePacket(const std::string& name, char* packet, int* size, uint16_t* xi memcpy(p, &question.dnsclass, sizeof (question.dnsclass)); } -void ParsePacket(std::vector& addresses, std::string& name, char* buf, ssize_t size) { - if (size < sizeof(TDnsHeader)) { throw std::runtime_error("Not enough data"); } +void ParsePacket(uint16_t* xid, std::vector& addresses, std::string& name, char* buf, ssize_t size) { TDnsHeader* header = (TDnsHeader*)(&buf[0]); - assert ((ntohs (header->flags) & 0xf) == 0); + *xid = ntohs(header->xid); + if ((ntohs (header->flags) & 0xf) != 0) { + throw std::runtime_error("Resolver Error"); + } uint8_t* startOfName = (uint8_t*)(&buf[0] + sizeof (TDnsHeader)); uint8_t fragmentSize = 0; uint8_t* p = startOfName; size -= p - (uint8_t*)buf; if (size <= 0) { throw std::runtime_error("Not enough data"); } @@ -200,7 +202,9 @@ TVoidSuspendedTask TResolver::SenderTask() { auto hostname = AddResolveQueue.front(); AddResolveQueue.pop(); int len; memset(buf, 0, sizeof(buf)); - CreatePacket(hostname, buf, &len, &Xid); + Inflight[Xid] = hostname; + CreatePacket(hostname, buf, &len, Xid); + Xid = 1 + (Xid + 1) % 65535; auto size = co_await Socket.WriteSome(buf, len); assert(size == len); } @@ -215,16 +219,21 @@ TVoidSuspendedTask TResolver::ReceiverTask() { if (size < 0) { continue; } + if (size < sizeof(TDnsHeader)) { + continue; + } std::vector addresses; std::string name; std::exception_ptr exception; + uint16_t xid; try { - ParsePacket(addresses, name, buf, size); + ParsePacket(&xid, addresses, name, buf, size); } catch (const std::exception& ex) { exception = std::current_exception(); } + name = Inflight[xid]; Results[name] = TResolveResult { .Addresses = std::move(addresses), .Exception = exception diff --git a/coroio/resolver.hpp b/coroio/resolver.hpp index 2b8076d..7171f08 100644 --- a/coroio/resolver.hpp +++ b/coroio/resolver.hpp @@ -52,8 +52,9 @@ class TResolver { std::unordered_map Results; std::unordered_map>> WaitingAddrs; + std::unordered_map Inflight; - uint16_t Xid = 0; + uint16_t Xid = 1; }; } // namespace NNet diff --git a/tests/tests.cpp b/tests/tests.cpp index e13c908..c64c686 100644 --- a/tests/tests.cpp +++ b/tests/tests.cpp @@ -678,6 +678,32 @@ void test_resolver(void**) { h1.destroy(); } +template +void test_resolve_bad_name(void**) { + using TLoop = TLoop; + using TSocket = typename TPoller::TSocket; + + TLoop loop; + TResolver resolver(loop.Poller()); + + std::exception_ptr ex; + NNet::TVoidSuspendedTask h1 = [](auto& resolver, auto& ex) -> NNet::TVoidSuspendedTask { + try { + co_await resolver.Resolve("bad.host.name.wtf123"); + } catch (const std::exception& ) { + ex = std::current_exception(); + } + }(resolver, ex); + + while (!(h1.done())) { + loop.Step(); + } + + h1.destroy(); + + assert_true(!!ex); +} + template void test_read_write_full_ssl(void**) { using TLoop = TLoop; @@ -916,6 +942,7 @@ int main() { my_unit_poller(test_read_write_lines), my_unit_test2(test_read_write_full_ssl, TSelect, TPoll), my_unit_test2(test_resolver, TSelect, TPoll), + my_unit_test2(test_resolve_bad_name, TSelect, TPoll), #ifdef __linux__ cmocka_unit_test(test_uring_create), cmocka_unit_test(test_uring_write),