From 952e84008da2c4742fa368f0cfcec8ffc289d070 Mon Sep 17 00:00:00 2001 From: Aurora Gaffney Date: Fri, 6 Oct 2023 18:33:07 -0500 Subject: [PATCH] feat: rework recursion support Fixes #34 --- internal/dns/dns.go | 238 +++++++++++++++++++++++++++++++++++--------- 1 file changed, 189 insertions(+), 49 deletions(-) diff --git a/internal/dns/dns.go b/internal/dns/dns.go index 3798f43..98b51a0 100644 --- a/internal/dns/dns.go +++ b/internal/dns/dns.go @@ -59,7 +59,7 @@ func handleQuery(w dns.ResponseWriter, r *dns.Msg) { // Pick random nameserver for domain tmpNameserver := randomNameserverAddress(nameservers) // Query the random domain nameserver we picked above - resp, err := queryServer(r, tmpNameserver.String()) + resp, err := doQuery(r, tmpNameserver.String(), true) if err != nil { // Send failure response m.SetRcode(r, dns.RcodeServerFailure) @@ -112,30 +112,6 @@ func handleQuery(w dns.ResponseWriter, r *dns.Msg) { return } - // Query fallback servers if recursion is enabled - if cfg.Dns.RecursionEnabled { - // Pick random fallback server - fallbackServer := cfg.Dns.FallbackServers[rand.Intn(len(cfg.Dns.FallbackServers))] - // Query chosen server - fallbackResp, err := queryServer(r, fallbackServer) - if err != nil { - // Send failure response - m.SetRcode(r, dns.RcodeServerFailure) - if err := w.WriteMsg(m); err != nil { - logger.Errorf("failed to write response: %s", err) - } - logger.Errorf("failed to query fallback server: %s", err) - return - } else { - copyResponse(r, fallbackResp, m) - // Send response - if err := w.WriteMsg(m); err != nil { - logger.Errorf("failed to write response: %s", err) - } - return - } - } - // Return NXDOMAIN if we have no information about the requested domain or any of its parents m.SetRcode(r, dns.RcodeNameError) if err := w.WriteMsg(m); err != nil { @@ -159,30 +135,64 @@ func copyResponse(req *dns.Msg, srcResp *dns.Msg, destResp *dns.Msg) { } } -func queryServer(req *dns.Msg, nameserver string) (*dns.Msg, error) { - m := new(dns.Msg) - m.Id = dns.Id() - m.RecursionDesired = req.RecursionDesired - m.Question = append(m.Question, req.Question...) - in, err := dns.Exchange(m, fmt.Sprintf("%s:53", nameserver)) - return in, err -} - func randomNameserverAddress(nameservers map[string][]net.IP) net.IP { // Put all namserver addresses in single list tmpNameservers := []net.IP{} for _, addresses := range nameservers { tmpNameservers = append(tmpNameservers, addresses...) } - tmpNameserver := tmpNameservers[rand.Intn(len(tmpNameservers))] - return tmpNameserver + if len(tmpNameservers) > 0 { + tmpNameserver := tmpNameservers[rand.Intn(len(tmpNameservers))] + return tmpNameserver + } + return nil } -func doQuery(msg *dns.Msg, address string) (*dns.Msg, error) { +func doQuery(msg *dns.Msg, address string, recursive bool) (*dns.Msg, error) { logger := logging.GetLogger() - logger.Debugf("querying %s: %s %s", address, dns.Type(msg.Question[0].Qtype).String(), msg.Question[0].Name) + // Default to a random fallback server if no address is specified + if address == "" { + address = randomFallbackServer() + } + // Add default port to address if there is none + if !strings.Contains(address, ":") { + address = address + `:53` + } + logger.Debugf("querying %s: %s", address, formatMessageQuestionSection(msg.Question)) resp, err := dns.Exchange(msg, address) - return resp, err + if err != nil { + return nil, err + } + logger.Debugf("response: rcode=%s, authoritative=%v, authority=%s, answer=%s, extra=%s", dns.RcodeToString[resp.Rcode], resp.Authoritative, formatMessageAnswerSection(resp.Ns), formatMessageAnswerSection(resp.Answer), formatMessageAnswerSection(resp.Extra)) + // Immediately return authoritative response + if resp.Authoritative { + return resp, nil + } + if recursive { + if len(resp.Ns) > 0 { + nameservers := getNameserversFromResponse(resp) + randNsName, randNsAddress := randomNameserver(nameservers) + if randNsAddress == "" { + m := createQuery(randNsName, dns.TypeA) + // XXX: should this query the fallback servers or the server that gave us the NS response? + resp, err := doQuery(m, "", false) + if err != nil { + return nil, err + } + randNsAddress = getAddressForNameFromResponse(resp, randNsName) + if randNsAddress == "" { + // Return the current response if we couldn't get an address for the nameserver + return resp, nil + } + } + // Perform recursive query + return doQuery(msg, randNsAddress, true) + } else { + // Return the current response if there is no authority information + return resp, nil + } + } + return resp, nil } func findNameserversForDomain(recordName string) (string, map[string][]net.IP, error) { @@ -211,14 +221,11 @@ func findNameserversForDomain(recordName string) (string, map[string][]net.IP, e // Query fallback servers, if configured if len(cfg.Dns.FallbackServers) > 0 { // Pick random fallback server - fallbackServer := cfg.Dns.FallbackServers[rand.Intn(len(cfg.Dns.FallbackServers))] - serverWithPort := fmt.Sprintf("%s:53", fallbackServer) + fallbackServer := randomFallbackServer() for startLabelIdx := 0; startLabelIdx < len(queryLabels); startLabelIdx++ { lookupDomainName := dns.Fqdn(strings.Join(queryLabels[startLabelIdx:], ".")) - m := new(dns.Msg) - m.SetQuestion(lookupDomainName, dns.TypeNS) - m.RecursionDesired = false - in, err := doQuery(m, serverWithPort) + m := createQuery(lookupDomainName, dns.TypeNS) + in, err := doQuery(m, fallbackServer, false) if err != nil { return "", nil, err } @@ -231,10 +238,8 @@ func findNameserversForDomain(recordName string) (string, map[string][]net.IP, e ns := v.Ns ret[ns] = make([]net.IP, 0) // Query for matching A/AAAA records - m2 := new(dns.Msg) - m2.SetQuestion(ns, dns.TypeA) - m2.RecursionDesired = false - in2, err := doQuery(m2, serverWithPort) + m2 := createQuery(ns, dns.TypeA) + in2, err := doQuery(m2, fallbackServer, false) if err != nil { return "", nil, err } @@ -258,3 +263,138 @@ func findNameserversForDomain(recordName string) (string, map[string][]net.IP, e return "", nil, nil } + +func getNameserversFromResponse(msg *dns.Msg) map[string][]net.IP { + if len(msg.Ns) == 0 { + return nil + } + ret := map[string][]net.IP{} + for _, ns := range msg.Ns { + // TODO: handle SOA + switch v := ns.(type) { + case *dns.NS: + nsName := v.Ns + ret[nsName] = []net.IP{} + for _, extra := range msg.Extra { + if extra.Header().Name != nsName { + continue + } + switch v := extra.(type) { + case *dns.A: + ret[nsName] = append( + ret[nsName], + v.A, + ) + case *dns.AAAA: + ret[nsName] = append( + ret[nsName], + v.AAAA, + ) + } + } + } + } + return ret +} + +func getAddressForNameFromResponse(msg *dns.Msg, recordName string) string { + var retRR dns.RR + for _, answer := range msg.Answer { + if answer.Header().Name == recordName { + retRR = answer + break + } + } + if retRR == nil { + for _, extra := range msg.Extra { + if extra.Header().Name == recordName { + retRR = extra + break + } + } + } + if retRR == nil { + return "" + } + switch v := retRR.(type) { + case *dns.A: + return v.A.String() + case *dns.AAAA: + return v.AAAA.String() + } + return "" +} + +func randomNameserver(nameservers map[string][]net.IP) (string, string) { + mapKeys := []string{} + for k := range nameservers { + mapKeys = append(mapKeys, k) + } + if len(mapKeys) > 0 { + randNsName := mapKeys[rand.Intn(len(mapKeys))] + randNsAddresses := nameservers[randNsName] + randNsAddress := randNsAddresses[rand.Intn(len(randNsAddresses))].String() + return randNsName, randNsAddress + } + return "", "" +} + +func createQuery(recordName string, recordType uint16) *dns.Msg { + m := new(dns.Msg) + m.SetQuestion(recordName, recordType) + m.RecursionDesired = false + return m +} + +func randomFallbackServer() string { + cfg := config.GetConfig() + return cfg.Dns.FallbackServers[rand.Intn( + len(cfg.Dns.FallbackServers), + )] +} + +func formatMessageAnswerSection(section []dns.RR) string { + ret := "[ " + for idx, rr := range section { + ret += fmt.Sprintf( + "< %s >", + strings.ReplaceAll( + strings.TrimPrefix( + rr.String(), + ";", + ), + "\t", + " ", + ), + ) + if idx != len(section)-1 { + ret += `,` + } + ret += ` ` + } + ret += "]" + return ret +} + +func formatMessageQuestionSection(section []dns.Question) string { + ret := "[ " + for idx, question := range section { + ret += fmt.Sprintf( + "< %s >", + strings.ReplaceAll( + strings.TrimPrefix( + question.String(), + ";", + ), + "\t", + " ", + ), + ) + if idx != len(section)-1 { + ret += `,` + } + ret += ` ` + } + ret += "]" + return ret +}