Skip to content

Commit

Permalink
[client] Persist route selection (#2810)
Browse files Browse the repository at this point in the history
  • Loading branch information
lixmal authored Dec 2, 2024
1 parent ecb44ff commit 5142dc5
Show file tree
Hide file tree
Showing 10 changed files with 273 additions and 40 deletions.
10 changes: 10 additions & 0 deletions client/firewall/iptables/rulestore_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ func (s *ipList) UnmarshalJSON(data []byte) error {
return err
}
s.ips = temp.IPs

if temp.IPs == nil {
temp.IPs = make(map[string]struct{})
}

return nil
}

Expand Down Expand Up @@ -89,5 +94,10 @@ func (s *ipsetStore) UnmarshalJSON(data []byte) error {
return err
}
s.ipsets = temp.IPSets

if temp.IPSets == nil {
temp.IPSets = make(map[string]*ipList)
}

return nil
}
13 changes: 11 additions & 2 deletions client/internal/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -349,8 +349,17 @@ func (e *Engine) Start() error {
}
e.dnsServer = dnsServer

e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.config.DNSRouteInterval, e.wgInterface, e.statusRecorder, e.relayManager, initialRoutes)
beforePeerHook, afterPeerHook, err := e.routeManager.Init(e.stateManager)
e.routeManager = routemanager.NewManager(
e.ctx,
e.config.WgPrivateKey.PublicKey().String(),
e.config.DNSRouteInterval,
e.wgInterface,
e.statusRecorder,
e.relayManager,
initialRoutes,
e.stateManager,
)
beforePeerHook, afterPeerHook, err := e.routeManager.Init()
if err != nil {
log.Errorf("Failed to initialize route manager: %s", err)
} else {
Expand Down
5 changes: 4 additions & 1 deletion client/internal/engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,12 +245,15 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
nil)

wgIface := &iface.MockWGIface{
NameFunc: func() string { return "utun102" },
RemovePeerFunc: func(peerKey string) error {
return nil
},
}
engine.wgInterface = wgIface
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil)
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil, nil)
_, _, err = engine.routeManager.Init()
require.NoError(t, err)
engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
}
Expand Down
38 changes: 34 additions & 4 deletions client/internal/routemanager/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import (

// Manager is a route manager interface
type Manager interface {
Init(*statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error)
TriggerSelection(route.HAMap)
GetRouteSelector() *routeselector.RouteSelector
Expand All @@ -59,6 +59,7 @@ type DefaultManager struct {
routeRefCounter *refcounter.RouteRefCounter
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter
dnsRouteInterval time.Duration
stateManager *statemanager.Manager
}

func NewManager(
Expand All @@ -69,6 +70,7 @@ func NewManager(
statusRecorder *peer.Status,
relayMgr *relayClient.Manager,
initialRoutes []*route.Route,
stateManager *statemanager.Manager,
) *DefaultManager {
mCTX, cancel := context.WithCancel(ctx)
notifier := notifier.NewNotifier()
Expand All @@ -80,12 +82,12 @@ func NewManager(
dnsRouteInterval: dnsRouteInterval,
clientNetworks: make(map[route.HAUniqueID]*clientNetwork),
relayMgr: relayMgr,
routeSelector: routeselector.NewRouteSelector(),
sysOps: sysOps,
statusRecorder: statusRecorder,
wgInterface: wgInterface,
pubKey: pubKey,
notifier: notifier,
stateManager: stateManager,
}

dm.routeRefCounter = refcounter.New(
Expand Down Expand Up @@ -121,7 +123,7 @@ func NewManager(
}

// Init sets up the routing
func (m *DefaultManager) Init(stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
if nbnet.CustomRoutingDisabled() {
return nil, nil, nil
}
Expand All @@ -137,14 +139,38 @@ func (m *DefaultManager) Init(stateManager *statemanager.Manager) (nbnet.AddHook

ips := resolveURLsToIPs(initialAddresses)

beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips, stateManager)
beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips, m.stateManager)
if err != nil {
return nil, nil, fmt.Errorf("setup routing: %w", err)
}

m.routeSelector = m.initSelector()

log.Info("Routing setup complete")
return beforePeerHook, afterPeerHook, nil
}

func (m *DefaultManager) initSelector() *routeselector.RouteSelector {
var state *SelectorState
m.stateManager.RegisterState(state)

// restore selector state if it exists
if err := m.stateManager.LoadState(state); err != nil {
log.Warnf("failed to load state: %v", err)
return routeselector.NewRouteSelector()
}

if state := m.stateManager.GetState(state); state != nil {
if selector, ok := state.(*SelectorState); ok {
return (*routeselector.RouteSelector)(selector)
}

log.Warnf("failed to convert state with type %T to SelectorState", state)
}

return routeselector.NewRouteSelector()
}

func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
var err error
m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder)
Expand Down Expand Up @@ -252,6 +278,10 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
go clientNetworkWatcher.peersStateAndUpdateWatcher()
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes})
}

