diff --git a/go.mod b/go.mod index 569e7bc4502..a4241539e3a 100644 --- a/go.mod +++ b/go.mod @@ -3,8 +3,8 @@ module github.com/AdguardTeam/AdGuardHome go 1.22.5 require ( - github.com/AdguardTeam/dnsproxy v0.71.2 - github.com/AdguardTeam/golibs v0.24.0 + github.com/AdguardTeam/dnsproxy v0.72.0 + github.com/AdguardTeam/golibs v0.24.1 github.com/AdguardTeam/urlfilter v0.19.0 github.com/NYTimes/gziphandler v1.1.1 github.com/ameshkov/dnscrypt/v2 v2.3.0 diff --git a/go.sum b/go.sum index b3f06967a4b..b6d345c1576 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,7 @@ -github.com/AdguardTeam/dnsproxy v0.71.2 h1:dFG2wga4GDdj1eI3rU2wqjQ6QGQm9MjLRb5ZzyH3Vgg= -github.com/AdguardTeam/dnsproxy v0.71.2/go.mod h1:huI5zyWhlimHBhg0jt2CMinXzsEHymI+WlvxIfmfEGA= -github.com/AdguardTeam/golibs v0.24.0 h1:qAnOq7BQtwSVo7Co9q703/n+nZ2Ap6smkugU9G9MomY= -github.com/AdguardTeam/golibs v0.24.0/go.mod h1:9/vJcYznW7RlmCT/Qzi8XNZGj+ZbWfHZJmEXKnRpCAU= +github.com/AdguardTeam/dnsproxy v0.72.0 h1:Psn7uCMVR/dCx8Te2Iy05bWWRNArSBF9j38VXNtt6+4= +github.com/AdguardTeam/dnsproxy v0.72.0/go.mod h1:5ehzbfInAu07not4beAM+FlFPqntw18T1sQCK/kIQR8= +github.com/AdguardTeam/golibs v0.24.1 h1:/ulkfm65wi33p72ybxiOt3lSdP0nr1GggSoaT4sHbns= +github.com/AdguardTeam/golibs v0.24.1/go.mod h1:9/vJcYznW7RlmCT/Qzi8XNZGj+ZbWfHZJmEXKnRpCAU= github.com/AdguardTeam/urlfilter v0.19.0 h1:q7eH13+yNETlpD/VD3u5rLQOripcUdEktqZFy+KiQLk= github.com/AdguardTeam/urlfilter v0.19.0/go.mod h1:+N54ZvxqXYLnXuvpaUhK2exDQW+djZBRSb6F6j0rkBY= github.com/NYTimes/gziphandler v1.1.1 h1:ZUDjpQae29j0ryrS0u/B8HZfJBtBQHjqw2rQ2cqUQ3I= diff --git a/internal/aghos/syslog.go b/internal/aghos/syslog.go index a67a56354f3..a38b3e68ad1 100644 --- a/internal/aghos/syslog.go +++ b/internal/aghos/syslog.go @@ -1,6 +1,6 @@ package aghos // ConfigureSyslog reroutes standard logger output to syslog. -func ConfigureSyslog(serviceName string) error { +func ConfigureSyslog(serviceName string) (err error) { return configureSyslog(serviceName) } diff --git a/internal/aghos/syslog_others.go b/internal/aghos/syslog_others.go index 1659ae498c7..9c72e66d214 100644 --- a/internal/aghos/syslog_others.go +++ b/internal/aghos/syslog_others.go @@ -8,11 +8,15 @@ import ( "github.com/AdguardTeam/golibs/log" ) -func configureSyslog(serviceName string) error { +// configureSyslog sets standard log output to syslog. +func configureSyslog(serviceName string) (err error) { w, err := syslog.New(syslog.LOG_NOTICE|syslog.LOG_USER, serviceName) if err != nil { + // Don't wrap the error, because it's informative enough as is. return err } + log.SetOutput(w) + return nil } diff --git a/internal/aghos/syslog_windows.go b/internal/aghos/syslog_windows.go index c8e86e78d21..0a78b28d54f 100644 --- a/internal/aghos/syslog_windows.go +++ b/internal/aghos/syslog_windows.go @@ -19,23 +19,30 @@ func (w *eventLogWriter) Write(b []byte) (int, error) { return len(b), w.el.Info(1, string(b)) } -func configureSyslog(serviceName string) error { - // Note that the eventlog src is the same as the service name - // Otherwise, we will get "the description for event id cannot be found" warning in every log record +// configureSyslog sets standard log output to event log. +func configureSyslog(serviceName string) (err error) { + // Note that the eventlog src is the same as the service name, otherwise we + // will get "the description for event id cannot be found" warning in every + // log record. // Continue if we receive "registry key already exists" or if we get // ERROR_ACCESS_DENIED so that we can log without administrative permissions // for pre-existing eventlog sources. - if err := eventlog.InstallAsEventCreate(serviceName, eventlog.Info|eventlog.Warning|eventlog.Error); err != nil { - if !strings.Contains(err.Error(), "registry key already exists") && err != windows.ERROR_ACCESS_DENIED { - return err - } + err = eventlog.InstallAsEventCreate(serviceName, eventlog.Info|eventlog.Warning|eventlog.Error) + if err != nil && + !strings.Contains(err.Error(), "registry key already exists") && + err != windows.ERROR_ACCESS_DENIED { + // Don't wrap the error, because it's informative enough as is. + return err } + el, err := eventlog.Open(serviceName) if err != nil { + // Don't wrap the error, because it's informative enough as is. return err } log.SetOutput(&eventLogWriter{el: el}) + return nil } diff --git a/internal/dhcpsvc/db.go b/internal/dhcpsvc/db.go index b247e653750..f1ee7d557ce 100644 --- a/internal/dhcpsvc/db.go +++ b/internal/dhcpsvc/db.go @@ -106,6 +106,9 @@ func (srv *DHCPServer) dbLoad(ctx context.Context) (err error) { return nil } + defer func() { + err = errors.WithDeferred(err, file.Close()) + }() dl := &dataLeases{} err = json.NewDecoder(file).Decode(dl) diff --git a/internal/dhcpsvc/dhcpsvc.go b/internal/dhcpsvc/dhcpsvc.go index a7d76ab5408..b6c777868da 100644 --- a/internal/dhcpsvc/dhcpsvc.go +++ b/internal/dhcpsvc/dhcpsvc.go @@ -50,7 +50,7 @@ type Interface interface { IPByHost(host string) (ip netip.Addr) // Leases returns all the active DHCP leases. The returned slice should be - // a clone. + // a clone. The order of leases is undefined. // // TODO(e.burkov): Consider implementing iterating methods with appropriate // signatures instead of cloning the whole list. diff --git a/internal/dhcpsvc/interface.go b/internal/dhcpsvc/interface.go index 13dadb4a0b5..87c3de4d016 100644 --- a/internal/dhcpsvc/interface.go +++ b/internal/dhcpsvc/interface.go @@ -3,42 +3,74 @@ package dhcpsvc import ( "fmt" "log/slog" - "slices" + "net" "time" ) -// netInterface is a common part of any network interface within the DHCP -// server. +// macKey contains hardware address as byte array of 6, 8, or 20 bytes. +// +// TODO(e.burkov): Move to aghnet or even to netutil. +type macKey any + +// macToKey converts mac into macKey, which is used as the key for the lease +// maps. mac must be a valid hardware address of length 6, 8, or 20 bytes, see +// [netutil.ValidateMAC]. +func macToKey(mac net.HardwareAddr) (key macKey) { + switch len(mac) { + case 6: + return [6]byte(mac) + case 8: + return [8]byte(mac) + case 20: + return [20]byte(mac) + default: + panic(fmt.Errorf("invalid mac address %#v", mac)) + } +} + +// netInterface is a common part of any interface within the DHCP server. // // TODO(e.burkov): Add other methods as [DHCPServer] evolves. type netInterface struct { // logger logs the events related to the network interface. logger *slog.Logger + // leases is the set of DHCP leases assigned to this interface. + leases map[macKey]*Lease + // name is the name of the network interface. name string - // leases is a set of leases sorted by hardware address. - leases []*Lease - // leaseTTL is the default Time-To-Live value for leases. leaseTTL time.Duration } +// newNetInterface creates a new netInterface with the given name, leaseTTL, and +// logger. +func newNetInterface(name string, l *slog.Logger, leaseTTL time.Duration) (iface *netInterface) { + return &netInterface{ + logger: l, + leases: map[macKey]*Lease{}, + name: name, + leaseTTL: leaseTTL, + } +} + // reset clears all the slices in iface for reuse. func (iface *netInterface) reset() { - iface.leases = iface.leases[:0] + clear(iface.leases) } -// insertLease inserts the given lease into iface. It returns an error if the +// addLease inserts the given lease into iface. It returns an error if the // lease can't be inserted. -func (iface *netInterface) insertLease(l *Lease) (err error) { - i, found := slices.BinarySearchFunc(iface.leases, l, compareLeaseMAC) +func (iface *netInterface) addLease(l *Lease) (err error) { + mk := macToKey(l.HWAddr) + _, found := iface.leases[mk] if found { return fmt.Errorf("lease for mac %s already exists", l.HWAddr) } - iface.leases = slices.Insert(iface.leases, i, l) + iface.leases[mk] = l return nil } @@ -46,12 +78,13 @@ func (iface *netInterface) insertLease(l *Lease) (err error) { // updateLease replaces an existing lease within iface with the given one. It // returns an error if there is no lease with such hardware address. func (iface *netInterface) updateLease(l *Lease) (prev *Lease, err error) { - i, found := slices.BinarySearchFunc(iface.leases, l, compareLeaseMAC) + mk := macToKey(l.HWAddr) + prev, found := iface.leases[mk] if !found { return nil, fmt.Errorf("no lease for mac %s", l.HWAddr) } - prev, iface.leases[i] = iface.leases[i], l + iface.leases[mk] = l return prev, nil } @@ -59,12 +92,13 @@ func (iface *netInterface) updateLease(l *Lease) (prev *Lease, err error) { // removeLease removes an existing lease from iface. It returns an error if // there is no lease equal to l. func (iface *netInterface) removeLease(l *Lease) (err error) { - i, found := slices.BinarySearchFunc(iface.leases, l, compareLeaseMAC) + mk := macToKey(l.HWAddr) + _, found := iface.leases[mk] if !found { return fmt.Errorf("no lease for mac %s", l.HWAddr) } - iface.leases = slices.Delete(iface.leases, i, i+1) + delete(iface.leases, mk) return nil } diff --git a/internal/dhcpsvc/lease.go b/internal/dhcpsvc/lease.go index a920a4f2077..a855b7d5c98 100644 --- a/internal/dhcpsvc/lease.go +++ b/internal/dhcpsvc/lease.go @@ -1,7 +1,6 @@ package dhcpsvc import ( - "bytes" "net" "net/netip" "slices" @@ -45,8 +44,3 @@ func (l *Lease) Clone() (clone *Lease) { IsStatic: l.IsStatic, } } - -// compareLeaseMAC compares two [Lease]s by hardware address. -func compareLeaseMAC(a, b *Lease) (res int) { - return bytes.Compare(a.HWAddr, b.HWAddr) -} diff --git a/internal/dhcpsvc/leaseindex.go b/internal/dhcpsvc/leaseindex.go index 855d6b84a51..5502d2cf2b4 100644 --- a/internal/dhcpsvc/leaseindex.go +++ b/internal/dhcpsvc/leaseindex.go @@ -61,7 +61,7 @@ func (idx *leaseIndex) add(l *Lease, iface *netInterface) (err error) { return fmt.Errorf("lease for hostname %s already exists", l.Hostname) } - err = iface.insertLease(l) + err = iface.addLease(l) if err != nil { return err } diff --git a/internal/dhcpsvc/server.go b/internal/dhcpsvc/server.go index 745895ff9c7..c8bab6e603d 100644 --- a/internal/dhcpsvc/server.go +++ b/internal/dhcpsvc/server.go @@ -41,10 +41,10 @@ type DHCPServer struct { leases *leaseIndex // interfaces4 is the set of IPv4 interfaces sorted by interface name. - interfaces4 netInterfacesV4 + interfaces4 dhcpInterfacesV4 // interfaces6 is the set of IPv6 interfaces sorted by interface name. - interfaces6 netInterfacesV6 + interfaces6 dhcpInterfacesV6 // icmpTimeout is the timeout for checking another DHCP server's presence. icmpTimeout time.Duration @@ -63,28 +63,9 @@ func New(ctx context.Context, conf *Config) (srv *DHCPServer, err error) { return nil, nil } - // TODO(e.burkov): Add validations scoped to the network interfaces set. - ifaces4 := make(netInterfacesV4, 0, len(conf.Interfaces)) - ifaces6 := make(netInterfacesV6, 0, len(conf.Interfaces)) - var errs []error - - mapsutil.SortedRange(conf.Interfaces, func(name string, iface *InterfaceConfig) (cont bool) { - var i4 *netInterfaceV4 - i4, err = newNetInterfaceV4(ctx, l, name, iface.IPv4) - if err != nil { - errs = append(errs, fmt.Errorf("interface %q: ipv4: %w", name, err)) - } else if i4 != nil { - ifaces4 = append(ifaces4, i4) - } - - i6 := newNetInterfaceV6(ctx, l, name, iface.IPv6) - if i6 != nil { - ifaces6 = append(ifaces6, i6) - } - - return true - }) - if err = errors.Join(errs...); err != nil { + ifaces4, ifaces6, err := newInterfaces(ctx, l, conf.Interfaces) + if err != nil { + // Don't wrap the error since it's informative enough as is. return nil, err } @@ -112,6 +93,43 @@ func New(ctx context.Context, conf *Config) (srv *DHCPServer, err error) { return srv, nil } +// newInterfaces creates interfaces for the given map of interface names to +// their configurations. +func newInterfaces( + ctx context.Context, + l *slog.Logger, + ifaces map[string]*InterfaceConfig, +) (v4 dhcpInterfacesV4, v6 dhcpInterfacesV6, err error) { + defer func() { err = errors.Annotate(err, "creating interfaces: %w") }() + + // TODO(e.burkov): Add validations scoped to the network interfaces set. + v4 = make(dhcpInterfacesV4, 0, len(ifaces)) + v6 = make(dhcpInterfacesV6, 0, len(ifaces)) + + var errs []error + mapsutil.SortedRange(ifaces, func(name string, iface *InterfaceConfig) (cont bool) { + var i4 *dhcpInterfaceV4 + i4, err = newDHCPInterfaceV4(ctx, l, name, iface.IPv4) + if err != nil { + errs = append(errs, fmt.Errorf("interface %q: ipv4: %w", name, err)) + } else if i4 != nil { + v4 = append(v4, i4) + } + + i6 := newDHCPInterfaceV6(ctx, l, name, iface.IPv6) + if i6 != nil { + v6 = append(v6, i6) + } + + return true + }) + if err = errors.Join(errs...); err != nil { + return nil, nil, err + } + + return v4, v6, nil +} + // type check // // TODO(e.burkov): Uncomment when the [Interface] interface is implemented. @@ -127,16 +145,11 @@ func (srv *DHCPServer) Leases() (leases []*Lease) { srv.leasesMu.RLock() defer srv.leasesMu.RUnlock() - for _, iface := range srv.interfaces4 { - for _, lease := range iface.leases { - leases = append(leases, lease.Clone()) - } - } - for _, iface := range srv.interfaces6 { - for _, lease := range iface.leases { - leases = append(leases, lease.Clone()) - } - } + srv.leases.rangeLeases(func(l *Lease) (cont bool) { + leases = append(leases, l.Clone()) + + return true + }) return leases } @@ -200,10 +213,10 @@ func (srv *DHCPServer) Reset(ctx context.Context) (err error) { // expects the DHCPServer.leasesMu to be locked. func (srv *DHCPServer) resetLeases() { for _, iface := range srv.interfaces4 { - iface.reset() + iface.common.reset() } for _, iface := range srv.interfaces6 { - iface.reset() + iface.common.reset() } srv.leases.clear() } diff --git a/internal/dhcpsvc/server_test.go b/internal/dhcpsvc/server_test.go index 0166a9b7d61..94509e37c0c 100644 --- a/internal/dhcpsvc/server_test.go +++ b/internal/dhcpsvc/server_test.go @@ -4,6 +4,7 @@ import ( "io/fs" "net/netip" "os" + "path" "path/filepath" "strings" "testing" @@ -19,14 +20,14 @@ import ( var testdata = os.DirFS("testdata") // newTempDB copies the leases database file located in the testdata FS, under -// tb.Name()/leases.db, to a temporary directory and returns the path to the +// tb.Name()/leases.json, to a temporary directory and returns the path to the // copied file. func newTempDB(tb testing.TB) (dst string) { tb.Helper() const filename = "leases.json" - data, err := fs.ReadFile(testdata, filepath.Join(tb.Name(), filename)) + data, err := fs.ReadFile(testdata, path.Join(tb.Name(), filename)) require.NoError(tb, err) dst = filepath.Join(tb.TempDir(), filename) @@ -121,7 +122,7 @@ func TestNew(t *testing.T) { DBFilePath: leasesPath, }, name: "gateway_within_range", - wantErrMsg: `interface "eth0": ipv4: ` + + wantErrMsg: `creating interfaces: interface "eth0": ipv4: ` + `gateway ip 192.168.0.100 in the ip range 192.168.0.1-192.168.0.254`, }, { conf: &dhcpsvc.Config{ @@ -137,7 +138,7 @@ func TestNew(t *testing.T) { DBFilePath: leasesPath, }, name: "bad_start", - wantErrMsg: `interface "eth0": ipv4: ` + + wantErrMsg: `creating interfaces: interface "eth0": ipv4: ` + `range start 127.0.0.1 is not within 192.168.0.1/24`, }} @@ -567,5 +568,5 @@ func TestServer_Leases(t *testing.T) { HWAddr: mustParseMAC(t, "BB:BB:BB:BB:BB:BB"), IsStatic: true, }} - assert.Equal(t, wantLeases, srv.Leases()) + assert.ElementsMatch(t, wantLeases, srv.Leases()) } diff --git a/internal/dhcpsvc/v4.go b/internal/dhcpsvc/v4.go index 106241054a1..b5194a9f90e 100644 --- a/internal/dhcpsvc/v4.go +++ b/internal/dhcpsvc/v4.go @@ -82,8 +82,12 @@ func (c *IPv4Config) validate() (err error) { return errors.Join(errs...) } -// netInterfaceV4 is a DHCP interface for IPv4 address family. -type netInterfaceV4 struct { +// dhcpInterfaceV4 is a DHCP interface for IPv4 address family. +type dhcpInterfaceV4 struct { + // common is the common part of any network interface within the DHCP + // server. + common *netInterface + // gateway is the IP address of the network gateway. gateway netip.Addr @@ -101,21 +105,17 @@ type netInterfaceV4 struct { // explicitOpts are the user-configured options. It must not have // intersections with implicitOpts. explicitOpts layers.DHCPOptions - - // netInterface is embedded here to provide some common network interface - // logic. - netInterface } -// newNetInterfaceV4 creates a new DHCP interface for IPv4 address family with +// newDHCPInterfaceV4 creates a new DHCP interface for IPv4 address family with // the given configuration. It returns an error if the given configuration // can't be used. -func newNetInterfaceV4( +func newDHCPInterfaceV4( ctx context.Context, l *slog.Logger, name string, conf *IPv4Config, -) (i *netInterfaceV4, err error) { +) (i *dhcpInterfaceV4, err error) { l = l.With( keyInterface, name, keyFamily, netutil.AddrFamilyIPv4, @@ -144,35 +144,31 @@ func newNetInterfaceV4( return nil, fmt.Errorf("gateway ip %s in the ip range %s", conf.GatewayIP, addrSpace) } - i = &netInterfaceV4{ + i = &dhcpInterfaceV4{ gateway: conf.GatewayIP, subnet: subnet, addrSpace: addrSpace, - netInterface: netInterface{ - name: name, - leaseTTL: conf.LeaseDuration, - logger: l, - }, + common: newNetInterface(name, l, conf.LeaseDuration), } i.implicitOpts, i.explicitOpts = conf.options(ctx, l) return i, nil } -// netInterfacesV4 is a slice of network interfaces of IPv4 address family. -type netInterfacesV4 []*netInterfaceV4 +// dhcpInterfacesV4 is a slice of network interfaces of IPv4 address family. +type dhcpInterfacesV4 []*dhcpInterfaceV4 // find returns the first network interface within ifaces containing ip. It // returns false if there is no such interface. -func (ifaces netInterfacesV4) find(ip netip.Addr) (iface4 *netInterface, ok bool) { - i := slices.IndexFunc(ifaces, func(iface *netInterfaceV4) (contains bool) { +func (ifaces dhcpInterfacesV4) find(ip netip.Addr) (iface4 *netInterface, ok bool) { + i := slices.IndexFunc(ifaces, func(iface *dhcpInterfaceV4) (contains bool) { return iface.subnet.Contains(ip) }) if i < 0 { return nil, false } - return &ifaces[i].netInterface, true + return ifaces[i].common, true } // options returns the implicit and explicit options for the interface. The two diff --git a/internal/dhcpsvc/v6.go b/internal/dhcpsvc/v6.go index a1ee56acd23..dd75184e167 100644 --- a/internal/dhcpsvc/v6.go +++ b/internal/dhcpsvc/v6.go @@ -62,10 +62,12 @@ func (c *IPv6Config) validate() (err error) { return errors.Join(errs...) } -// netInterfaceV6 is a DHCP interface for IPv6 address family. -// -// TODO(e.burkov): Add options. -type netInterfaceV6 struct { +// dhcpInterfaceV6 is a DHCP interface for IPv6 address family. +type dhcpInterfaceV6 struct { + // common is the common part of any network interface within the DHCP + // server. + common *netInterface + // rangeStart is the first IP address in the range. rangeStart netip.Addr @@ -78,10 +80,6 @@ type netInterfaceV6 struct { // intersections with implicitOpts. explicitOpts layers.DHCPv6Options - // netInterface is embedded here to provide some common network interface - // logic. - netInterface - // raSLAACOnly defines if DHCP should send ICMPv6.RA packets without MO // flags. raSLAACOnly bool @@ -90,16 +88,16 @@ type netInterfaceV6 struct { raAllowSLAAC bool } -// newNetInterfaceV6 creates a new DHCP interface for IPv6 address family with +// newDHCPInterfaceV6 creates a new DHCP interface for IPv6 address family with // the given configuration. // // TODO(e.burkov): Validate properly. -func newNetInterfaceV6( +func newDHCPInterfaceV6( ctx context.Context, l *slog.Logger, name string, conf *IPv6Config, -) (i *netInterfaceV6) { +) (i *dhcpInterfaceV6) { l = l.With(keyInterface, name, keyFamily, netutil.AddrFamilyIPv6) if !conf.Enabled { l.DebugContext(ctx, "disabled") @@ -107,13 +105,9 @@ func newNetInterfaceV6( return nil } - i = &netInterfaceV6{ - rangeStart: conf.RangeStart, - netInterface: netInterface{ - name: name, - leaseTTL: conf.LeaseDuration, - logger: l, - }, + i = &dhcpInterfaceV6{ + rangeStart: conf.RangeStart, + common: newNetInterface(name, l, conf.LeaseDuration), raSLAACOnly: conf.RASLAACOnly, raAllowSLAAC: conf.RAAllowSLAAC, } @@ -122,12 +116,12 @@ func newNetInterfaceV6( return i } -// netInterfacesV4 is a slice of network interfaces of IPv4 address family. -type netInterfacesV6 []*netInterfaceV6 +// dhcpInterfacesV6 is a slice of network interfaces of IPv6 address family. +type dhcpInterfacesV6 []*dhcpInterfaceV6 // find returns the first network interface within ifaces containing ip. It // returns false if there is no such interface. -func (ifaces netInterfacesV6) find(ip netip.Addr) (iface6 *netInterface, ok bool) { +func (ifaces dhcpInterfacesV6) find(ip netip.Addr) (iface6 *netInterface, ok bool) { // prefLen is the length of prefix to match ip against. // // TODO(e.burkov): DHCPv6 inherits the weird behavior of legacy @@ -136,7 +130,7 @@ func (ifaces netInterfacesV6) find(ip netip.Addr) (iface6 *netInterface, ok bool // be used instead. const prefLen = netutil.IPv6BitLen - 8 - i := slices.IndexFunc(ifaces, func(iface *netInterfaceV6) (contains bool) { + i := slices.IndexFunc(ifaces, func(iface *dhcpInterfaceV6) (contains bool) { return !ip.Less(iface.rangeStart) && netip.PrefixFrom(iface.rangeStart, prefLen).Contains(ip) }) @@ -144,7 +138,7 @@ func (ifaces netInterfacesV6) find(ip netip.Addr) (iface6 *netInterface, ok bool return nil, false } - return &ifaces[i].netInterface, true + return ifaces[i].common, true } // options returns the implicit and explicit options for the interface. The two diff --git a/internal/dnsforward/clientid_test.go b/internal/dnsforward/clientid_test.go index 4ea579453a9..b896a2d16dd 100644 --- a/internal/dnsforward/clientid_test.go +++ b/internal/dnsforward/clientid_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/AdguardTeam/dnsproxy/proxy" + "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/testutil" "github.com/quic-go/quic-go" "github.com/stretchr/testify/assert" @@ -217,7 +218,8 @@ func TestServer_clientIDFromDNSContext(t *testing.T) { } srv := &Server{ - conf: ServerConfig{TLSConfig: tlsConf}, + conf: ServerConfig{TLSConfig: tlsConf}, + logger: slogutil.NewDiscardLogger(), } var ( diff --git a/internal/dnsforward/config.go b/internal/dnsforward/config.go index 3caa2b0fa26..a10c403d05e 100644 --- a/internal/dnsforward/config.go +++ b/internal/dnsforward/config.go @@ -22,6 +22,7 @@ import ( "github.com/AdguardTeam/golibs/container" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/stringutil" "github.com/AdguardTeam/golibs/timeutil" @@ -301,6 +302,8 @@ type ServerConfig struct { // UpstreamMode is a enumeration of upstream mode representations. See // [proxy.UpstreamModeType]. +// +// TODO(d.kolyshev): Consider using [proxy.UpstreamMode]. type UpstreamMode string const ( @@ -339,6 +342,10 @@ func (s *Server) newProxyConfig() (conf *proxy.Config, err error) { MessageConstructor: s, } + if s.logger != nil { + conf.Logger = s.logger.With(slogutil.KeyPrefix, "dnsproxy") + } + if srvConf.EDNSClientSubnet.UseCustom { // TODO(s.chzhen): Use netip.Addr instead of net.IP inside dnsproxy. conf.EDNSAddr = net.IP(srvConf.EDNSClientSubnet.CustomIP.AsSlice()) diff --git a/internal/dnsforward/dnsforward.go b/internal/dnsforward/dnsforward.go index ac8807ab2bc..9ae6fc69075 100644 --- a/internal/dnsforward/dnsforward.go +++ b/internal/dnsforward/dnsforward.go @@ -6,6 +6,7 @@ import ( "context" "fmt" "io" + "log/slog" "net" "net/http" "net/netip" @@ -27,6 +28,7 @@ import ( "github.com/AdguardTeam/golibs/cache" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil/sysresolv" "github.com/AdguardTeam/golibs/stringutil" @@ -121,6 +123,12 @@ type Server struct { // access drops disallowed clients. access *accessManager + // logger is used for logging during server routines. + // + // TODO(d.kolyshev): Make it never nil. + // TODO(d.kolyshev): Use this logger. + logger *slog.Logger + // localDomainSuffix is the suffix used to detect internal hosts. It // must be a valid domain name plus dots on each side. localDomainSuffix string @@ -197,6 +205,10 @@ type DNSCreateParams struct { PrivateNets netutil.SubnetSet Anonymizer *aghnet.IPMut EtcHosts *aghnet.HostsContainer + + // Logger is used as a base logger. It must not be nil. + Logger *slog.Logger + LocalDomain string } @@ -233,6 +245,7 @@ func NewServer(p DNSCreateParams) (s *Server, err error) { stats: p.Stats, queryLog: p.QueryLog, privateNets: p.PrivateNets, + logger: p.Logger.With(slogutil.KeyPrefix, "dnsforward"), // TODO(e.burkov): Use some case-insensitive string comparison. localDomainSuffix: strings.ToLower(localDomainSuffix), etcHosts: etcHosts, @@ -719,6 +732,10 @@ func (s *Server) prepareInternalProxy() (err error) { MessageConstructor: s, } + if s.logger != nil { + conf.Logger = s.logger.With(slogutil.KeyPrefix, "dnsproxy") + } + err = setProxyUpstreamMode(conf, srvConf.UpstreamMode, srvConf.FastestTimeout.Duration) if err != nil { return fmt.Errorf("invalid upstream mode: %w", err) diff --git a/internal/dnsforward/dnsforward_test.go b/internal/dnsforward/dnsforward_test.go index 9e4942cc98c..c326f8aae9d 100644 --- a/internal/dnsforward/dnsforward_test.go +++ b/internal/dnsforward/dnsforward_test.go @@ -28,6 +28,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" + "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/timeutil" @@ -99,6 +100,7 @@ func createTestServer( DHCPServer: dhcp, DNSFilter: f, PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed), + Logger: slogutil.NewDiscardLogger(), }) require.NoError(t, err) @@ -339,7 +341,10 @@ func TestServer_timeout(t *testing.T) { ServePlainDNS: true, } - s, err := NewServer(DNSCreateParams{DNSFilter: createTestDNSFilter(t)}) + s, err := NewServer(DNSCreateParams{ + DNSFilter: createTestDNSFilter(t), + Logger: slogutil.NewDiscardLogger(), + }) require.NoError(t, err) err = s.Prepare(srvConf) @@ -349,7 +354,10 @@ func TestServer_timeout(t *testing.T) { }) t.Run("default", func(t *testing.T) { - s, err := NewServer(DNSCreateParams{DNSFilter: createTestDNSFilter(t)}) + s, err := NewServer(DNSCreateParams{ + DNSFilter: createTestDNSFilter(t), + Logger: slogutil.NewDiscardLogger(), + }) require.NoError(t, err) s.conf.Config.UpstreamMode = UpstreamModeLoadBalance @@ -376,7 +384,9 @@ func TestServer_Prepare_fallbacks(t *testing.T) { ServePlainDNS: true, } - s, err := NewServer(DNSCreateParams{}) + s, err := NewServer(DNSCreateParams{ + Logger: slogutil.NewDiscardLogger(), + }) require.NoError(t, err) err = s.Prepare(srvConf) @@ -962,6 +972,7 @@ func TestBlockedCustomIP(t *testing.T) { DHCPServer: dhcp, DNSFilter: f, PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed), + Logger: slogutil.NewDiscardLogger(), }) require.NoError(t, err) @@ -1127,6 +1138,7 @@ func TestRewrite(t *testing.T) { DHCPServer: dhcp, DNSFilter: f, PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed), + Logger: slogutil.NewDiscardLogger(), }) require.NoError(t, err) @@ -1256,6 +1268,7 @@ func TestPTRResponseFromDHCPLeases(t *testing.T) { }, }, PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed), + Logger: slogutil.NewDiscardLogger(), LocalDomain: localDomain, }) require.NoError(t, err) @@ -1341,6 +1354,7 @@ func TestPTRResponseFromHosts(t *testing.T) { DHCPServer: dhcp, DNSFilter: flt, PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed), + Logger: slogutil.NewDiscardLogger(), }) require.NoError(t, err) @@ -1392,24 +1406,29 @@ func TestNewServer(t *testing.T) { in DNSCreateParams wantErrMsg string }{{ - name: "success", - in: DNSCreateParams{}, + name: "success", + in: DNSCreateParams{ + Logger: slogutil.NewDiscardLogger(), + }, wantErrMsg: "", }, { name: "success_local_tld", in: DNSCreateParams{ + Logger: slogutil.NewDiscardLogger(), LocalDomain: "mynet", }, wantErrMsg: "", }, { name: "success_local_domain", in: DNSCreateParams{ + Logger: slogutil.NewDiscardLogger(), LocalDomain: "my.local.net", }, wantErrMsg: "", }, { name: "bad_local_domain", in: DNSCreateParams{ + Logger: slogutil.NewDiscardLogger(), LocalDomain: "!!!", }, wantErrMsg: `local domain: bad domain name "!!!": ` + diff --git a/internal/dnsforward/filter_test.go b/internal/dnsforward/filter_test.go index 9e172a32705..57d265f719b 100644 --- a/internal/dnsforward/filter_test.go +++ b/internal/dnsforward/filter_test.go @@ -9,6 +9,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" + "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/netutil" "github.com/miekg/dns" "github.com/stretchr/testify/assert" @@ -57,6 +58,7 @@ func TestHandleDNSRequest_handleDNSRequest(t *testing.T) { }, DNSFilter: f, PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed), + Logger: slogutil.NewDiscardLogger(), }) require.NoError(t, err) @@ -229,6 +231,7 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) { DHCPServer: &testDHCP{}, DNSFilter: f, PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed), + Logger: slogutil.NewDiscardLogger(), }) require.NoError(t, err) diff --git a/internal/dnsforward/process.go b/internal/dnsforward/process.go index 432e0d58c64..5c125def96d 100644 --- a/internal/dnsforward/process.go +++ b/internal/dnsforward/process.go @@ -159,7 +159,7 @@ func (s *Server) processInitial(dctx *dnsContext) (rc resultCode) { q := pctx.Req.Question[0] qt := q.Qtype if s.conf.AAAADisabled && qt == dns.TypeAAAA { - _ = proxy.CheckDisabledAAAARequest(pctx, true) + pctx.Res = s.newMsgNODATA(pctx.Req) return resultCodeFinish } diff --git a/internal/dnsforward/process_internal_test.go b/internal/dnsforward/process_internal_test.go index dd5d9a7786c..b23cd34f230 100644 --- a/internal/dnsforward/process_internal_test.go +++ b/internal/dnsforward/process_internal_test.go @@ -10,6 +10,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" + "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/urlfilter/rules" @@ -430,6 +431,7 @@ func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) { dnsFilter: createTestDNSFilter(t), dhcpServer: dhcp, localDomainSuffix: localDomainSuffix, + logger: slogutil.NewDiscardLogger(), } req := &dns.Msg{ @@ -565,6 +567,7 @@ func TestServer_ProcessDHCPHosts(t *testing.T) { dnsFilter: createTestDNSFilter(t), dhcpServer: testDHCP, localDomainSuffix: tc.suffix, + logger: slogutil.NewDiscardLogger(), } req := &dns.Msg{ diff --git a/internal/dnsforward/stats_test.go b/internal/dnsforward/stats_test.go index 668b885b5cf..8626c18003a 100644 --- a/internal/dnsforward/stats_test.go +++ b/internal/dnsforward/stats_test.go @@ -11,6 +11,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/stats" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" + "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -202,6 +203,7 @@ func TestServer_ProcessQueryLogsAndStats(t *testing.T) { ql := &testQueryLog{} st := &testStats{} srv := &Server{ + logger: slogutil.NewDiscardLogger(), queryLog: ql, stats: st, anonymizer: aghnet.NewIPMut(nil), diff --git a/internal/dnsforward/upstreams.go b/internal/dnsforward/upstreams.go index 6fbe0638983..00e10125b6c 100644 --- a/internal/dnsforward/upstreams.go +++ b/internal/dnsforward/upstreams.go @@ -150,12 +150,12 @@ func setProxyUpstreamMode( ) (err error) { switch upstreamMode { case UpstreamModeParallel: - conf.UpstreamMode = proxy.UModeParallel + conf.UpstreamMode = proxy.UpstreamModeParallel case UpstreamModeFastestAddr: - conf.UpstreamMode = proxy.UModeFastestAddr + conf.UpstreamMode = proxy.UpstreamModeFastestAddr conf.FastestPingTimeout = fastestTimeout case UpstreamModeLoadBalance: - conf.UpstreamMode = proxy.UModeLoadBalance + conf.UpstreamMode = proxy.UpstreamModeLoadBalance default: return fmt.Errorf("unexpected value %q", upstreamMode) } diff --git a/internal/home/controlinstall.go b/internal/home/controlinstall.go index 3d012751d64..f94457d01dd 100644 --- a/internal/home/controlinstall.go +++ b/internal/home/controlinstall.go @@ -433,7 +433,7 @@ func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request // moment we'll allow setting up TLS in the initial configuration or the // configuration itself will use HTTPS protocol, because the underlying // functions potentially restart the HTTPS server. - err = startMods() + err = startMods(web.logger) if err != nil { Context.firstRun = true copyInstallSettings(config, curConfig) diff --git a/internal/home/dns.go b/internal/home/dns.go index 64dfc1aad46..101495961b3 100644 --- a/internal/home/dns.go +++ b/internal/home/dns.go @@ -2,6 +2,7 @@ package home import ( "fmt" + "log/slog" "net" "net/netip" "net/url" @@ -43,8 +44,8 @@ func onConfigModified() { // initDNS updates all the fields of the [Context] needed to initialize the DNS // server and initializes it at last. It also must not be called unless -// [config] and [Context] are initialized. -func initDNS() (err error) { +// [config] and [Context] are initialized. l must not be nil. +func initDNS(l *slog.Logger) (err error) { anonymizer := config.anonymizer() statsDir, querylogDir, err := checkStatsAndQuerylogDirs(&Context, config) @@ -114,13 +115,16 @@ func initDNS() (err error) { anonymizer, httpRegister, tlsConf, + l, ) } // initDNSServer initializes the [context.dnsServer]. To only use the internal -// proxy, none of the arguments are required, but tlsConf still must not be nil, -// in other cases all the arguments also must not be nil. It also must not be -// called unless [config] and [Context] are initialized. +// proxy, none of the arguments are required, but tlsConf and l still must not +// be nil, in other cases all the arguments also must not be nil. It also must +// not be called unless [config] and [Context] are initialized. +// +// TODO(e.burkov): Use [dnsforward.DNSCreateParams] as a parameter. func initDNSServer( filters *filtering.DNSFilter, sts stats.Interface, @@ -129,8 +133,10 @@ func initDNSServer( anonymizer *aghnet.IPMut, httpReg aghhttp.RegisterFunc, tlsConf *tlsConfigSettings, + l *slog.Logger, ) (err error) { Context.dnsServer, err = dnsforward.NewServer(dnsforward.DNSCreateParams{ + Logger: l, DNSFilter: filters, Stats: sts, QueryLog: qlog, diff --git a/internal/home/home.go b/internal/home/home.go index fdbb4b73023..6f83fc164de 100644 --- a/internal/home/home.go +++ b/internal/home/home.go @@ -6,6 +6,7 @@ import ( "crypto/x509" "fmt" "io/fs" + "log/slog" "net/http" "net/netip" "net/url" @@ -90,6 +91,8 @@ func (c *homeContext) getDataDir() string { } // Context - a global context object +// +// TODO(a.garipov): Refactor. var Context homeContext // Main is the entry point @@ -483,7 +486,12 @@ func checkPorts() (err error) { return nil } -func initWeb(opts options, clientBuildFS fs.FS, upd *updater.Updater) (web *webAPI, err error) { +func initWeb( + opts options, + clientBuildFS fs.FS, + upd *updater.Updater, + l *slog.Logger, +) (web *webAPI, err error) { var clientFS fs.FS if opts.localFrontend { log.Info("warning: using local frontend files") @@ -525,7 +533,7 @@ func initWeb(opts options, clientBuildFS fs.FS, upd *updater.Updater) (web *webA serveHTTP3: config.DNS.ServeHTTP3, } - web = newWebAPI(webConf) + web = newWebAPI(webConf, l) if web == nil { return nil, fmt.Errorf("initializing web: %w", err) } @@ -548,10 +556,15 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) { // Configure config filename. initConfigFilename(opts) + ls := getLogSettings(opts) + // Configure log level and output. - err = configureLogger(opts) + err = configureLogger(ls) fatalOnError(err) + // TODO(a.garipov): Use slog everywhere. + slogLogger := newSlogLogger(ls) + // Print the first message after logger is configured. log.Info(version.Full()) log.Debug("current working directory is %s", Context.workDir) @@ -605,7 +618,7 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) { // TODO(e.burkov): This could be made earlier, probably as the option's // effect. - cmdlineUpdate(opts, upd) + cmdlineUpdate(opts, upd, slogLogger) if !Context.firstRun { // Save the updated config. @@ -633,11 +646,11 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) { onConfigModified() } - Context.web, err = initWeb(opts, clientBuildFS, upd) + Context.web, err = initWeb(opts, clientBuildFS, upd, slogLogger) fatalOnError(err) if !Context.firstRun { - err = initDNS() + err = initDNS(slogLogger) fatalOnError(err) Context.tls.start() @@ -698,9 +711,10 @@ func (c *configuration) anonymizer() (ipmut *aghnet.IPMut) { return aghnet.NewIPMut(anonFunc) } -// startMods initializes and starts the DNS server after installation. -func startMods() (err error) { - err = initDNS() +// startMods initializes and starts the DNS server after installation. l must +// not be nil. +func startMods(l *slog.Logger) (err error) { + err = initDNS(l) if err != nil { return err } @@ -960,8 +974,8 @@ type jsonError struct { Message string `json:"message"` } -// cmdlineUpdate updates current application and exits. -func cmdlineUpdate(opts options, upd *updater.Updater) { +// cmdlineUpdate updates current application and exits. l must not be nil. +func cmdlineUpdate(opts options, upd *updater.Updater, l *slog.Logger) { if !opts.performUpdate { return } @@ -971,7 +985,7 @@ func cmdlineUpdate(opts options, upd *updater.Updater) { // // TODO(e.burkov): We could probably initialize the internal resolver // separately. - err := initDNSServer(nil, nil, nil, nil, nil, nil, &tlsConfigSettings{}) + err := initDNSServer(nil, nil, nil, nil, nil, nil, &tlsConfigSettings{}, l) fatalOnError(err) log.Info("cmdline update: performing update") diff --git a/internal/home/log.go b/internal/home/log.go index fd18d1ec022..0b3a14a8105 100644 --- a/internal/home/log.go +++ b/internal/home/log.go @@ -3,11 +3,13 @@ package home import ( "cmp" "fmt" + "log/slog" "path/filepath" "runtime" "github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/logutil/slogutil" "gopkg.in/natefinch/lumberjack.v2" "gopkg.in/yaml.v3" ) @@ -16,10 +18,21 @@ import ( // for logger output. const configSyslog = "syslog" -// configureLogger configures logger level and output. -func configureLogger(opts options) (err error) { - ls := getLogSettings(opts) +// newSlogLogger returns new [*slog.Logger] configured with the given settings. +func newSlogLogger(ls *logSettings) (l *slog.Logger) { + if !ls.Enabled { + return slogutil.NewDiscardLogger() + } + return slogutil.New(&slogutil.Config{ + Format: slogutil.FormatAdGuardLegacy, + AddTimestamp: true, + Verbose: ls.Verbose, + }) +} + +// configureLogger configures logger level and output. +func configureLogger(ls *logSettings) (err error) { // Configure logger level. if !ls.Enabled { log.SetLevel(log.OFF) @@ -60,7 +73,7 @@ func configureLogger(opts options) (err error) { MaxAge: ls.MaxAge, }) - return nil + return err } // getLogSettings returns a log settings object properly initialized from opts. diff --git a/internal/home/web.go b/internal/home/web.go index 3c403e4bd1f..d3d1fc41595 100644 --- a/internal/home/web.go +++ b/internal/home/web.go @@ -4,6 +4,7 @@ import ( "context" "crypto/tls" "io/fs" + "log/slog" "net/http" "net/netip" "runtime" @@ -90,17 +91,22 @@ type webAPI struct { // TODO(a.garipov): Refactor all these servers. httpServer *http.Server + // logger is a slog logger used in webAPI. It must not be nil. + logger *slog.Logger + // httpsServer is the server that handles HTTPS traffic. If it is not nil, // [Web.http3Server] must also not be nil. httpsServer httpsServer } -// newWebAPI creates a new instance of the web UI and API server. -func newWebAPI(conf *webConfig) (w *webAPI) { +// newWebAPI creates a new instance of the web UI and API server. l must not be +// nil. +func newWebAPI(conf *webConfig, l *slog.Logger) (w *webAPI) { log.Info("web: initializing") w = &webAPI{ - conf: conf, + conf: conf, + logger: l, } clientFS := http.FileServer(http.FS(conf.clientFS))