Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: rework recursion support #48

Merged
merged 1 commit into from
Oct 11, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
238 changes: 189 additions & 49 deletions internal/dns/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func handleQuery(w dns.ResponseWriter, r *dns.Msg) {
// Pick random nameserver for domain
tmpNameserver := randomNameserverAddress(nameservers)
// Query the random domain nameserver we picked above
resp, err := queryServer(r, tmpNameserver.String())
resp, err := doQuery(r, tmpNameserver.String(), true)
if err != nil {
// Send failure response
m.SetRcode(r, dns.RcodeServerFailure)
Expand Down Expand Up @@ -112,30 +112,6 @@ func handleQuery(w dns.ResponseWriter, r *dns.Msg) {
return
}

// Query fallback servers if recursion is enabled
if cfg.Dns.RecursionEnabled {
// Pick random fallback server
fallbackServer := cfg.Dns.FallbackServers[rand.Intn(len(cfg.Dns.FallbackServers))]
// Query chosen server
fallbackResp, err := queryServer(r, fallbackServer)
if err != nil {
// Send failure response
m.SetRcode(r, dns.RcodeServerFailure)
if err := w.WriteMsg(m); err != nil {
logger.Errorf("failed to write response: %s", err)
}
logger.Errorf("failed to query fallback server: %s", err)
return
} else {
copyResponse(r, fallbackResp, m)
// Send response
if err := w.WriteMsg(m); err != nil {
logger.Errorf("failed to write response: %s", err)
}
return
}
}

// Return NXDOMAIN if we have no information about the requested domain or any of its parents
m.SetRcode(r, dns.RcodeNameError)
if err := w.WriteMsg(m); err != nil {
Expand All @@ -159,30 +135,64 @@ func copyResponse(req *dns.Msg, srcResp *dns.Msg, destResp *dns.Msg) {
}
}

func queryServer(req *dns.Msg, nameserver string) (*dns.Msg, error) {
m := new(dns.Msg)
m.Id = dns.Id()
m.RecursionDesired = req.RecursionDesired
m.Question = append(m.Question, req.Question...)
in, err := dns.Exchange(m, fmt.Sprintf("%s:53", nameserver))
return in, err
}

func randomNameserverAddress(nameservers map[string][]net.IP) net.IP {
// Put all namserver addresses in single list
tmpNameservers := []net.IP{}
for _, addresses := range nameservers {
tmpNameservers = append(tmpNameservers, addresses...)
}
tmpNameserver := tmpNameservers[rand.Intn(len(tmpNameservers))]
return tmpNameserver
if len(tmpNameservers) > 0 {
tmpNameserver := tmpNameservers[rand.Intn(len(tmpNameservers))]
return tmpNameserver
}
return nil
}

func doQuery(msg *dns.Msg, address string) (*dns.Msg, error) {
func doQuery(msg *dns.Msg, address string, recursive bool) (*dns.Msg, error) {
logger := logging.GetLogger()
logger.Debugf("querying %s: %s %s", address, dns.Type(msg.Question[0].Qtype).String(), msg.Question[0].Name)
// Default to a random fallback server if no address is specified
if address == "" {
address = randomFallbackServer()
}
// Add default port to address if there is none
if !strings.Contains(address, ":") {
address = address + `:53`
}
logger.Debugf("querying %s: %s", address, formatMessageQuestionSection(msg.Question))
resp, err := dns.Exchange(msg, address)
return resp, err
if err != nil {
return nil, err
}
logger.Debugf("response: rcode=%s, authoritative=%v, authority=%s, answer=%s, extra=%s", dns.RcodeToString[resp.Rcode], resp.Authoritative, formatMessageAnswerSection(resp.Ns), formatMessageAnswerSection(resp.Answer), formatMessageAnswerSection(resp.Extra))
// Immediately return authoritative response
if resp.Authoritative {
return resp, nil
}
if recursive {
if len(resp.Ns) > 0 {
nameservers := getNameserversFromResponse(resp)
randNsName, randNsAddress := randomNameserver(nameservers)
if randNsAddress == "" {
m := createQuery(randNsName, dns.TypeA)
// XXX: should this query the fallback servers or the server that gave us the NS response?
resp, err := doQuery(m, "", false)
if err != nil {
return nil, err
}
randNsAddress = getAddressForNameFromResponse(resp, randNsName)
if randNsAddress == "" {
// Return the current response if we couldn't get an address for the nameserver
return resp, nil
}
}
// Perform recursive query
return doQuery(msg, randNsAddress, true)
} else {
// Return the current response if there is no authority information
return resp, nil
}
}
return resp, nil
}

func findNameserversForDomain(recordName string) (string, map[string][]net.IP, error) {
Expand Down Expand Up @@ -211,14 +221,11 @@ func findNameserversForDomain(recordName string) (string, map[string][]net.IP, e
// Query fallback servers, if configured
if len(cfg.Dns.FallbackServers) > 0 {
// Pick random fallback server
fallbackServer := cfg.Dns.FallbackServers[rand.Intn(len(cfg.Dns.FallbackServers))]
serverWithPort := fmt.Sprintf("%s:53", fallbackServer)
fallbackServer := randomFallbackServer()
for startLabelIdx := 0; startLabelIdx < len(queryLabels); startLabelIdx++ {
lookupDomainName := dns.Fqdn(strings.Join(queryLabels[startLabelIdx:], "."))
m := new(dns.Msg)
m.SetQuestion(lookupDomainName, dns.TypeNS)
m.RecursionDesired = false
in, err := doQuery(m, serverWithPort)
m := createQuery(lookupDomainName, dns.TypeNS)
in, err := doQuery(m, fallbackServer, false)
if err != nil {
return "", nil, err
}
Expand All @@ -231,10 +238,8 @@ func findNameserversForDomain(recordName string) (string, map[string][]net.IP, e
ns := v.Ns
ret[ns] = make([]net.IP, 0)
// Query for matching A/AAAA records
m2 := new(dns.Msg)
m2.SetQuestion(ns, dns.TypeA)
m2.RecursionDesired = false
in2, err := doQuery(m2, serverWithPort)
m2 := createQuery(ns, dns.TypeA)
in2, err := doQuery(m2, fallbackServer, false)
if err != nil {
return "", nil, err
}
Expand All @@ -258,3 +263,138 @@ func findNameserversForDomain(recordName string) (string, map[string][]net.IP, e

return "", nil, nil
}

func getNameserversFromResponse(msg *dns.Msg) map[string][]net.IP {
if len(msg.Ns) == 0 {
return nil
}
ret := map[string][]net.IP{}
for _, ns := range msg.Ns {
// TODO: handle SOA
switch v := ns.(type) {
case *dns.NS:
nsName := v.Ns
ret[nsName] = []net.IP{}
for _, extra := range msg.Extra {
if extra.Header().Name != nsName {
continue
}
switch v := extra.(type) {
case *dns.A:
ret[nsName] = append(
ret[nsName],
v.A,
)
case *dns.AAAA:
ret[nsName] = append(
ret[nsName],
v.AAAA,
)
}
}
}
}
return ret
}

func getAddressForNameFromResponse(msg *dns.Msg, recordName string) string {
var retRR dns.RR
for _, answer := range msg.Answer {
if answer.Header().Name == recordName {
retRR = answer
break
}
}
if retRR == nil {
for _, extra := range msg.Extra {
if extra.Header().Name == recordName {
retRR = extra
break
}
}
}
if retRR == nil {
return ""
}
switch v := retRR.(type) {
case *dns.A:
return v.A.String()
case *dns.AAAA:
return v.AAAA.String()
}
return ""
}

func randomNameserver(nameservers map[string][]net.IP) (string, string) {
mapKeys := []string{}
for k := range nameservers {
mapKeys = append(mapKeys, k)
}
if len(mapKeys) > 0 {
randNsName := mapKeys[rand.Intn(len(mapKeys))]
randNsAddresses := nameservers[randNsName]
randNsAddress := randNsAddresses[rand.Intn(len(randNsAddresses))].String()
return randNsName, randNsAddress
}
return "", ""
}

func createQuery(recordName string, recordType uint16) *dns.Msg {
m := new(dns.Msg)
m.SetQuestion(recordName, recordType)
m.RecursionDesired = false
return m
}

func randomFallbackServer() string {
cfg := config.GetConfig()
return cfg.Dns.FallbackServers[rand.Intn(
len(cfg.Dns.FallbackServers),
)]
}

func formatMessageAnswerSection(section []dns.RR) string {
ret := "[ "
for idx, rr := range section {
ret += fmt.Sprintf(
"< %s >",
strings.ReplaceAll(
strings.TrimPrefix(
rr.String(),
";",
),
"\t",
" ",
),
)
if idx != len(section)-1 {
ret += `,`
}
ret += ` `
}
ret += "]"
return ret
}

func formatMessageQuestionSection(section []dns.Question) string {
ret := "[ "
for idx, question := range section {
ret += fmt.Sprintf(
"< %s >",
strings.ReplaceAll(
strings.TrimPrefix(
question.String(),
";",
),
"\t",
" ",
),
)
if idx != len(section)-1 {
ret += `,`
}
ret += ` `
}
ret += "]"
return ret
}
Loading