if err := m.stateManager.UpdateState((*SelectorState)(m.routeSelector)); err != nil {
log.Errorf("failed to update state: %v", err)
}
}

// stopObsoleteClients stops the client network watcher for the networks that are not in the new list
Expand Down
4 changes: 2 additions & 2 deletions client/internal/routemanager/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -424,9 +424,9 @@ func TestManagerUpdateRoutes(t *testing.T) {

statusRecorder := peer.NewRecorder("https://mgm")
ctx := context.TODO()
routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil)
routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil, nil)

_, _, err = routeManager.Init(nil)
_, _, err = routeManager.Init()

require.NoError(t, err, "should init route manager")
defer routeManager.Stop(nil)
Expand Down
2 changes: 1 addition & 1 deletion client/internal/routemanager/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ type MockManager struct {
StopFunc func(manager *statemanager.Manager)
}

func (m *MockManager) Init(*statemanager.Manager) (net.AddHookFunc, net.RemoveHookFunc, error) {
func (m *MockManager) Init() (net.AddHookFunc, net.RemoveHookFunc, error) {
return nil, nil, nil
}

Expand Down
13 changes: 13 additions & 0 deletions client/internal/routemanager/refcounter/refcounter.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,14 @@ func New[Key comparable, I, O any](add AddFunc[Key, I, O], remove RemoveFunc[Key
}

// LoadData loads the data from the existing counter
// The passed counter should not be used any longer after calling this function.
func (rm *Counter[Key, I, O]) LoadData(
existingCounter *Counter[Key, I, O],
) {
rm.mu.Lock()
defer rm.mu.Unlock()
existingCounter.mu.Lock()
defer existingCounter.mu.Unlock()

rm.refCountMap = existingCounter.refCountMap
rm.idMap = existingCounter.idMap
Expand Down Expand Up @@ -231,6 +234,9 @@ func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) {

// UnmarshalJSON implements the json.Unmarshaler interface for Counter.
func (rm *Counter[Key, I, O]) UnmarshalJSON(data []byte) error {
rm.mu.Lock()
defer rm.mu.Unlock()

var temp struct {
RefCountMap map[Key]Ref[O] `json:"refCountMap"`
IDMap map[string][]Key `json:"idMap"`
Expand All @@ -241,6 +247,13 @@ func (rm *Counter[Key, I, O]) UnmarshalJSON(data []byte) error {
rm.refCountMap = temp.RefCountMap
rm.idMap = temp.IDMap

if temp.RefCountMap == nil {
temp.RefCountMap = map[Key]Ref[O]{}
}
if temp.IDMap == nil {
temp.IDMap = map[string][]Key{}
}

return nil
}

Expand Down
19 changes: 19 additions & 0 deletions client/internal/routemanager/state.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package routemanager

import (
"github.com/netbirdio/netbird/client/internal/routeselector"
)

type SelectorState routeselector.RouteSelector

func (s *SelectorState) Name() string {
return "routeselector_state"
}

func (s *SelectorState) MarshalJSON() ([]byte, error) {
return (*routeselector.RouteSelector)(s).MarshalJSON()
}

func (s *SelectorState) UnmarshalJSON(data []byte) error {
return (*routeselector.RouteSelector)(s).UnmarshalJSON(data)
}
67 changes: 67 additions & 0 deletions client/internal/routeselector/routeselector.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package routeselector

import (
"encoding/json"
"fmt"
"slices"
"sync"

"github.com/hashicorp/go-multierror"
"golang.org/x/exp/maps"
Expand All @@ -12,6 +14,7 @@ import (
)

type RouteSelector struct {
mu sync.RWMutex
selectedRoutes map[route.NetID]struct{}
selectAll bool
}
Expand All @@ -26,6 +29,9 @@ func NewRouteSelector() *RouteSelector {

// SelectRoutes updates the selected routes based on the provided route IDs.
func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, allRoutes []route.NetID) error {
rs.mu.Lock()
defer rs.mu.Unlock()

if !appendRoute {
rs.selectedRoutes = map[route.NetID]struct{}{}
}
Expand All @@ -46,13 +52,19 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al

// SelectAllRoutes sets the selector to select all routes.
func (rs *RouteSelector) SelectAllRoutes() {
rs.mu.Lock()
defer rs.mu.Unlock()

rs.selectAll = true
rs.selectedRoutes = map[route.NetID]struct{}{}
}

// DeselectRoutes removes specific routes from the selection.
// If the selector is in "select all" mode, it will transition to "select specific" mode.
func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route.NetID) error {
rs.mu.Lock()
defer rs.mu.Unlock()

if rs.selectAll {
rs.selectAll = false
rs.selectedRoutes = map[route.NetID]struct{}{}
Expand All @@ -76,12 +88,18 @@ func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route.

// DeselectAllRoutes deselects all routes, effectively disabling route selection.
func (rs *RouteSelector) DeselectAllRoutes() {
rs.mu.Lock()
defer rs.mu.Unlock()

rs.selectAll = false
rs.selectedRoutes = map[route.NetID]struct{}{}
}

// IsSelected checks if a specific route is selected.
func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
rs.mu.RLock()
defer rs.mu.RUnlock()

if rs.selectAll {
return true
}
Expand All @@ -91,6 +109,9 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {

// FilterSelected removes unselected routes from the provided map.
func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
rs.mu.RLock()
defer rs.mu.RUnlock()

if rs.selectAll {
return maps.Clone(routes)
}
Expand All @@ -103,3 +124,49 @@ func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
}
return filtered
}

// MarshalJSON implements the json.Marshaler interface
func (rs *RouteSelector) MarshalJSON() ([]byte, error) {
rs.mu.RLock()
defer rs.mu.RUnlock()

return json.Marshal(struct {
SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"`
SelectAll bool `json:"select_all"`
}{
SelectAll: rs.selectAll,
SelectedRoutes: rs.selectedRoutes,
})
}

// UnmarshalJSON implements the json.Unmarshaler interface
// If the JSON is empty or null, it will initialize like a NewRouteSelector.
func (rs *RouteSelector) UnmarshalJSON(data []byte) error {
rs.mu.Lock()
defer rs.mu.Unlock()

// Check for null or empty JSON
if len(data) == 0 || string(data) == "null" {
rs.selectedRoutes = map[route.NetID]struct{}{}
rs.selectAll = true
return nil
}

var temp struct {
SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"`
SelectAll bool `json:"select_all"`
}

if err := json.Unmarshal(data, &temp); err != nil {
return err
}

rs.selectedRoutes = temp.SelectedRoutes
rs.selectAll = temp.SelectAll

if rs.selectedRoutes == nil {
rs.selectedRoutes = map[route.NetID]struct{}{}
}

return nil
}
Loading

0 comments on commit 5142dc5

Please sign in to comment.