Skip to content

Commit

Permalink
Merge pull request #45 from blinklabs-io/feat/fallback-query-ns
Browse files Browse the repository at this point in the history
feat: explicitly query fallback servers for NS records
  • Loading branch information
agaffney authored Oct 5, 2023
2 parents 71668b8 + 6530134 commit c5b2bae
Showing 1 changed file with 116 additions and 36 deletions.
152 changes: 116 additions & 36 deletions internal/dns/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,28 +48,18 @@ func handleQuery(w dns.ResponseWriter, r *dns.Msg) {
}
}

// Split query name into labels and lookup each domain and parent until we get a hit
queryLabels := dns.SplitDomainName(r.Question[0].Name)
for startLabelIdx := 0; startLabelIdx < len(queryLabels); startLabelIdx++ {
lookupDomainName := strings.Join(queryLabels[startLabelIdx:], ".")
nameServers, err := state.GetState().LookupDomain(lookupDomainName)
if err != nil {
logger.Errorf("failed to lookup domain: %s", err)
}
if nameServers == nil {
continue
}
nameserverDomain, nameservers, err := findNameserversForDomain(r.Question[0].Name)
if err != nil {
logger.Errorf("failed to lookup nameservers for %s: %s", r.Question[0].Name, err)
}
if nameservers != nil {
// Assemble response
m.SetReply(r)
if cfg.Dns.RecursionEnabled {
// Pick random nameserver for domain
tmpNameservers := []string{}
for nameserver := range nameServers {
tmpNameservers = append(tmpNameservers, nameserver)
}
tmpNameserver := nameServers[tmpNameservers[rand.Intn(len(tmpNameservers))]]
tmpNameserver := randomNameserverAddress(nameservers)
// Query the random domain nameserver we picked above
resp, err := queryServer(r, tmpNameserver)
resp, err := queryServer(r, tmpNameserver.String())
if err != nil {
// Send failure response
m.SetRcode(r, dns.RcodeServerFailure)
Expand All @@ -87,31 +77,30 @@ func handleQuery(w dns.ResponseWriter, r *dns.Msg) {
return
}
} else {
for nameserver, ipAddress := range nameServers {
// Add trailing dot to make everybody happy
nameserver = nameserver + `.`
for nameserver, addresses := range nameservers {
// NS record
ns := &dns.NS{
Hdr: dns.RR_Header{Name: (lookupDomainName + `.`), Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: 999},
Hdr: dns.RR_Header{Name: (nameserverDomain), Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: 999},
Ns: nameserver,
}
m.Ns = append(m.Ns, ns)
// A or AAAA record
ipAddr := net.ParseIP(ipAddress)
if ipAddr.To4() != nil {
// IPv4
a := &dns.A{
Hdr: dns.RR_Header{Name: nameserver, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 999},
A: ipAddr,
}
m.Extra = append(m.Extra, a)
} else {
// IPv6
aaaa := &dns.AAAA{
Hdr: dns.RR_Header{Name: nameserver, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 999},
AAAA: ipAddr,
for _, address := range addresses {
// A or AAAA record
if address.To4() != nil {
// IPv4
a := &dns.A{
Hdr: dns.RR_Header{Name: nameserver, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 999},
A: address,
}
m.Extra = append(m.Extra, a)
} else {
// IPv6
aaaa := &dns.AAAA{
Hdr: dns.RR_Header{Name: nameserver, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 999},
AAAA: address,
}
m.Extra = append(m.Extra, aaaa)
}
m.Extra = append(m.Extra, aaaa)
}
}
}
Expand Down Expand Up @@ -178,3 +167,94 @@ func queryServer(req *dns.Msg, nameserver string) (*dns.Msg, error) {
in, err := dns.Exchange(m, fmt.Sprintf("%s:53", nameserver))
return in, err
}

func randomNameserverAddress(nameservers map[string][]net.IP) net.IP {
// Put all namserver addresses in single list
tmpNameservers := []net.IP{}
for _, addresses := range nameservers {
tmpNameservers = append(tmpNameservers, addresses...)
}
tmpNameserver := tmpNameservers[rand.Intn(len(tmpNameservers))]
return tmpNameserver
}

func doQuery(msg *dns.Msg, address string) (*dns.Msg, error) {
logger := logging.GetLogger()
logger.Debugf("querying %s: %s %s", address, dns.Type(msg.Question[0].Qtype).String(), msg.Question[0].Name)
resp, err := dns.Exchange(msg, address)
return resp, err
}

func findNameserversForDomain(recordName string) (string, map[string][]net.IP, error) {
cfg := config.GetConfig()

// Split record name into labels and lookup each domain and parent until we get a hit
queryLabels := dns.SplitDomainName(recordName)

// Check on-chain domains first
for startLabelIdx := 0; startLabelIdx < len(queryLabels); startLabelIdx++ {
lookupDomainName := strings.Join(queryLabels[startLabelIdx:], ".")
nameservers, err := state.GetState().LookupDomain(lookupDomainName)
if err != nil {
return "", nil, err
}
if nameservers != nil {
ret := map[string][]net.IP{}
for k, v := range nameservers {
k = k + `.`
ret[k] = append(ret[k], net.ParseIP(v))
}
return dns.Fqdn(lookupDomainName), ret, nil
}
}

// Query fallback servers, if configured
if len(cfg.Dns.FallbackServers) > 0 {
// Pick random fallback server
fallbackServer := cfg.Dns.FallbackServers[rand.Intn(len(cfg.Dns.FallbackServers))]
serverWithPort := fmt.Sprintf("%s:53", fallbackServer)
for startLabelIdx := 0; startLabelIdx < len(queryLabels); startLabelIdx++ {
lookupDomainName := dns.Fqdn(strings.Join(queryLabels[startLabelIdx:], "."))
m := new(dns.Msg)
m.SetQuestion(lookupDomainName, dns.TypeNS)
m.RecursionDesired = false
in, err := doQuery(m, serverWithPort)
if err != nil {
return "", nil, err
}
if in.Rcode == dns.RcodeSuccess {
if len(in.Answer) > 0 {
ret := map[string][]net.IP{}
for _, answer := range in.Answer {
switch v := answer.(type) {
case *dns.NS:
ns := v.Ns
ret[ns] = make([]net.IP, 0)
// Query for matching A/AAAA records
m2 := new(dns.Msg)
m2.SetQuestion(ns, dns.TypeA)
m2.RecursionDesired = false
in2, err := doQuery(m2, serverWithPort)
if err != nil {
return "", nil, err
}
for _, answer2 := range in2.Answer {
switch v := answer2.(type) {
case *dns.A:
ret[ns] = append(ret[ns], v.A)
case *dns.AAAA:
ret[ns] = append(ret[ns], v.AAAA)
}
}
}
}
if len(ret) > 0 {
return lookupDomainName, ret, nil
}
}
}
}
}

return "", nil, nil
}

0 comments on commit c5b2bae

Please sign in to comment.