Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/feature/filters-list-per-client'
Browse files Browse the repository at this point in the history
  • Loading branch information
hoang-rio committed Jul 2, 2024
2 parents 45c386e + a8af4ca commit 952bd7a
Show file tree
Hide file tree
Showing 11 changed files with 264 additions and 306 deletions.
80 changes: 35 additions & 45 deletions internal/client/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ func macToKey(mac net.HardwareAddr) (key macKey) {
}
}

// Index stores all information about persistent clients.
type Index struct {
// index stores all information about persistent clients.
type index struct {
// nameToUID maps client name to UID.
nameToUID map[string]UID

Expand All @@ -51,9 +51,9 @@ type Index struct {
subnetToUID aghalg.SortedMap[netip.Prefix, UID]
}

// NewIndex initializes the new instance of client index.
func NewIndex() (ci *Index) {
return &Index{
// newIndex initializes the new instance of client index.
func newIndex() (ci *index) {
return &index{
nameToUID: map[string]UID{},
clientIDToUID: map[string]UID{},
ipToUID: map[netip.Addr]UID{},
Expand All @@ -63,9 +63,9 @@ func NewIndex() (ci *Index) {
}
}

// Add stores information about a persistent client in the index. c must be
// add stores information about a persistent client in the index. c must be
// non-nil, have a UID, and contain at least one identifier.
func (ci *Index) Add(c *Persistent) {
func (ci *index) add(c *Persistent) {
if (c.UID == UID{}) {
panic("client must contain uid")
}
Expand All @@ -92,9 +92,9 @@ func (ci *Index) Add(c *Persistent) {
ci.uidToClient[c.UID] = c
}

// ClashesUID returns existing persistent client with the same UID as c. Note
// clashesUID returns existing persistent client with the same UID as c. Note
// that this is only possible when configuration contains duplicate fields.
func (ci *Index) ClashesUID(c *Persistent) (err error) {
func (ci *index) clashesUID(c *Persistent) (err error) {
p, ok := ci.uidToClient[c.UID]
if ok {
return fmt.Errorf("another client %q uses the same uid", p.Name)
Expand All @@ -103,9 +103,9 @@ func (ci *Index) ClashesUID(c *Persistent) (err error) {
return nil
}

// Clashes returns an error if the index contains a different persistent client
// clashes returns an error if the index contains a different persistent client
// with at least a single identifier contained by c. c must be non-nil.
func (ci *Index) Clashes(c *Persistent) (err error) {
func (ci *index) clashes(c *Persistent) (err error) {
if p := ci.clashesName(c); p != nil {
return fmt.Errorf("another client uses the same name %q", p.Name)
}
Expand Down Expand Up @@ -139,8 +139,8 @@ func (ci *Index) Clashes(c *Persistent) (err error) {

// clashesName returns existing persistent client with the same name as c or
// nil. c must be non-nil.
func (ci *Index) clashesName(c *Persistent) (existing *Persistent) {
existing, ok := ci.FindByName(c.Name)
func (ci *index) clashesName(c *Persistent) (existing *Persistent) {
existing, ok := ci.findByName(c.Name)
if !ok {
return nil
}
Expand All @@ -154,7 +154,7 @@ func (ci *Index) clashesName(c *Persistent) (existing *Persistent) {

// clashesIP returns a previous client with the same IP address as c. c must be
// non-nil.
func (ci *Index) clashesIP(c *Persistent) (p *Persistent, ip netip.Addr) {
func (ci *index) clashesIP(c *Persistent) (p *Persistent, ip netip.Addr) {
for _, ip := range c.IPs {
existing, ok := ci.ipToUID[ip]
if ok && existing != c.UID {
Expand All @@ -167,7 +167,7 @@ func (ci *Index) clashesIP(c *Persistent) (p *Persistent, ip netip.Addr) {

// clashesSubnet returns a previous client with the same subnet as c. c must be
// non-nil.
func (ci *Index) clashesSubnet(c *Persistent) (p *Persistent, s netip.Prefix) {
func (ci *index) clashesSubnet(c *Persistent) (p *Persistent, s netip.Prefix) {
for _, s = range c.Subnets {
var existing UID
var ok bool
Expand All @@ -193,7 +193,7 @@ func (ci *Index) clashesSubnet(c *Persistent) (p *Persistent, s netip.Prefix) {

// clashesMAC returns a previous client with the same MAC address as c. c must
// be non-nil.
func (ci *Index) clashesMAC(c *Persistent) (p *Persistent, mac net.HardwareAddr) {
func (ci *index) clashesMAC(c *Persistent) (p *Persistent, mac net.HardwareAddr) {
for _, mac = range c.MACs {
k := macToKey(mac)
existing, ok := ci.macToUID[k]
Expand All @@ -205,9 +205,9 @@ func (ci *Index) clashesMAC(c *Persistent) (p *Persistent, mac net.HardwareAddr)
return nil, nil
}

// Find finds persistent client by string representation of the client ID, IP
// find finds persistent client by string representation of the client ID, IP
// address, or MAC.
func (ci *Index) Find(id string) (c *Persistent, ok bool) {
func (ci *index) find(id string) (c *Persistent, ok bool) {
uid, found := ci.clientIDToUID[id]
if found {
return ci.uidToClient[uid], true
Expand All @@ -224,14 +224,14 @@ func (ci *Index) Find(id string) (c *Persistent, ok bool) {

mac, err := net.ParseMAC(id)
if err == nil {
return ci.FindByMAC(mac)
return ci.findByMAC(mac)
}

return nil, false
}

// FindByName finds persistent client by name.
func (ci *Index) FindByName(name string) (c *Persistent, found bool) {
// findByName finds persistent client by name.
func (ci *index) findByName(name string) (c *Persistent, found bool) {
uid, found := ci.nameToUID[name]
if found {
return ci.uidToClient[uid], true
Expand All @@ -241,7 +241,7 @@ func (ci *Index) FindByName(name string) (c *Persistent, found bool) {
}

// findByIP finds persistent client by IP address.
func (ci *Index) findByIP(ip netip.Addr) (c *Persistent, found bool) {
func (ci *index) findByIP(ip netip.Addr) (c *Persistent, found bool) {
uid, found := ci.ipToUID[ip]
if found {
return ci.uidToClient[uid], true
Expand All @@ -266,8 +266,8 @@ func (ci *Index) findByIP(ip netip.Addr) (c *Persistent, found bool) {
return nil, false
}

// FindByMAC finds persistent client by MAC.
func (ci *Index) FindByMAC(mac net.HardwareAddr) (c *Persistent, found bool) {
// findByMAC finds persistent client by MAC.
func (ci *index) findByMAC(mac net.HardwareAddr) (c *Persistent, found bool) {
k := macToKey(mac)
uid, found := ci.macToUID[k]
if found {
Expand All @@ -277,13 +277,13 @@ func (ci *Index) FindByMAC(mac net.HardwareAddr) (c *Persistent, found bool) {
return nil, false
}

// FindByIPWithoutZone finds a persistent client by IP address without zone. It
// findByIPWithoutZone finds a persistent client by IP address without zone. It
// strips the IPv6 zone index from the stored IP addresses before comparing,
// because querylog entries don't have it. See TODO on [querylog.logEntry.IP].
//
// Note that multiple clients can have the same IP address with different zones.
// Therefore, the result of this method is indeterminate.
func (ci *Index) FindByIPWithoutZone(ip netip.Addr) (c *Persistent) {
func (ci *index) findByIPWithoutZone(ip netip.Addr) (c *Persistent) {
if (ip == netip.Addr{}) {
return nil
}
Expand All @@ -297,9 +297,9 @@ func (ci *Index) FindByIPWithoutZone(ip netip.Addr) (c *Persistent) {
return nil
}

// Delete removes information about persistent client from the index. c must be
// remove removes information about persistent client from the index. c must be
// non-nil.
func (ci *Index) Delete(c *Persistent) {
func (ci *index) remove(c *Persistent) {
delete(ci.nameToUID, c.Name)

for _, id := range c.ClientIDs {
Expand All @@ -322,24 +322,14 @@ func (ci *Index) Delete(c *Persistent) {
delete(ci.uidToClient, c.UID)
}

// Size returns the number of persistent clients.
func (ci *Index) Size() (n int) {
// size returns the number of persistent clients.
func (ci *index) size() (n int) {
return len(ci.uidToClient)
}

// Range calls f for each persistent client, unless cont is false. The order is
// undefined.
func (ci *Index) Range(f func(c *Persistent) (cont bool)) {
for _, c := range ci.uidToClient {
if !f(c) {
return
}
}
}

// RangeByName is like [Index.Range] but sorts the persistent clients by name
// rangeByName is like [Index.Range] but sorts the persistent clients by name
// before iterating ensuring a predictable order.
func (ci *Index) RangeByName(f func(c *Persistent) (cont bool)) {
func (ci *index) rangeByName(f func(c *Persistent) (cont bool)) {
cs := maps.Values(ci.uidToClient)
slices.SortFunc(cs, func(a, b *Persistent) (n int) {
return strings.Compare(a.Name, b.Name)
Expand All @@ -352,10 +342,10 @@ func (ci *Index) RangeByName(f func(c *Persistent) (cont bool)) {
}
}

// CloseUpstreams closes upstream configurations of persistent clients.
func (ci *Index) CloseUpstreams() (err error) {
// closeUpstreams closes upstream configurations of persistent clients.
func (ci *index) closeUpstreams() (err error) {
var errs []error
ci.RangeByName(func(c *Persistent) (cont bool) {
ci.rangeByName(func(c *Persistent) (cont bool) {
err = c.CloseUpstreams()
if err != nil {
errs = append(errs, err)
Expand Down
20 changes: 10 additions & 10 deletions internal/client/index_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ import (

// newIDIndex is a helper function that returns a client index filled with
// persistent clients from the m. It also generates a UID for each client.
func newIDIndex(m []*Persistent) (ci *Index) {
ci = NewIndex()
func newIDIndex(m []*Persistent) (ci *index) {
ci = newIndex()

for _, c := range m {
c.UID = MustNewUID()
ci.Add(c)
ci.add(c)
}

return ci
Expand Down Expand Up @@ -110,7 +110,7 @@ func TestClientIndex_Find(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
for _, id := range tc.ids {
c, ok := ci.Find(id)
c, ok := ci.find(id)
require.True(t, ok)

assert.Equal(t, tc.want, c)
Expand All @@ -119,7 +119,7 @@ func TestClientIndex_Find(t *testing.T) {
}

t.Run("not_found", func(t *testing.T) {
_, ok := ci.Find(cliIPNone)
_, ok := ci.find(cliIPNone)
assert.False(t, ok)
})
}
Expand Down Expand Up @@ -171,11 +171,11 @@ func TestClientIndex_Clashes(t *testing.T) {
clone := tc.client.ShallowClone()
clone.UID = MustNewUID()

err := ci.Clashes(clone)
err := ci.clashes(clone)
require.Error(t, err)

ci.Delete(tc.client)
err = ci.Clashes(clone)
ci.remove(tc.client)
err = ci.clashes(clone)
require.NoError(t, err)
})
}
Expand Down Expand Up @@ -293,7 +293,7 @@ func TestIndex_FindByIPWithoutZone(t *testing.T) {

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := ci.FindByIPWithoutZone(tc.ip.WithZone(""))
c := ci.findByIPWithoutZone(tc.ip.WithZone(""))
require.Equal(t, tc.want, c)
})
}
Expand Down Expand Up @@ -339,7 +339,7 @@ func TestClientIndex_RangeByName(t *testing.T) {
ci := newIDIndex(tc.want)

var got []*Persistent
ci.RangeByName(func(c *Persistent) (cont bool) {
ci.rangeByName(func(c *Persistent) (cont bool) {
got = append(got, c)

return true
Expand Down
71 changes: 45 additions & 26 deletions internal/client/persistent.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ type Persistent struct {
// upstream must be used.
UpstreamConfig *proxy.CustomUpstreamConfig

// SafeSearch handles search engine hosts rewrites.
SafeSearch filtering.SafeSearch

// BlockedServices is the configuration of blocked services of a client. It
Expand All @@ -74,32 +75,65 @@ type Persistent struct {
// Name of the persistent client. Must not be empty.
Name string

Tags []string
Filters []filtering.FilterYAML
WhitelistFilters []filtering.FilterYAML
UserRules []string
Upstreams []string
// Tags is a list of client tags that categorize the client.
Tags []string

// Upstreams is a list of custom upstream DNS servers for the client.
Upstreams []string

// IPs is a list of IP addresses that identify the client. The client must
// have at least one ID (IP, subnet, MAC, or ClientID).
IPs []netip.Addr

// Subnets identifying the client. The client must have at least one ID
// (IP, subnet, MAC, or ClientID).
//
// TODO(s.chzhen): Use netutil.Prefix.
Subnets []netip.Prefix
MACs []net.HardwareAddr
Subnets []netip.Prefix

// MACs identifying the client. The client must have at least one ID (IP,
// subnet, MAC, or ClientID).
MACs []net.HardwareAddr

// ClientIDs identifying the client. The client must have at least one ID
// (IP, subnet, MAC, or ClientID).
ClientIDs []string

// UID is the unique identifier of the persistent client.
UID UID

UpstreamsCacheSize uint32
// UpstreamsCacheSize is the cache size for custom upstreams.
UpstreamsCacheSize uint32

// UpstreamsCacheEnabled specifies whether custom upstreams are used.
UpstreamsCacheEnabled bool

UseOwnSettings bool
FilteringEnabled bool
SafeBrowsingEnabled bool
ParentalEnabled bool
// UseOwnSettings specifies whether custom filtering settings are used.
UseOwnSettings bool

// FilteringEnabled specifies whether filtering is enabled.
FilteringEnabled bool

// SafeBrowsingEnabled specifies whether safe browsing is enabled.
SafeBrowsingEnabled bool

// ParentalEnabled specifies whether parental control is enabled.
ParentalEnabled bool

// UseOwnBlockedServices specifies whether custom services are blocked.
UseOwnBlockedServices bool
IgnoreQueryLog bool
IgnoreStatistics bool

// IgnoreQueryLog specifies whether the client requests are logged.
IgnoreQueryLog bool

// IgnoreStatistics specifies whether the client requests are counted.
IgnoreStatistics bool

// SafeSearchConf is the safe search filtering configuration.
//
// TODO(d.kolyshev): Make SafeSearchConf a pointer.
SafeSearchConf filtering.SafeSearchConfig
UseGlobalFilters bool
Expand Down Expand Up @@ -138,21 +172,6 @@ func (c *Persistent) validate(allTags *container.MapSet[string]) (err error) {
return nil
}

// SetTags sets the tags if they are known, otherwise logs an unknown tag.
func (c *Persistent) SetTags(tags []string, known *container.MapSet[string]) {
for _, t := range tags {
if !known.Has(t) {
log.Info("skipping unknown tag %q", t)

continue
}

c.Tags = append(c.Tags, t)
}

slices.Sort(c.Tags)
}

// SetFilters sets with allow and blocked
func (c *Persistent) SetFilters(allowFilters, blockFilters []filtering.FilterYAML) {
c.WhitelistFilters = append(c.WhitelistFilters, allowFilters...)
Expand Down
Loading

0 comments on commit 952bd7a

Please sign in to comment.