diff --git a/internal/client/index.go b/internal/client/index.go index 8cdbad13937..cad6deb39c9 100644 --- a/internal/client/index.go +++ b/internal/client/index.go @@ -30,8 +30,8 @@ func macToKey(mac net.HardwareAddr) (key macKey) { } } -// Index stores all information about persistent clients. -type Index struct { +// index stores all information about persistent clients. +type index struct { // nameToUID maps client name to UID. nameToUID map[string]UID @@ -51,9 +51,9 @@ type Index struct { subnetToUID aghalg.SortedMap[netip.Prefix, UID] } -// NewIndex initializes the new instance of client index. -func NewIndex() (ci *Index) { - return &Index{ +// newIndex initializes the new instance of client index. +func newIndex() (ci *index) { + return &index{ nameToUID: map[string]UID{}, clientIDToUID: map[string]UID{}, ipToUID: map[netip.Addr]UID{}, @@ -63,9 +63,9 @@ func NewIndex() (ci *Index) { } } -// Add stores information about a persistent client in the index. c must be +// add stores information about a persistent client in the index. c must be // non-nil, have a UID, and contain at least one identifier. -func (ci *Index) Add(c *Persistent) { +func (ci *index) add(c *Persistent) { if (c.UID == UID{}) { panic("client must contain uid") } @@ -92,9 +92,9 @@ func (ci *Index) Add(c *Persistent) { ci.uidToClient[c.UID] = c } -// ClashesUID returns existing persistent client with the same UID as c. Note +// clashesUID returns existing persistent client with the same UID as c. Note // that this is only possible when configuration contains duplicate fields. -func (ci *Index) ClashesUID(c *Persistent) (err error) { +func (ci *index) clashesUID(c *Persistent) (err error) { p, ok := ci.uidToClient[c.UID] if ok { return fmt.Errorf("another client %q uses the same uid", p.Name) @@ -103,9 +103,9 @@ func (ci *Index) ClashesUID(c *Persistent) (err error) { return nil } -// Clashes returns an error if the index contains a different persistent client +// clashes returns an error if the index contains a different persistent client // with at least a single identifier contained by c. c must be non-nil. -func (ci *Index) Clashes(c *Persistent) (err error) { +func (ci *index) clashes(c *Persistent) (err error) { if p := ci.clashesName(c); p != nil { return fmt.Errorf("another client uses the same name %q", p.Name) } @@ -139,8 +139,8 @@ func (ci *Index) Clashes(c *Persistent) (err error) { // clashesName returns existing persistent client with the same name as c or // nil. c must be non-nil. -func (ci *Index) clashesName(c *Persistent) (existing *Persistent) { - existing, ok := ci.FindByName(c.Name) +func (ci *index) clashesName(c *Persistent) (existing *Persistent) { + existing, ok := ci.findByName(c.Name) if !ok { return nil } @@ -154,7 +154,7 @@ func (ci *Index) clashesName(c *Persistent) (existing *Persistent) { // clashesIP returns a previous client with the same IP address as c. c must be // non-nil. -func (ci *Index) clashesIP(c *Persistent) (p *Persistent, ip netip.Addr) { +func (ci *index) clashesIP(c *Persistent) (p *Persistent, ip netip.Addr) { for _, ip := range c.IPs { existing, ok := ci.ipToUID[ip] if ok && existing != c.UID { @@ -167,7 +167,7 @@ func (ci *Index) clashesIP(c *Persistent) (p *Persistent, ip netip.Addr) { // clashesSubnet returns a previous client with the same subnet as c. c must be // non-nil. -func (ci *Index) clashesSubnet(c *Persistent) (p *Persistent, s netip.Prefix) { +func (ci *index) clashesSubnet(c *Persistent) (p *Persistent, s netip.Prefix) { for _, s = range c.Subnets { var existing UID var ok bool @@ -193,7 +193,7 @@ func (ci *Index) clashesSubnet(c *Persistent) (p *Persistent, s netip.Prefix) { // clashesMAC returns a previous client with the same MAC address as c. c must // be non-nil. -func (ci *Index) clashesMAC(c *Persistent) (p *Persistent, mac net.HardwareAddr) { +func (ci *index) clashesMAC(c *Persistent) (p *Persistent, mac net.HardwareAddr) { for _, mac = range c.MACs { k := macToKey(mac) existing, ok := ci.macToUID[k] @@ -205,9 +205,9 @@ func (ci *Index) clashesMAC(c *Persistent) (p *Persistent, mac net.HardwareAddr) return nil, nil } -// Find finds persistent client by string representation of the client ID, IP +// find finds persistent client by string representation of the client ID, IP // address, or MAC. -func (ci *Index) Find(id string) (c *Persistent, ok bool) { +func (ci *index) find(id string) (c *Persistent, ok bool) { uid, found := ci.clientIDToUID[id] if found { return ci.uidToClient[uid], true @@ -224,14 +224,14 @@ func (ci *Index) Find(id string) (c *Persistent, ok bool) { mac, err := net.ParseMAC(id) if err == nil { - return ci.FindByMAC(mac) + return ci.findByMAC(mac) } return nil, false } -// FindByName finds persistent client by name. -func (ci *Index) FindByName(name string) (c *Persistent, found bool) { +// findByName finds persistent client by name. +func (ci *index) findByName(name string) (c *Persistent, found bool) { uid, found := ci.nameToUID[name] if found { return ci.uidToClient[uid], true @@ -241,7 +241,7 @@ func (ci *Index) FindByName(name string) (c *Persistent, found bool) { } // findByIP finds persistent client by IP address. -func (ci *Index) findByIP(ip netip.Addr) (c *Persistent, found bool) { +func (ci *index) findByIP(ip netip.Addr) (c *Persistent, found bool) { uid, found := ci.ipToUID[ip] if found { return ci.uidToClient[uid], true @@ -266,8 +266,8 @@ func (ci *Index) findByIP(ip netip.Addr) (c *Persistent, found bool) { return nil, false } -// FindByMAC finds persistent client by MAC. -func (ci *Index) FindByMAC(mac net.HardwareAddr) (c *Persistent, found bool) { +// findByMAC finds persistent client by MAC. +func (ci *index) findByMAC(mac net.HardwareAddr) (c *Persistent, found bool) { k := macToKey(mac) uid, found := ci.macToUID[k] if found { @@ -277,13 +277,13 @@ func (ci *Index) FindByMAC(mac net.HardwareAddr) (c *Persistent, found bool) { return nil, false } -// FindByIPWithoutZone finds a persistent client by IP address without zone. It +// findByIPWithoutZone finds a persistent client by IP address without zone. It // strips the IPv6 zone index from the stored IP addresses before comparing, // because querylog entries don't have it. See TODO on [querylog.logEntry.IP]. // // Note that multiple clients can have the same IP address with different zones. // Therefore, the result of this method is indeterminate. -func (ci *Index) FindByIPWithoutZone(ip netip.Addr) (c *Persistent) { +func (ci *index) findByIPWithoutZone(ip netip.Addr) (c *Persistent) { if (ip == netip.Addr{}) { return nil } @@ -297,9 +297,9 @@ func (ci *Index) FindByIPWithoutZone(ip netip.Addr) (c *Persistent) { return nil } -// Delete removes information about persistent client from the index. c must be +// remove removes information about persistent client from the index. c must be // non-nil. -func (ci *Index) Delete(c *Persistent) { +func (ci *index) remove(c *Persistent) { delete(ci.nameToUID, c.Name) for _, id := range c.ClientIDs { @@ -322,24 +322,14 @@ func (ci *Index) Delete(c *Persistent) { delete(ci.uidToClient, c.UID) } -// Size returns the number of persistent clients. -func (ci *Index) Size() (n int) { +// size returns the number of persistent clients. +func (ci *index) size() (n int) { return len(ci.uidToClient) } -// Range calls f for each persistent client, unless cont is false. The order is -// undefined. -func (ci *Index) Range(f func(c *Persistent) (cont bool)) { - for _, c := range ci.uidToClient { - if !f(c) { - return - } - } -} - -// RangeByName is like [Index.Range] but sorts the persistent clients by name +// rangeByName is like [Index.Range] but sorts the persistent clients by name // before iterating ensuring a predictable order. -func (ci *Index) RangeByName(f func(c *Persistent) (cont bool)) { +func (ci *index) rangeByName(f func(c *Persistent) (cont bool)) { cs := maps.Values(ci.uidToClient) slices.SortFunc(cs, func(a, b *Persistent) (n int) { return strings.Compare(a.Name, b.Name) @@ -352,10 +342,10 @@ func (ci *Index) RangeByName(f func(c *Persistent) (cont bool)) { } } -// CloseUpstreams closes upstream configurations of persistent clients. -func (ci *Index) CloseUpstreams() (err error) { +// closeUpstreams closes upstream configurations of persistent clients. +func (ci *index) closeUpstreams() (err error) { var errs []error - ci.RangeByName(func(c *Persistent) (cont bool) { + ci.rangeByName(func(c *Persistent) (cont bool) { err = c.CloseUpstreams() if err != nil { errs = append(errs, err) diff --git a/internal/client/index_internal_test.go b/internal/client/index_internal_test.go index f51f461cec7..f514b995b02 100644 --- a/internal/client/index_internal_test.go +++ b/internal/client/index_internal_test.go @@ -11,12 +11,12 @@ import ( // newIDIndex is a helper function that returns a client index filled with // persistent clients from the m. It also generates a UID for each client. -func newIDIndex(m []*Persistent) (ci *Index) { - ci = NewIndex() +func newIDIndex(m []*Persistent) (ci *index) { + ci = newIndex() for _, c := range m { c.UID = MustNewUID() - ci.Add(c) + ci.add(c) } return ci @@ -110,7 +110,7 @@ func TestClientIndex_Find(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { for _, id := range tc.ids { - c, ok := ci.Find(id) + c, ok := ci.find(id) require.True(t, ok) assert.Equal(t, tc.want, c) @@ -119,7 +119,7 @@ func TestClientIndex_Find(t *testing.T) { } t.Run("not_found", func(t *testing.T) { - _, ok := ci.Find(cliIPNone) + _, ok := ci.find(cliIPNone) assert.False(t, ok) }) } @@ -171,11 +171,11 @@ func TestClientIndex_Clashes(t *testing.T) { clone := tc.client.ShallowClone() clone.UID = MustNewUID() - err := ci.Clashes(clone) + err := ci.clashes(clone) require.Error(t, err) - ci.Delete(tc.client) - err = ci.Clashes(clone) + ci.remove(tc.client) + err = ci.clashes(clone) require.NoError(t, err) }) } @@ -293,7 +293,7 @@ func TestIndex_FindByIPWithoutZone(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - c := ci.FindByIPWithoutZone(tc.ip.WithZone("")) + c := ci.findByIPWithoutZone(tc.ip.WithZone("")) require.Equal(t, tc.want, c) }) } @@ -339,7 +339,7 @@ func TestClientIndex_RangeByName(t *testing.T) { ci := newIDIndex(tc.want) var got []*Persistent - ci.RangeByName(func(c *Persistent) (cont bool) { + ci.rangeByName(func(c *Persistent) (cont bool) { got = append(got, c) return true diff --git a/internal/client/persistent.go b/internal/client/persistent.go index b573b0fe540..ce68986f826 100644 --- a/internal/client/persistent.go +++ b/internal/client/persistent.go @@ -65,6 +65,7 @@ type Persistent struct { // upstream must be used. UpstreamConfig *proxy.CustomUpstreamConfig + // SafeSearch handles search engine hosts rewrites. SafeSearch filtering.SafeSearch // BlockedServices is the configuration of blocked services of a client. It @@ -74,29 +75,62 @@ type Persistent struct { // Name of the persistent client. Must not be empty. Name string - Tags []string + // Tags is a list of client tags that categorize the client. + Tags []string + + // Upstreams is a list of custom upstream DNS servers for the client. Upstreams []string + // IPs is a list of IP addresses that identify the client. The client must + // have at least one ID (IP, subnet, MAC, or ClientID). IPs []netip.Addr + + // Subnets identifying the client. The client must have at least one ID + // (IP, subnet, MAC, or ClientID). + // // TODO(s.chzhen): Use netutil.Prefix. - Subnets []netip.Prefix - MACs []net.HardwareAddr + Subnets []netip.Prefix + + // MACs identifying the client. The client must have at least one ID (IP, + // subnet, MAC, or ClientID). + MACs []net.HardwareAddr + + // ClientIDs identifying the client. The client must have at least one ID + // (IP, subnet, MAC, or ClientID). ClientIDs []string // UID is the unique identifier of the persistent client. UID UID - UpstreamsCacheSize uint32 + // UpstreamsCacheSize is the cache size for custom upstreams. + UpstreamsCacheSize uint32 + + // UpstreamsCacheEnabled specifies whether custom upstreams are used. UpstreamsCacheEnabled bool - UseOwnSettings bool - FilteringEnabled bool - SafeBrowsingEnabled bool - ParentalEnabled bool + // UseOwnSettings specifies whether custom filtering settings are used. + UseOwnSettings bool + + // FilteringEnabled specifies whether filtering is enabled. + FilteringEnabled bool + + // SafeBrowsingEnabled specifies whether safe browsing is enabled. + SafeBrowsingEnabled bool + + // ParentalEnabled specifies whether parental control is enabled. + ParentalEnabled bool + + // UseOwnBlockedServices specifies whether custom services are blocked. UseOwnBlockedServices bool - IgnoreQueryLog bool - IgnoreStatistics bool + // IgnoreQueryLog specifies whether the client requests are logged. + IgnoreQueryLog bool + + // IgnoreStatistics specifies whether the client requests are counted. + IgnoreStatistics bool + + // SafeSearchConf is the safe search filtering configuration. + // // TODO(d.kolyshev): Make SafeSearchConf a pointer. SafeSearchConf filtering.SafeSearchConfig } @@ -134,21 +168,6 @@ func (c *Persistent) validate(allTags *container.MapSet[string]) (err error) { return nil } -// SetTags sets the tags if they are known, otherwise logs an unknown tag. -func (c *Persistent) SetTags(tags []string, known *container.MapSet[string]) { - for _, t := range tags { - if !known.Has(t) { - log.Info("skipping unknown tag %q", t) - - continue - } - - c.Tags = append(c.Tags, t) - } - - slices.Sort(c.Tags) -} - // SetIDs parses a list of strings into typed fields and returns an error if // there is one. func (c *Persistent) SetIDs(ids []string) (err error) { diff --git a/internal/client/persistent_internal_test.go b/internal/client/persistent_internal_test.go index 89190285184..a96c3778626 100644 --- a/internal/client/persistent_internal_test.go +++ b/internal/client/persistent_internal_test.go @@ -4,6 +4,7 @@ import ( "net/netip" "testing" + "github.com/AdguardTeam/golibs/container" "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -126,13 +127,19 @@ func TestPersistent_EqualIDs(t *testing.T) { } func TestPersistent_Validate(t *testing.T) { - // TODO(s.chzhen): Add test cases. + const ( + allowedTag = "allowed_tag" + notAllowedTag = "not_allowed_tag" + ) + + allowedTags := container.NewMapSet(allowedTag) + testCases := []struct { name string cli *Persistent wantErrMsg string }{{ - name: "basic", + name: "success", cli: &Persistent{ Name: "basic", IPs: []netip.Addr{ @@ -162,11 +169,24 @@ func TestPersistent_Validate(t *testing.T) { }, }, wantErrMsg: "uid required", + }, { + name: "not_allowed_tag", + cli: &Persistent{ + Name: "basic", + IPs: []netip.Addr{ + netip.MustParseAddr("1.2.3.4"), + }, + UID: MustNewUID(), + Tags: []string{ + notAllowedTag, + }, + }, + wantErrMsg: `invalid tag: "` + notAllowedTag + `"`, }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - err := tc.cli.validate(nil) + err := tc.cli.validate(allowedTags) testutil.AssertErrorMsg(t, tc.wantErrMsg, err) }) } diff --git a/internal/client/storage.go b/internal/client/storage.go index d9abc529596..2053bdf9613 100644 --- a/internal/client/storage.go +++ b/internal/client/storage.go @@ -11,6 +11,14 @@ import ( "github.com/AdguardTeam/golibs/log" ) +// Config is the client storage configuration structure. +// +// TODO(s.chzhen): Expand. +type Config struct { + // AllowedTags is a list of all allowed client tags. + AllowedTags []string +} + // Storage contains information about persistent and runtime clients. type Storage struct { // allowedTags is a set of all allowed tags. @@ -20,18 +28,22 @@ type Storage struct { mu *sync.Mutex // index contains information about persistent clients. - index *Index + index *index // runtimeIndex contains information about runtime clients. + // + // TODO(s.chzhen): Use it. runtimeIndex *RuntimeIndex } -// NewStorage returns initialized client storage. -func NewStorage(allowedTags *container.MapSet[string]) (s *Storage) { +// NewStorage returns initialized client storage. conf must not be nil. +func NewStorage(conf *Config) (s *Storage) { + allowedTags := container.NewMapSet(conf.AllowedTags...) + return &Storage{ allowedTags: allowedTags, mu: &sync.Mutex{}, - index: NewIndex(), + index: newIndex(), runtimeIndex: NewRuntimeIndex(), } } @@ -49,40 +61,45 @@ func (s *Storage) Add(p *Persistent) (err error) { s.mu.Lock() defer s.mu.Unlock() - err = s.index.ClashesUID(p) + err = s.index.clashesUID(p) if err != nil { // Don't wrap the error since there is already an annotation deferred. return err } - err = s.index.Clashes(p) + err = s.index.clashes(p) if err != nil { // Don't wrap the error since there is already an annotation deferred. return err } - s.index.Add(p) + s.index.add(p) - log.Debug("client storage: added %q: IDs: %q [%d]", p.Name, p.IDs(), s.index.Size()) + log.Debug("client storage: added %q: IDs: %q [%d]", p.Name, p.IDs(), s.index.size()) return nil } -// FindByName finds persistent client by name. -func (s *Storage) FindByName(name string) (c *Persistent, found bool) { +// FindByName finds persistent client by name. And returns its shallow copy. +func (s *Storage) FindByName(name string) (p *Persistent, ok bool) { s.mu.Lock() defer s.mu.Unlock() - return s.index.FindByName(name) + p, ok = s.index.findByName(name) + if ok { + return p.ShallowClone(), ok + } + + return nil, false } // Find finds persistent client by string representation of the client ID, IP -// address, or MAC. And returns it shallow copy. +// address, or MAC. And returns its shallow copy. func (s *Storage) Find(id string) (p *Persistent, ok bool) { s.mu.Lock() defer s.mu.Unlock() - p, ok = s.index.Find(id) + p, ok = s.index.find(id) if ok { return p.ShallowClone(), ok } @@ -101,12 +118,12 @@ func (s *Storage) FindLoose(ip netip.Addr, id string) (p *Persistent, ok bool) { s.mu.Lock() defer s.mu.Unlock() - p, ok = s.index.Find(id) + p, ok = s.index.find(id) if ok { return p.ShallowClone(), ok } - p = s.index.FindByIPWithoutZone(ip) + p = s.index.findByIPWithoutZone(ip) if p != nil { return p.ShallowClone(), true } @@ -114,12 +131,17 @@ func (s *Storage) FindLoose(ip netip.Addr, id string) (p *Persistent, ok bool) { return nil, false } -// FindByMAC finds persistent client by MAC. -func (s *Storage) FindByMAC(mac net.HardwareAddr) (c *Persistent, found bool) { +// FindByMAC finds persistent client by MAC and returns its shallow copy. +func (s *Storage) FindByMAC(mac net.HardwareAddr) (p *Persistent, ok bool) { s.mu.Lock() defer s.mu.Unlock() - return s.index.FindByMAC(mac) + p, ok = s.index.findByMAC(mac) + if ok { + return p.ShallowClone(), ok + } + + return nil, false } // RemoveByName removes persistent client information. ok is false if no such @@ -128,7 +150,7 @@ func (s *Storage) RemoveByName(name string) (ok bool) { s.mu.Lock() defer s.mu.Unlock() - p, ok := s.index.FindByName(name) + p, ok := s.index.findByName(name) if !ok { return false } @@ -137,7 +159,7 @@ func (s *Storage) RemoveByName(name string) (ok bool) { log.Error("client storage: removing client %q: %s", p.Name, err) } - s.index.Delete(p) + s.index.remove(p) return true } @@ -156,7 +178,7 @@ func (s *Storage) Update(name string, p *Persistent) (err error) { s.mu.Lock() defer s.mu.Unlock() - stored, ok := s.index.FindByName(name) + stored, ok := s.index.findByName(name) if !ok { return fmt.Errorf("client %q is not found", name) } @@ -166,14 +188,14 @@ func (s *Storage) Update(name string, p *Persistent) (err error) { // TODO(s.chzhen): Remove when frontend starts handling UIDs. p.UID = stored.UID - err = s.index.Clashes(p) + err = s.index.clashes(p) if err != nil { // Don't wrap the error since there is already an annotation deferred. return err } - s.index.Delete(stored) - s.index.Add(p) + s.index.remove(stored) + s.index.add(p) return nil } @@ -184,7 +206,7 @@ func (s *Storage) RangeByName(f func(c *Persistent) (cont bool)) { s.mu.Lock() defer s.mu.Unlock() - s.index.RangeByName(f) + s.index.rangeByName(f) } // Size returns the number of persistent clients. @@ -192,7 +214,7 @@ func (s *Storage) Size() (n int) { s.mu.Lock() defer s.mu.Unlock() - return s.index.Size() + return s.index.size() } // CloseUpstreams closes upstream configurations of persistent clients. @@ -200,11 +222,13 @@ func (s *Storage) CloseUpstreams() (err error) { s.mu.Lock() defer s.mu.Unlock() - return s.index.CloseUpstreams() + return s.index.closeUpstreams() } // ClientRuntime returns a copy of the saved runtime client by ip. If no such // client exists, returns nil. +// +// TODO(s.chzhen): Use it. func (s *Storage) ClientRuntime(ip netip.Addr) (rc *Runtime) { s.mu.Lock() defer s.mu.Unlock() @@ -214,6 +238,8 @@ func (s *Storage) ClientRuntime(ip netip.Addr) (rc *Runtime) { // AddRuntime saves the runtime client information in the storage. IP address // of a client must be unique. rc must not be nil. +// +// TODO(s.chzhen): Use it. func (s *Storage) AddRuntime(rc *Runtime) { s.mu.Lock() defer s.mu.Unlock() @@ -222,6 +248,8 @@ func (s *Storage) AddRuntime(rc *Runtime) { } // SizeRuntime returns the number of the runtime clients. +// +// TODO(s.chzhen): Use it. func (s *Storage) SizeRuntime() (n int) { s.mu.Lock() defer s.mu.Unlock() @@ -230,6 +258,8 @@ func (s *Storage) SizeRuntime() (n int) { } // RangeRuntime calls f for each runtime client in an undefined order. +// +// TODO(s.chzhen): Use it. func (s *Storage) RangeRuntime(f func(rc *Runtime) (cont bool)) { s.mu.Lock() defer s.mu.Unlock() @@ -238,6 +268,8 @@ func (s *Storage) RangeRuntime(f func(rc *Runtime) (cont bool)) { } // DeleteRuntime removes the runtime client by ip. +// +// TODO(s.chzhen): Use it. func (s *Storage) DeleteRuntime(ip netip.Addr) { s.mu.Lock() defer s.mu.Unlock() @@ -247,6 +279,8 @@ func (s *Storage) DeleteRuntime(ip netip.Addr) { // DeleteBySource removes all runtime clients that have information only from // the specified source and returns the number of removed clients. +// +// TODO(s.chzhen): Use it. func (s *Storage) DeleteBySource(src Source) (n int) { s.mu.Lock() defer s.mu.Unlock() diff --git a/internal/client/storage_test.go b/internal/client/storage_test.go index fef021085a8..abfc6d6287d 100644 --- a/internal/client/storage_test.go +++ b/internal/client/storage_test.go @@ -16,7 +16,9 @@ import ( func newStorage(tb testing.TB, m []*client.Persistent) (s *client.Storage) { tb.Helper() - s = client.NewStorage(nil) + s = client.NewStorage(&client.Config{ + AllowedTags: nil, + }) for _, c := range m { c.UID = client.MustNewUID() @@ -57,7 +59,9 @@ func TestStorage_Add(t *testing.T) { UID: existingClientUID, } - s := client.NewStorage(nil) + s := client.NewStorage(&client.Config{ + AllowedTags: nil, + }) err := s.Add(existingClient) require.NoError(t, err) @@ -137,7 +141,9 @@ func TestStorage_RemoveByName(t *testing.T) { UID: client.MustNewUID(), } - s := client.NewStorage(nil) + s := client.NewStorage(&client.Config{ + AllowedTags: nil, + }) err := s.Add(existingClient) require.NoError(t, err) @@ -162,7 +168,9 @@ func TestStorage_RemoveByName(t *testing.T) { } t.Run("duplicate_remove", func(t *testing.T) { - s = client.NewStorage(nil) + s = client.NewStorage(&client.Config{ + AllowedTags: nil, + }) err = s.Add(existingClient) require.NoError(t, err) diff --git a/internal/home/clients.go b/internal/home/clients.go index 9d39451dae8..aee32f9253d 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -19,7 +19,6 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/whois" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" - "github.com/AdguardTeam/golibs/container" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/hostsfile" "github.com/AdguardTeam/golibs/log" @@ -45,14 +44,12 @@ type DHCP interface { // clientsContainer is the storage of all runtime and persistent clients. type clientsContainer struct { - // clientIndex stores information about persistent clients. - clientIndex *client.Index + // storage stores information about persistent clients. + storage *client.Storage // runtimeIndex stores information about runtime clients. runtimeIndex *client.RuntimeIndex - allTags *container.MapSet[string] - // dhcp is the DHCP service implementation. dhcp DHCP @@ -104,15 +101,15 @@ func (clients *clientsContainer) Init( filteringConf *filtering.Config, ) (err error) { // TODO(s.chzhen): Refactor it. - if clients.clientIndex != nil { + if clients.storage != nil { return errors.Error("clients container already initialized") } clients.runtimeIndex = client.NewRuntimeIndex() - clients.clientIndex = client.NewIndex() - - clients.allTags = container.NewMapSet(clientTags...) + clients.storage = client.NewStorage(&client.Config{ + AllowedTags: clientTags, + }) // TODO(e.burkov): Use [dhcpsvc] implementation when it's ready. clients.dhcp = dhcpServer @@ -217,7 +214,6 @@ type clientObject struct { // toPersistent returns an initialized persistent client if there are no errors. func (o *clientObject) toPersistent( filteringConf *filtering.Config, - allTags *container.MapSet[string], ) (cli *client.Persistent, err error) { cli = &client.Persistent{ Name: o.Name, @@ -274,7 +270,7 @@ func (o *clientObject) toPersistent( cli.BlockedServices = o.BlockedServices.Clone() - cli.SetTags(o.Tags, allTags) + cli.Tags = slices.Clone(o.Tags) return cli, nil } @@ -287,22 +283,14 @@ func (clients *clientsContainer) addFromConfig( ) (err error) { for i, o := range objects { var cli *client.Persistent - cli, err = o.toPersistent(filteringConf, clients.allTags) + cli, err = o.toPersistent(filteringConf) if err != nil { return fmt.Errorf("clients: init persistent client at index %d: %w", i, err) } - // TODO(s.chzhen): Consider moving to the client index constructor. - err = clients.clientIndex.ClashesUID(cli) + err = clients.storage.Add(cli) if err != nil { - return fmt.Errorf("adding client %s at index %d: %w", cli.Name, i, err) - } - - err = clients.add(cli) - if err != nil { - // TODO(s.chzhen): Return an error instead of logging if more - // stringent requirements are implemented. - log.Error("clients: adding client %s at index %d: %s", cli.Name, i, err) + return fmt.Errorf("adding client %q at index %d: %w", cli.Name, i, err) } } @@ -315,8 +303,8 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) { clients.lock.Lock() defer clients.lock.Unlock() - objs = make([]*clientObject, 0, clients.clientIndex.Size()) - clients.clientIndex.RangeByName(func(cli *client.Persistent) (cont bool) { + objs = make([]*clientObject, 0, clients.storage.Size()) + clients.storage.RangeByName(func(cli *client.Persistent) (cont bool) { objs = append(objs, &clientObject{ Name: cli.Name, @@ -360,7 +348,7 @@ func (clients *clientsContainer) periodicUpdate() { // clientSource checks if client with this IP address already exists and returns // the source which updated it last. It returns [client.SourceNone] if the -// client doesn't exist. +// client doesn't exist. Note that it is only used in tests. func (clients *clientsContainer) clientSource(ip netip.Addr) (src client.Source) { clients.lock.Lock() defer clients.lock.Unlock() @@ -419,12 +407,8 @@ func (clients *clientsContainer) clientOrArtificial( } }() - cli, ok := clients.find(id) - if !ok { - cli = clients.clientIndex.FindByIPWithoutZone(ip) - } - - if cli != nil { + cli, ok := clients.storage.FindLoose(ip, id) + if ok { return &querylog.Client{ Name: cli.Name, IgnoreQueryLog: cli.IgnoreQueryLog, @@ -456,7 +440,7 @@ func (clients *clientsContainer) find(id string) (c *client.Persistent, ok bool) return nil, false } - return c.ShallowClone(), true + return c, true } // shouldCountClient is a wrapper around [clientsContainer.find] to make it a @@ -530,7 +514,7 @@ func (clients *clientsContainer) UpstreamConfigByID( // findLocked searches for a client by its ID. clients.lock is expected to be // locked. func (clients *clientsContainer) findLocked(id string) (c *client.Persistent, ok bool) { - c, ok = clients.clientIndex.Find(id) + c, ok = clients.storage.Find(id) if ok { return c, true } @@ -552,7 +536,7 @@ func (clients *clientsContainer) findDHCP(ip netip.Addr) (c *client.Persistent, return nil, false } - return clients.clientIndex.FindByMAC(foundMAC) + return clients.storage.FindByMAC(foundMAC) } // runtimeClient returns a runtime client from internal index. Note that it @@ -586,114 +570,6 @@ func (clients *clientsContainer) findRuntimeClient(ip netip.Addr) (rc *client.Ru return rc } -// check validates the client. It also sorts the client tags. -func (clients *clientsContainer) check(c *client.Persistent) (err error) { - switch { - case c == nil: - return errors.Error("client is nil") - case c.Name == "": - return errors.Error("invalid name") - case c.IDsLen() == 0: - return errors.Error("id required") - default: - // Go on. - } - - for _, t := range c.Tags { - if !clients.allTags.Has(t) { - return fmt.Errorf("invalid tag: %q", t) - } - } - - // TODO(s.chzhen): Move to the constructor. - slices.Sort(c.Tags) - - _, err = proxy.ParseUpstreamsConfig(c.Upstreams, &upstream.Options{}) - if err != nil { - return fmt.Errorf("invalid upstream servers: %w", err) - } - - return nil -} - -// add adds a persistent client or returns an error. -func (clients *clientsContainer) add(c *client.Persistent) (err error) { - err = clients.check(c) - if err != nil { - // Don't wrap the error since it's informative enough as is. - return err - } - - clients.lock.Lock() - defer clients.lock.Unlock() - - err = clients.clientIndex.Clashes(c) - if err != nil { - // Don't wrap the error since it's informative enough as is. - return err - } - - clients.addLocked(c) - - log.Debug("clients: added %q: ID:%q [%d]", c.Name, c.IDs(), clients.clientIndex.Size()) - - return nil -} - -// addLocked c to the indexes. clients.lock is expected to be locked. -func (clients *clientsContainer) addLocked(c *client.Persistent) { - clients.clientIndex.Add(c) -} - -// remove removes a client. ok is false if there is no such client. -func (clients *clientsContainer) remove(name string) (ok bool) { - clients.lock.Lock() - defer clients.lock.Unlock() - - c, ok := clients.clientIndex.FindByName(name) - if !ok { - return false - } - - clients.removeLocked(c) - - return true -} - -// removeLocked removes c from the indexes. clients.lock is expected to be -// locked. -func (clients *clientsContainer) removeLocked(c *client.Persistent) { - if err := c.CloseUpstreams(); err != nil { - log.Error("client container: removing client %s: %s", c.Name, err) - } - - // Update the ID index. - clients.clientIndex.Delete(c) -} - -// update updates a client by its name. -func (clients *clientsContainer) update(prev, c *client.Persistent) (err error) { - err = clients.check(c) - if err != nil { - // Don't wrap the error since it's informative enough as is. - return err - } - - clients.lock.Lock() - defer clients.lock.Unlock() - - err = clients.clientIndex.Clashes(c) - if err != nil { - // Don't wrap the error since it's informative enough as is. - return err - } - - clients.removeLocked(prev) - clients.addLocked(c) - - return nil -} - // setWHOISInfo sets the WHOIS information for a client. clients.lock is // expected to be locked. func (clients *clientsContainer) setWHOISInfo(ip netip.Addr, wi *whois.Info) { @@ -855,5 +731,5 @@ func (clients *clientsContainer) addFromSystemARP() { // close gracefully closes all the client-specific upstream configurations of // the persistent clients. func (clients *clientsContainer) close() (err error) { - return clients.clientIndex.CloseUpstreams() + return clients.storage.CloseUpstreams() } diff --git a/internal/home/clients_internal_test.go b/internal/home/clients_internal_test.go index d371df7b6ff..2c90a1e0d29 100644 --- a/internal/home/clients_internal_test.go +++ b/internal/home/clients_internal_test.go @@ -72,7 +72,7 @@ func TestClients(t *testing.T) { IPs: []netip.Addr{cli1IP, cliIPv6}, } - err := clients.add(c) + err := clients.storage.Add(c) require.NoError(t, err) c = &client.Persistent{ @@ -81,7 +81,7 @@ func TestClients(t *testing.T) { IPs: []netip.Addr{cli2IP}, } - err = clients.add(c) + err = clients.storage.Add(c) require.NoError(t, err) c, ok := clients.find(cli1) @@ -107,7 +107,7 @@ func TestClients(t *testing.T) { }) t.Run("add_fail_name", func(t *testing.T) { - err := clients.add(&client.Persistent{ + err := clients.storage.Add(&client.Persistent{ Name: "client1", UID: client.MustNewUID(), IPs: []netip.Addr{netip.MustParseAddr("1.2.3.5")}, @@ -116,7 +116,7 @@ func TestClients(t *testing.T) { }) t.Run("add_fail_ip", func(t *testing.T) { - err := clients.add(&client.Persistent{ + err := clients.storage.Add(&client.Persistent{ Name: "client3", UID: client.MustNewUID(), }) @@ -124,7 +124,7 @@ func TestClients(t *testing.T) { }) t.Run("update_fail_ip", func(t *testing.T) { - err := clients.update(&client.Persistent{Name: "client1"}, &client.Persistent{ + err := clients.storage.Update("client1", &client.Persistent{ Name: "client1", UID: client.MustNewUID(), }) @@ -139,11 +139,11 @@ func TestClients(t *testing.T) { cliNewIP = netip.MustParseAddr(cliNew) ) - prev, ok := clients.clientIndex.FindByName("client1") + prev, ok := clients.storage.FindByName("client1") require.True(t, ok) require.NotNil(t, prev) - err := clients.update(prev, &client.Persistent{ + err := clients.storage.Update("client1", &client.Persistent{ Name: "client1", UID: prev.UID, IPs: []netip.Addr{cliNewIP}, @@ -155,11 +155,11 @@ func TestClients(t *testing.T) { assert.Equal(t, clients.clientSource(cliNewIP), client.SourcePersistent) - prev, ok = clients.clientIndex.FindByName("client1") + prev, ok = clients.storage.FindByName("client1") require.True(t, ok) require.NotNil(t, prev) - err = clients.update(prev, &client.Persistent{ + err = clients.storage.Update("client1", &client.Persistent{ Name: "client1-renamed", UID: prev.UID, IPs: []netip.Addr{cliNewIP}, @@ -173,7 +173,7 @@ func TestClients(t *testing.T) { assert.Equal(t, "client1-renamed", c.Name) assert.True(t, c.UseOwnSettings) - nilCli, ok := clients.clientIndex.FindByName("client1") + nilCli, ok := clients.storage.FindByName("client1") require.False(t, ok) assert.Nil(t, nilCli) @@ -184,7 +184,7 @@ func TestClients(t *testing.T) { }) t.Run("del_success", func(t *testing.T) { - ok := clients.remove("client1-renamed") + ok := clients.storage.RemoveByName("client1-renamed") require.True(t, ok) _, ok = clients.find("1.1.1.2") @@ -192,7 +192,7 @@ func TestClients(t *testing.T) { }) t.Run("del_fail", func(t *testing.T) { - ok := clients.remove("client3") + ok := clients.storage.RemoveByName("client3") assert.False(t, ok) }) @@ -261,7 +261,7 @@ func TestClientsWHOIS(t *testing.T) { t.Run("can't_set_manually-added", func(t *testing.T) { ip := netip.MustParseAddr("1.1.1.2") - err := clients.add(&client.Persistent{ + err := clients.storage.Add(&client.Persistent{ Name: "client1", UID: client.MustNewUID(), IPs: []netip.Addr{netip.MustParseAddr("1.1.1.2")}, @@ -272,7 +272,7 @@ func TestClientsWHOIS(t *testing.T) { rc := clients.runtimeIndex.Client(ip) require.Nil(t, rc) - assert.True(t, clients.remove("client1")) + assert.True(t, clients.storage.RemoveByName("client1")) }) } @@ -283,7 +283,7 @@ func TestClientsAddExisting(t *testing.T) { ip := netip.MustParseAddr("1.1.1.1") // Add a client. - err := clients.add(&client.Persistent{ + err := clients.storage.Add(&client.Persistent{ Name: "client1", UID: client.MustNewUID(), IPs: []netip.Addr{ip, netip.MustParseAddr("1:2:3::4")}, @@ -333,7 +333,7 @@ func TestClientsAddExisting(t *testing.T) { require.NoError(t, err) // Add a new client with the same IP as for a client with MAC. - err = clients.add(&client.Persistent{ + err = clients.storage.Add(&client.Persistent{ Name: "client2", UID: client.MustNewUID(), IPs: []netip.Addr{ip}, @@ -341,7 +341,7 @@ func TestClientsAddExisting(t *testing.T) { require.NoError(t, err) // Add a new client with the IP from the first client's IP range. - err = clients.add(&client.Persistent{ + err = clients.storage.Add(&client.Persistent{ Name: "client3", UID: client.MustNewUID(), IPs: []netip.Addr{netip.MustParseAddr("2.2.2.2")}, @@ -354,7 +354,7 @@ func TestClientsCustomUpstream(t *testing.T) { clients := newClientsContainer(t) // Add client with upstreams. - err := clients.add(&client.Persistent{ + err := clients.storage.Add(&client.Persistent{ Name: "client1", UID: client.MustNewUID(), IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1"), netip.MustParseAddr("1:2:3::4")}, diff --git a/internal/home/clientshttp.go b/internal/home/clientshttp.go index fbec5c23714..a8b318353fb 100644 --- a/internal/home/clientshttp.go +++ b/internal/home/clientshttp.go @@ -96,7 +96,7 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http clients.lock.Lock() defer clients.lock.Unlock() - clients.clientIndex.Range(func(c *client.Persistent) (cont bool) { + clients.storage.RangeByName(func(c *client.Persistent) (cont bool) { cj := clientToJSON(c) data.Clients = append(data.Clients, cj) @@ -336,7 +336,7 @@ func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http. return } - err = clients.add(c) + err = clients.storage.Add(c) if err != nil { aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) @@ -364,7 +364,7 @@ func (clients *clientsContainer) handleDelClient(w http.ResponseWriter, r *http. return } - if !clients.remove(cj.Name) { + if !clients.storage.RemoveByName(cj.Name) { aghhttp.Error(r, w, http.StatusBadRequest, "Client not found") return @@ -399,30 +399,14 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht return } - var prev *client.Persistent - var ok bool - - func() { - clients.lock.Lock() - defer clients.lock.Unlock() - - prev, ok = clients.clientIndex.FindByName(dj.Name) - }() - - if !ok { - aghhttp.Error(r, w, http.StatusBadRequest, "client not found") - - return - } - - c, err := clients.jsonToClient(dj.Data, prev) + c, err := clients.jsonToClient(dj.Data, nil) if err != nil { aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) return } - err = clients.update(prev, c) + err = clients.storage.Update(dj.Name, c) if err != nil { aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) diff --git a/internal/home/clientshttp_internal_test.go b/internal/home/clientshttp_internal_test.go index aa2f40fb15d..7c1f3dfaf86 100644 --- a/internal/home/clientshttp_internal_test.go +++ b/internal/home/clientshttp_internal_test.go @@ -198,11 +198,11 @@ func TestClientsContainer_HandleDelClient(t *testing.T) { clients := newClientsContainer(t) clientOne := newPersistentClientWithIDs(t, "client1", []string{testClientIP1}) - err := clients.add(clientOne) + err := clients.storage.Add(clientOne) require.NoError(t, err) clientTwo := newPersistentClientWithIDs(t, "client2", []string{testClientIP2}) - err = clients.add(clientTwo) + err = clients.storage.Add(clientTwo) require.NoError(t, err) assertPersistentClients(t, clients, []*client.Persistent{clientOne, clientTwo}) @@ -260,7 +260,7 @@ func TestClientsContainer_HandleUpdateClient(t *testing.T) { clients := newClientsContainer(t) clientOne := newPersistentClientWithIDs(t, "client1", []string{testClientIP1}) - err := clients.add(clientOne) + err := clients.storage.Add(clientOne) require.NoError(t, err) assertPersistentClients(t, clients, []*client.Persistent{clientOne}) @@ -342,11 +342,11 @@ func TestClientsContainer_HandleFindClient(t *testing.T) { } clientOne := newPersistentClientWithIDs(t, "client1", []string{testClientIP1}) - err := clients.add(clientOne) + err := clients.storage.Add(clientOne) require.NoError(t, err) clientTwo := newPersistentClientWithIDs(t, "client2", []string{testClientIP2}) - err = clients.add(clientTwo) + err = clients.storage.Add(clientTwo) require.NoError(t, err) assertPersistentClients(t, clients, []*client.Persistent{clientOne, clientTwo}) diff --git a/internal/home/dns_internal_test.go b/internal/home/dns_internal_test.go index 8413e2a33fa..4adaec81db2 100644 --- a/internal/home/dns_internal_test.go +++ b/internal/home/dns_internal_test.go @@ -13,17 +13,21 @@ import ( var testIPv4 = netip.AddrFrom4([4]byte{1, 2, 3, 4}) -// newIDIndex is a helper function that returns a client index filled with -// persistent clients from the m. It also generates a UID for each client. -func newIDIndex(m []*client.Persistent) (ci *client.Index) { - ci = client.NewIndex() - - for _, c := range m { - c.UID = client.MustNewUID() - ci.Add(c) +// newStorage is a helper function that returns a client storage filled with +// persistent clients. It also generates a UID for each client. +func newStorage(tb testing.TB, clients []*client.Persistent) (s *client.Storage) { + tb.Helper() + + s = client.NewStorage(&client.Config{ + AllowedTags: nil, + }) + + for _, p := range clients { + p.UID = client.MustNewUID() + require.NoError(tb, s.Add(p)) } - return ci + return s } func TestApplyAdditionalFiltering(t *testing.T) { @@ -36,7 +40,8 @@ func TestApplyAdditionalFiltering(t *testing.T) { }, nil) require.NoError(t, err) - Context.clients.clientIndex = newIDIndex([]*client.Persistent{{ + Context.clients.storage = newStorage(t, []*client.Persistent{{ + Name: "default", ClientIDs: []string{"default"}, UseOwnSettings: false, SafeSearchConf: filtering.SafeSearchConfig{Enabled: false}, @@ -44,6 +49,7 @@ func TestApplyAdditionalFiltering(t *testing.T) { SafeBrowsingEnabled: false, ParentalEnabled: false, }, { + Name: "custom_filtering", ClientIDs: []string{"custom_filtering"}, UseOwnSettings: true, SafeSearchConf: filtering.SafeSearchConfig{Enabled: true}, @@ -51,6 +57,7 @@ func TestApplyAdditionalFiltering(t *testing.T) { SafeBrowsingEnabled: true, ParentalEnabled: true, }, { + Name: "partial_custom_filtering", ClientIDs: []string{"partial_custom_filtering"}, UseOwnSettings: true, SafeSearchConf: filtering.SafeSearchConfig{Enabled: true}, @@ -121,16 +128,19 @@ func TestApplyAdditionalFiltering_blockedServices(t *testing.T) { }, nil) require.NoError(t, err) - Context.clients.clientIndex = newIDIndex([]*client.Persistent{{ + Context.clients.storage = newStorage(t, []*client.Persistent{{ + Name: "default", ClientIDs: []string{"default"}, UseOwnBlockedServices: false, }, { + Name: "no_services", ClientIDs: []string{"no_services"}, BlockedServices: &filtering.BlockedServices{ Schedule: schedule.EmptyWeekly(), }, UseOwnBlockedServices: true, }, { + Name: "services", ClientIDs: []string{"services"}, BlockedServices: &filtering.BlockedServices{ Schedule: schedule.EmptyWeekly(), @@ -138,6 +148,7 @@ func TestApplyAdditionalFiltering_blockedServices(t *testing.T) { }, UseOwnBlockedServices: true, }, { + Name: "invalid_services", ClientIDs: []string{"invalid_services"}, BlockedServices: &filtering.BlockedServices{ Schedule: schedule.EmptyWeekly(), @@ -145,6 +156,7 @@ func TestApplyAdditionalFiltering_blockedServices(t *testing.T) { }, UseOwnBlockedServices: true, }, { + Name: "allow_all", ClientIDs: []string{"allow_all"}, BlockedServices: &filtering.BlockedServices{ Schedule: schedule.FullWeekly(),