Skip to content

Commit

Permalink
[management] Refactor posture check to use store methods (#2874)
Browse files Browse the repository at this point in the history
  • Loading branch information
bcmmbaga authored Nov 25, 2024
1 parent 9810386 commit ca12bc6
Show file tree
Hide file tree
Showing 18 changed files with 563 additions and 234 deletions.
2 changes: 1 addition & 1 deletion management/server/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ type AccountManager interface {
HasConnectedChannel(peerID string) bool
GetExternalCacheManager() ExternalCacheManager
GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error
SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error)
DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error
ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
GetIdpManager() idp.Manager
Expand Down
2 changes: 1 addition & 1 deletion management/server/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
}

if anyGroupHasPeers(account, addedGroups) || anyGroupHasPeers(account, removedGroups) {
if am.anyGroupHasPeers(account, addedGroups) || am.anyGroupHasPeers(account, removedGroups) {
am.updateAccountPeers(ctx, accountID)
}

Expand Down
19 changes: 17 additions & 2 deletions management/server/group.go
Original file line number Diff line number Diff line change
Expand Up @@ -566,12 +566,27 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction Store, accountI
return false, nil
}

// anyGroupHasPeers checks if any of the given groups in the account have peers.
func anyGroupHasPeers(account *Account, groupIDs []string) bool {
func (am *DefaultAccountManager) anyGroupHasPeers(account *Account, groupIDs []string) bool {
for _, groupID := range groupIDs {
if group, exists := account.Groups[groupID]; exists && group.HasPeers() {
return true
}
}
return false
}

// anyGroupHasPeers checks if any of the given groups in the account have peers.
func anyGroupHasPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) {
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, groupIDs)
if err != nil {
return false, err
}

for _, group := range groups {
if group.HasPeers() {
return true, nil
}
}

return false, nil
}
3 changes: 2 additions & 1 deletion management/server/http/posture_checks_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,8 @@ func (p *PostureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http.
return
}

