diff --git a/mock.go b/mock.go new file mode 100644 index 0000000..3a054f8 --- /dev/null +++ b/mock.go @@ -0,0 +1,31 @@ +package madns + +import ( + "context" + "net" +) + +type MockResolver struct { + IP map[string][]net.IPAddr + TXT map[string][]string +} + +var _ BasicResolver = (*MockResolver)(nil) + +func (r *MockResolver) LookupIPAddr(ctx context.Context, name string) ([]net.IPAddr, error) { + results, ok := r.IP[name] + if ok { + return results, nil + } else { + return []net.IPAddr{}, nil + } +} + +func (r *MockResolver) LookupTXT(ctx context.Context, name string) ([]string, error) { + results, ok := r.TXT[name] + if ok { + return results, nil + } else { + return []string{}, nil + } +} diff --git a/resolve.go b/resolve.go index fd8d5c2..705fed7 100644 --- a/resolve.go +++ b/resolve.go @@ -9,59 +9,86 @@ import ( ) var ResolvableProtocols = []ma.Protocol{DnsaddrProtocol, Dns4Protocol, Dns6Protocol, DnsProtocol} -var DefaultResolver = &Resolver{Backend: net.DefaultResolver} +var DefaultResolver = &Resolver{def: net.DefaultResolver} const dnsaddrTXTPrefix = "dnsaddr=" -type Backend interface { +// BasicResolver is a low level interface for DNS resolution +type BasicResolver interface { LookupIPAddr(context.Context, string) ([]net.IPAddr, error) LookupTXT(context.Context, string) ([]string, error) } +// Resolver is an object capable of resolving dns multiaddrs by using one or more BasicResolvers; +// it supports custom per domain/TLD resolvers. +// It also implements the BasicResolver interface so that it can act as a custom per domain/TLD +// resolver. type Resolver struct { - Backend Backend + def BasicResolver + custom map[string]BasicResolver } -var _ Backend = (*MockBackend)(nil) +var _ BasicResolver = (*Resolver)(nil) -type MockBackend struct { - IP map[string][]net.IPAddr - TXT map[string][]string +// NewResolver creates a new Resolver instance with the specified options +func NewResolver(opts ...Option) (*Resolver, error) { + r := &Resolver{def: net.DefaultResolver} + for _, opt := range opts { + err := opt(r) + if err != nil { + return nil, err + } + } + + return r, nil } -func (r *MockBackend) LookupIPAddr(ctx context.Context, name string) ([]net.IPAddr, error) { - results, ok := r.IP[name] - if ok { - return results, nil - } else { - return []net.IPAddr{}, nil +type Option func(*Resolver) error + +// WithDefaultResolver is an option that specifies the default basic resolver, +// which resolves any TLD that doesn't have a custom resolver. +// Defaults to net.DefaultResolver +func WithDefaultResolver(def BasicResolver) Option { + return func(r *Resolver) error { + r.def = def + return nil } } -func (r *MockBackend) LookupTXT(ctx context.Context, name string) ([]string, error) { - results, ok := r.TXT[name] - if ok { - return results, nil - } else { - return []string{}, nil +// WithDomainResolver specifies a custom resolver for a domain/TLD. +// Custom resolver selection matches domains left to right, with more specific resolvers +// superseding generic ones. +func WithDomainResolver(domain string, rslv BasicResolver) Option { + return func(r *Resolver) error { + if r.custom == nil { + r.custom = make(map[string]BasicResolver) + } + r.custom[domain] = rslv + return nil } } -func Matches(maddr ma.Multiaddr) (matches bool) { - ma.ForEach(maddr, func(c ma.Component) bool { - switch c.Protocol().Code { - case DnsProtocol.Code, Dns4Protocol.Code, Dns6Protocol.Code, DnsaddrProtocol.Code: - matches = true +func (r *Resolver) getResolver(domain string) BasicResolver { + // we match left-to-right, with more specific resolvers superseding generic ones. + // So for a domain a.b.c, we will try a.b,c, b.c, c, and fallback to the default if + // there is no match + rslv, ok := r.custom[domain] + if ok { + return rslv + } + + for i := strings.Index(domain, "."); i != -1; i = strings.Index(domain, ".") { + domain = domain[i+1:] + rslv, ok = r.custom[domain] + if ok { + return rslv } - return !matches - }) - return matches -} + } -func Resolve(ctx context.Context, maddr ma.Multiaddr) ([]ma.Multiaddr, error) { - return DefaultResolver.Resolve(ctx, maddr) + return r.def } +// Resolve resolves a DNS multiaddr. func (r *Resolver) Resolve(ctx context.Context, maddr ma.Multiaddr) ([]ma.Multiaddr, error) { var results []ma.Multiaddr for i := 0; maddr != nil; i++ { @@ -99,6 +126,7 @@ func (r *Resolver) Resolve(ctx context.Context, maddr ma.Multiaddr) ([]ma.Multia proto := resolve.Protocol() value := resolve.Value() + rslv := r.getResolver(value) // resolve the dns component var resolved []ma.Multiaddr @@ -114,7 +142,7 @@ func (r *Resolver) Resolve(ctx context.Context, maddr ma.Multiaddr) ([]ma.Multia // differentiating between IPv6 and IPv4. A v4-in-v6 // AAAA record will _look_ like an A record to us and // there's nothing we can do about that. - records, err := r.Backend.LookupIPAddr(ctx, value) + records, err := rslv.LookupIPAddr(ctx, value) if err != nil { return nil, err } @@ -155,7 +183,7 @@ func (r *Resolver) Resolve(ctx context.Context, maddr ma.Multiaddr) ([]ma.Multia // matching the result of step 2. // First, lookup the TXT record - records, err := r.Backend.LookupTXT(ctx, "_dnsaddr."+value) + records, err := rslv.LookupTXT(ctx, "_dnsaddr."+value) if err != nil { return nil, err } @@ -235,37 +263,10 @@ func (r *Resolver) Resolve(ctx context.Context, maddr ma.Multiaddr) ([]ma.Multia return results, nil } -// counts the number of components in the multiaddr -func addrLen(maddr ma.Multiaddr) int { - length := 0 - ma.ForEach(maddr, func(_ ma.Component) bool { - length++ - return true - }) - return length -} - -// trims `offset` components from the beginning of the multiaddr. -func offset(maddr ma.Multiaddr, offset int) ma.Multiaddr { - _, after := ma.SplitFunc(maddr, func(c ma.Component) bool { - if offset == 0 { - return true - } - offset-- - return false - }) - return after +func (r *Resolver) LookupIPAddr(ctx context.Context, domain string) ([]net.IPAddr, error) { + return r.getResolver(domain).LookupIPAddr(ctx, domain) } -// takes the cross product of two sets of multiaddrs -// -// assumes `a` is non-empty. -func cross(a, b []ma.Multiaddr) []ma.Multiaddr { - res := make([]ma.Multiaddr, 0, len(a)*len(b)) - for _, x := range a { - for _, y := range b { - res = append(res, x.Encapsulate(y)) - } - } - return res +func (r *Resolver) LookupTXT(ctx context.Context, txt string) ([]string, error) { + return r.getResolver(txt).LookupTXT(ctx, txt) } diff --git a/resolve_test.go b/resolve_test.go index 1334611..6cb230a 100644 --- a/resolve_test.go +++ b/resolve_test.go @@ -1,6 +1,7 @@ package madns import ( + "bytes" "context" "net" "testing" @@ -29,7 +30,7 @@ var txtd = "dnsaddr=" + txtmd.String() var txte = "dnsaddr=" + txtme.String() func makeResolver() *Resolver { - mock := &MockBackend{ + mock := &MockResolver{ IP: map[string][]net.IPAddr{ "example.com": []net.IPAddr{ip4a, ip4b, ip6a, ip6b}, }, @@ -38,7 +39,7 @@ func makeResolver() *Resolver { "_dnsaddr.matching.com": []string{txtc, txtd, txte, "not a dnsaddr", "dnsaddr=/foobar"}, }, } - resolver := &Resolver{Backend: mock} + resolver := &Resolver{def: mock} return resolver } @@ -234,3 +235,89 @@ func TestBadDomain(t *testing.T) { t.Error("expected malformed address to fail to parse") } } + +func TestCustomResolver(t *testing.T) { + ip1 := net.IPAddr{IP: net.ParseIP("1.2.3.4")} + ip2 := net.IPAddr{IP: net.ParseIP("2.3.4.5")} + ip3 := net.IPAddr{IP: net.ParseIP("3.4.5.6")} + ip4 := net.IPAddr{IP: net.ParseIP("4.5.6.8")} + ip5 := net.IPAddr{IP: net.ParseIP("5.6.8.9")} + ip6 := net.IPAddr{IP: net.ParseIP("6.8.9.10")} + def := &MockResolver{ + IP: map[string][]net.IPAddr{ + "example.com": []net.IPAddr{ip1}, + }, + } + custom1 := &MockResolver{ + IP: map[string][]net.IPAddr{ + "custom.test": []net.IPAddr{ip2}, + "another.custom.test": []net.IPAddr{ip3}, + "more.custom.test": []net.IPAddr{ip6}, + }, + } + custom2 := &MockResolver{ + IP: map[string][]net.IPAddr{ + "more.custom.test": []net.IPAddr{ip4}, + "some.more.custom.test": []net.IPAddr{ip5}, + }, + } + + rslv, err := NewResolver( + WithDefaultResolver(def), + WithDomainResolver("custom.test", custom1), + WithDomainResolver("more.custom.test", custom2), + ) + if err != nil { + t.Fatal(err) + } + + sameIP := func(ip1, ip2 net.IPAddr) bool { + return bytes.Equal(ip1.IP, ip2.IP) + } + + ctx := context.Background() + res, err := rslv.LookupIPAddr(ctx, "example.com") + if err != nil { + t.Fatal(err) + } + + if len(res) != 1 || !sameIP(res[0], ip1) { + t.Fatal("expected result to be ip1") + } + + res, err = rslv.LookupIPAddr(ctx, "custom.test") + if err != nil { + t.Fatal(err) + } + + if len(res) != 1 || !sameIP(res[0], ip2) { + t.Fatal("expected result to be ip2") + } + + res, err = rslv.LookupIPAddr(ctx, "another.custom.test") + if err != nil { + t.Fatal(err) + } + + if len(res) != 1 || !sameIP(res[0], ip3) { + t.Fatal("expected result to be ip3") + } + + res, err = rslv.LookupIPAddr(ctx, "more.custom.test") + if err != nil { + t.Fatal(err) + } + + if len(res) != 1 || !sameIP(res[0], ip4) { + t.Fatal("expected result to be ip4") + } + + res, err = rslv.LookupIPAddr(ctx, "some.more.custom.test") + if err != nil { + t.Fatal(err) + } + + if len(res) != 1 || !sameIP(res[0], ip5) { + t.Fatal("expected result to be ip5") + } +} diff --git a/util.go b/util.go new file mode 100644 index 0000000..2953ddd --- /dev/null +++ b/util.go @@ -0,0 +1,57 @@ +package madns + +import ( + "context" + + ma "github.com/multiformats/go-multiaddr" +) + +func Matches(maddr ma.Multiaddr) (matches bool) { + ma.ForEach(maddr, func(c ma.Component) bool { + switch c.Protocol().Code { + case DnsProtocol.Code, Dns4Protocol.Code, Dns6Protocol.Code, DnsaddrProtocol.Code: + matches = true + } + return !matches + }) + return matches +} + +func Resolve(ctx context.Context, maddr ma.Multiaddr) ([]ma.Multiaddr, error) { + return DefaultResolver.Resolve(ctx, maddr) +} + +// counts the number of components in the multiaddr +func addrLen(maddr ma.Multiaddr) int { + length := 0 + ma.ForEach(maddr, func(_ ma.Component) bool { + length++ + return true + }) + return length +} + +// trims `offset` components from the beginning of the multiaddr. +func offset(maddr ma.Multiaddr, offset int) ma.Multiaddr { + _, after := ma.SplitFunc(maddr, func(c ma.Component) bool { + if offset == 0 { + return true + } + offset-- + return false + }) + return after +} + +// takes the cross product of two sets of multiaddrs +// +// assumes `a` is non-empty. +func cross(a, b []ma.Multiaddr) []ma.Multiaddr { + res := make([]ma.Multiaddr, 0, len(a)*len(b)) + for _, x := range a { + for _, y := range b { + res = append(res, x.Encapsulate(y)) + } + } + return res +}