Skip to content

Commit

Permalink
feat: rework recursion support
Browse files Browse the repository at this point in the history
Fixes #34
  • Loading branch information
agaffney committed Oct 10, 2023
1 parent 3283762 commit 952e840
Showing 1 changed file with 189 additions and 49 deletions.
238 changes: 189 additions & 49 deletions internal/dns/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func handleQuery(w dns.ResponseWriter, r *dns.Msg) {
// Pick random nameserver for domain
tmpNameserver := randomNameserverAddress(nameservers)
// Query the random domain nameserver we picked above
resp, err := queryServer(r, tmpNameserver.String())
resp, err := doQuery(r, tmpNameserver.String(), true)
if err != nil {
// Send failure response
m.SetRcode(r, dns.RcodeServerFailure)
Expand Down Expand Up @@ -112,30 +112,6 @@ func handleQuery(w dns.ResponseWriter, r *dns.Msg) {
return
}

// Query fallback servers if recursion is enabled
if cfg.Dns.RecursionEnabled {
// Pick random fallback server
fallbackServer := cfg.Dns.FallbackServers[rand.Intn(len(cfg.Dns.FallbackServers))]
// Query chosen server
fallbackResp, err := queryServer(r, fallbackServer)
if err != nil {
// Send failure response
m.SetRcode(r, dns.RcodeServerFailure)
if err := w.WriteMsg(m); err != nil {
logger.Errorf("failed to write response: %s", err)
}
logger.Errorf("failed to query fallback server: %s", err)
return
} else {
copyResponse(r, fallbackResp, m)
// Send response
if err := w.WriteMsg(m); err != nil {
logger.Errorf("failed to write response: %s", err)
}
return
}
}

// Return NXDOMAIN if we have no information about the requested domain or any of its parents
m.SetRcode(r, dns.RcodeNameError)
if err := w.WriteMsg(m); err != nil {
Expand All @@ -159,30 +135,64 @@ func copyResponse(req *dns.Msg, srcResp *dns.Msg, destResp *dns.Msg) {
}
}

func queryServer(req *dns.Msg, nameserver string) (*dns.Msg, error) {
m := new(dns.Msg)
m.Id = dns.Id()
m.RecursionDesired = req.RecursionDesired
m.Question = append(m.Question, req.Question...)
in, err := dns.Exchange(m, fmt.Sprintf("%s:53", nameserver))
return in, err
}

func randomNameserverAddress(nameservers map[string][]net.IP) net.IP {
// Put all namserver addresses in single list
tmpNameservers := []net.IP{}
for _, addresses := range nameservers {
tmpNameservers = append(tmpNameservers, addresses...)
}
tmpNameserver := tmpNameservers[rand.Intn(len(tmpNameservers))]
return tmpNameserver
if len(tmpNameservers) > 0 {
tmpNameserver := tmpNameservers[rand.Intn(len(tmpNameservers))]
return tmpNameserver
}
return nil
}

func doQuery(msg *dns.Msg, address string) (*dns.Msg, error) {
func doQuery(msg *dns.Msg, address string, recursive bool) (*dns.Msg, error) {
logger := logging.GetLogger()
logger.Debugf("querying %s: %s %s", address, dns.Type(msg.Question[0].Qtype).String(), msg.Question[0].Name)
// Default to a random fallback server if no address is specified
if address == "" {
address = randomFallbackServer()
}
// Add default port to address if there is none
if !strings.Contains(address, ":") {
address = address + `:53`
}
logger.Debugf("querying %s: %s", address, formatMessageQuestionSection(msg.Question))
resp, err := dns.Exchange(msg, address)
return resp, err
if err != nil {
return nil, err
}
logger.Debugf("response: rcode=%s, authoritative=%v, authority=%s, answer=%s, extra=%s", dns.RcodeToString[resp.Rcode], resp.Authoritative, formatMessageAnswerSection(resp.Ns), formatMessageAnswerSection(resp.Answer), formatMessageAnswerSection(resp.Extra))
// Immediately return authoritative response
if resp.Authoritative {
return resp, nil
}
if recursive {
if len(resp.Ns) > 0 {
nameservers := getNameserversFromResponse(resp)
randNsName, randNsAddress := randomNameserver(nameservers)
if randNsAddress == "" {
m := createQuery(randNsName, dns.TypeA)
// XXX: should this query the fallback servers or the server that gave us the NS response?
resp, err := doQuery(m, "", false)
if err != nil {
return nil, err
}
randNsAddress = getAddressForNameFromResponse(resp, randNsName)
if randNsAddress == "" {
// Return the current response if we couldn't get an address for the nameserver
return resp, nil
}
}
// Perform recursive query
return doQuery(msg, randNsAddress, true)
} else {
// Return the current response if there is no authority information
return resp, nil
}
}
return resp, nil
}

func findNameserversForDomain(recordName string) (string, map[string][]net.IP, error) {
Expand Down Expand Up @@ -211,14 +221,11 @@ func findNameserversForDomain(recordName string) (string, map[string][]net.IP, e
// Query fallback servers, if configured
if len(cfg.Dns.FallbackServers) > 0 {
// Pick random fallback server
fallbackServer := cfg.Dns.FallbackServers[rand.Intn(len(cfg.Dns.FallbackServers))]
serverWithPort := fmt.Sprintf("%s:53", fallbackServer)
fallbackServer := randomFallbackServer()
for startLabelIdx := 0; startLabelIdx < len(queryLabels); startLabelIdx++ {
lookupDomainName := dns.Fqdn(strings.Join(queryLabels[startLabelIdx:], "."))
m := new(dns.Msg)
m.SetQuestion(lookupDomainName, dns.TypeNS)
m.RecursionDesired = false
in, err := doQuery(m, serverWithPort)
m := createQuery(lookupDomainName, dns.TypeNS)
in, err := doQuery(m, fallbackServer, false)
if err != nil {
return "", nil, err
}
Expand All @@ -231,10 +238,8 @@ func findNameserversForDomain(recordName string) (string, map[string][]net.IP, e
ns := v.Ns
ret[ns] = make([]net.IP, 0)
// Query for matching A/AAAA records
m2 := new(dns.Msg)
m2.SetQuestion(ns, dns.TypeA)
m2.RecursionDesired = false
in2, err := doQuery(m2, serverWithPort)
m2 := createQuery(ns, dns.TypeA)
in2, err := doQuery(m2, fallbackServer, false)
if err != nil {
return "", nil, err
}
Expand All @@ -258,3 +263,138 @@ func findNameserversForDomain(recordName string) (string, map[string][]net.IP, e

return "", nil, nil
}

func getNameserversFromResponse(msg *dns.Msg) map[string][]net.IP {
if len(msg.Ns) == 0 {
return nil
}
ret := map[string][]net.IP{}
for _, ns := range msg.Ns {
// TODO: handle SOA
switch v := ns.(type) {
case *dns.NS:
nsName := v.Ns
ret[nsName] = []net.IP{}
for _, extra := range msg.Extra {
if extra.Header().Name != nsName {
continue
}
switch v := extra.(type) {
case *dns.A:
ret[nsName] = append(
ret[nsName],
v.A,
)
case *dns.AAAA:
ret[nsName] = append(
ret[nsName],
v.AAAA,
)
}
}
}
}
return ret
}

func getAddressForNameFromResponse(msg *dns.Msg, recordName string) string {
var retRR dns.RR
for _, answer := range msg.Answer {
if answer.Header().Name == recordName {
retRR = answer
break
}
}
if retRR == nil {
for _, extra := range msg.Extra {
if extra.Header().Name == recordName {
retRR = extra
break
}
}
}
if retRR == nil {
return ""
}
switch v := retRR.(type) {
case *dns.A:
return v.A.String()
case *dns.AAAA:
return v.AAAA.String()
}
return ""
}

func randomNameserver(nameservers map[string][]net.IP) (string, string) {
mapKeys := []string{}
for k := range nameservers {
mapKeys = append(mapKeys, k)
}
if len(mapKeys) > 0 {
randNsName := mapKeys[rand.Intn(len(mapKeys))]
randNsAddresses := nameservers[randNsName]
randNsAddress := randNsAddresses[rand.Intn(len(randNsAddresses))].String()
return randNsName, randNsAddress
}
return "", ""
}

func createQuery(recordName string, recordType uint16) *dns.Msg {
m := new(dns.Msg)
m.SetQuestion(recordName, recordType)
m.RecursionDesired = false
return m
}

func randomFallbackServer() string {
cfg := config.GetConfig()
return cfg.Dns.FallbackServers[rand.Intn(
len(cfg.Dns.FallbackServers),
)]
}

func formatMessageAnswerSection(section []dns.RR) string {
ret := "[ "
for idx, rr := range section {
ret += fmt.Sprintf(
"< %s >",
strings.ReplaceAll(
strings.TrimPrefix(
rr.String(),
";",
),
"\t",
" ",
),
)
if idx != len(section)-1 {
ret += `,`
}
ret += ` `
}
ret += "]"
return ret
}

func formatMessageQuestionSection(section []dns.Question) string {
ret := "[ "
for idx, question := range section {
ret += fmt.Sprintf(
"< %s >",
strings.ReplaceAll(
strings.TrimPrefix(
question.String(),
";",
),
"\t",
" ",
),
)
if idx != len(section)-1 {
ret += `,`
}
ret += ` `
}
ret += "]"
return ret
}

0 comments on commit 952e840

Please sign in to comment.