if err := p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks); err != nil {
postureChecks, err = p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
Expand Down
6 changes: 3 additions & 3 deletions management/server/http/posture_checks_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,15 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
}
return p, nil
},
SavePostureChecksFunc: func(_ context.Context, accountID, userID string, postureChecks *posture.Checks) error {
SavePostureChecksFunc: func(_ context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) {
postureChecks.ID = "postureCheck"
testPostureChecks[postureChecks.ID] = postureChecks

if err := postureChecks.Validate(); err != nil {
return status.Errorf(status.InvalidArgument, err.Error()) //nolint
return nil, status.Errorf(status.InvalidArgument, err.Error()) //nolint
}

return nil
return postureChecks, nil
},
DeletePostureChecksFunc: func(_ context.Context, accountID, postureChecksID, userID string) error {
_, ok := testPostureChecks[postureChecksID]
Expand Down
6 changes: 3 additions & 3 deletions management/server/mock_server/account_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ type MockAccountManager struct {
HasConnectedChannelFunc func(peerID string) bool
GetExternalCacheManagerFunc func() server.ExternalCacheManager
GetPostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error
SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error)
DeletePostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) error
ListPostureChecksFunc func(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
GetIdpManagerFunc func() idp.Manager
Expand Down Expand Up @@ -730,11 +730,11 @@ func (am *MockAccountManager) GetPostureChecks(ctx context.Context, accountID, p
}

// SavePostureChecks mocks SavePostureChecks of the AccountManager interface
func (am *MockAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error {
func (am *MockAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) {
if am.SavePostureChecksFunc != nil {
return am.SavePostureChecksFunc(ctx, accountID, userID, postureChecks)
}
return status.Errorf(codes.Unimplemented, "method SavePostureChecks is not implemented")
return nil, status.Errorf(codes.Unimplemented, "method SavePostureChecks is not implemented")
}

// DeletePostureChecks mocks DeletePostureChecks of the AccountManager interface
Expand Down
10 changes: 5 additions & 5 deletions management/server/nameserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
return nil, err
}

if anyGroupHasPeers(account, newNSGroup.Groups) {
if am.anyGroupHasPeers(account, newNSGroup.Groups) {
am.updateAccountPeers(ctx, accountID)
}
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
Expand Down Expand Up @@ -105,7 +105,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
return err
}

if areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) {
if am.areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) {
am.updateAccountPeers(ctx, accountID)
}
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
Expand Down Expand Up @@ -135,7 +135,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
return err
}

if anyGroupHasPeers(account, nsGroup.Groups) {
if am.anyGroupHasPeers(account, nsGroup.Groups) {
am.updateAccountPeers(ctx, accountID)
}
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
Expand Down Expand Up @@ -279,9 +279,9 @@ func validateDomain(domain string) error {
}

// areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers.
func areNameServerGroupChangesAffectPeers(account *Account, newNSGroup, oldNSGroup *nbdns.NameServerGroup) bool {
func (am *DefaultAccountManager) areNameServerGroupChangesAffectPeers(account *Account, newNSGroup, oldNSGroup *nbdns.NameServerGroup) bool {
if !newNSGroup.Enabled && !oldNSGroup.Enabled {
return false
}
return anyGroupHasPeers(account, newNSGroup.Groups) || anyGroupHasPeers(account, oldNSGroup.Groups)
return am.anyGroupHasPeers(account, newNSGroup.Groups) || am.anyGroupHasPeers(account, oldNSGroup.Groups)
}
25 changes: 21 additions & 4 deletions management/server/peer.go
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,11 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
return nil, nil, nil, err
}

postureChecks := am.getPeerPostureChecks(account, newPeer)
postureChecks, err := am.getPeerPostureChecks(ctx, account.Id, newPeer.ID)
if err != nil {
return nil, nil, nil, err
}

customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
return newPeer, networkMap, postureChecks, nil
Expand Down Expand Up @@ -702,7 +706,11 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to get validated peers: %w", err)
}
postureChecks = am.getPeerPostureChecks(account, peer)

postureChecks, err = am.getPeerPostureChecks(ctx, account.Id, peer.ID)
if err != nil {
return nil, nil, nil, err
}

customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, validPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
Expand Down Expand Up @@ -876,7 +884,11 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
if err != nil {
return nil, nil, nil, err
}
postureChecks = am.getPeerPostureChecks(account, peer)

postureChecks, err = am.getPeerPostureChecks(ctx, account.Id, peer.ID)
if err != nil {
return nil, nil, nil, err
}

customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
Expand Down Expand Up @@ -1030,7 +1042,12 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
defer wg.Done()
defer func() { <-semaphore }()

postureChecks := am.getPeerPostureChecks(account, p)
postureChecks, err := am.getPeerPostureChecks(ctx, account.Id, p.ID)
if err != nil {
log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to get peer: %s posture checks: %v", p.ID, err)
return
}

remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache)
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
Expand Down
6 changes: 3 additions & 3 deletions management/server/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po

am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta())

if anyGroupHasPeers(account, policy.ruleGroups()) {
if am.anyGroupHasPeers(account, policy.ruleGroups()) {
am.updateAccountPeers(ctx, accountID)
}

Expand Down Expand Up @@ -469,15 +469,15 @@ func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Poli
if !policyToSave.Enabled && !oldPolicy.Enabled {
return false, nil
}
updateAccountPeers := anyGroupHasPeers(account, oldPolicy.ruleGroups()) || anyGroupHasPeers(account, policyToSave.ruleGroups())
updateAccountPeers := am.anyGroupHasPeers(account, oldPolicy.ruleGroups()) || am.anyGroupHasPeers(account, policyToSave.ruleGroups())

return updateAccountPeers, nil
}

// Add the new policy to the account
account.Policies = append(account.Policies, policyToSave)

return anyGroupHasPeers(account, policyToSave.ruleGroups()), nil
return am.anyGroupHasPeers(account, policyToSave.ruleGroups()), nil
}

func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule {
Expand Down
6 changes: 0 additions & 6 deletions management/server/posture/checks.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ import (
"regexp"

"github.com/hashicorp/go-version"
"github.com/rs/xid"

"github.com/netbirdio/netbird/management/server/http/api"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
Expand Down Expand Up @@ -172,10 +170,6 @@ func NewChecksFromAPIPostureCheckUpdate(source api.PostureCheckUpdate, postureCh
}

func buildPostureCheck(postureChecksID string, name string, description string, checks api.Checks) (*Checks, error) {
if postureChecksID == "" {
postureChecksID = xid.New().String()
}

postureChecks := Checks{
ID: postureChecksID,
Name: name,
Expand Down
Loading

0 comments on commit ca12bc6

Please sign in to comment.