Skip to content

Commit

Permalink
[management] Refactor policy to use store methods (#2878)
Browse files Browse the repository at this point in the history
  • Loading branch information
bcmmbaga authored Nov 26, 2024
1 parent ca12bc6 commit f118d81
Show file tree
Hide file tree
Showing 18 changed files with 548 additions and 321 deletions.
2 changes: 1 addition & 1 deletion management/server/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ type AccountManager interface {
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error
GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error)
SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error
SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) (*Policy, error)
DeletePolicy(ctx context.Context, accountID, policyID, userID string) error
ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error)
GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
Expand Down
57 changes: 26 additions & 31 deletions management/server/account_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1238,8 +1238,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
return
}

policy := Policy{
ID: "policy",
_, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
Enabled: true,
Rules: []*PolicyRule{
{
Expand All @@ -1250,8 +1249,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
Action: PolicyTrafficActionAccept,
},
},
}
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
})
require.NoError(t, err)

updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
Expand Down Expand Up @@ -1320,19 +1318,6 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)

policy := Policy{
Enabled: true,
Rules: []*PolicyRule{
{
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupA"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
},
},
}

wg := sync.WaitGroup{}
wg.Add(1)
go func() {
Expand All @@ -1345,7 +1330,19 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
}
}()

if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil {
_, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
Enabled: true,
Rules: []*PolicyRule{
{
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupA"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
},
},
})
if err != nil {
t.Errorf("delete default rule: %v", err)
return
}
Expand All @@ -1366,7 +1363,7 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
return
}

policy := Policy{
_, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
Enabled: true,
Rules: []*PolicyRule{
{
Expand All @@ -1377,9 +1374,8 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
Action: PolicyTrafficActionAccept,
},
},
}

if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil {
})
if err != nil {
t.Errorf("save policy: %v", err)
return
}
Expand Down Expand Up @@ -1421,7 +1417,12 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {

require.NoError(t, err, "failed to save group")

policy := Policy{
if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil {
t.Errorf("delete default rule: %v", err)
return
}

policy, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
Enabled: true,
Rules: []*PolicyRule{
{
Expand All @@ -1432,14 +1433,8 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
Action: PolicyTrafficActionAccept,
},
},
}

if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil {
t.Errorf("delete default rule: %v", err)
return
}

if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil {
})
if err != nil {
t.Errorf("save policy: %v", err)
return
}
Expand Down
5 changes: 2 additions & 3 deletions management/server/group_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -500,8 +500,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
})

// adding a group to policy
err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
ID: "policy",
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
Enabled: true,
Rules: []*PolicyRule{
{
Expand All @@ -512,7 +511,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
Action: PolicyTrafficActionAccept,
},
},
}, false)
})
assert.NoError(t, err)

// Saving a group linked to policy should update account peers and send peer update
Expand Down
26 changes: 13 additions & 13 deletions management/server/http/policies_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@ import (
"strconv"

"github.com/gorilla/mux"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/rs/xid"

"github.com/netbirdio/netbird/management/server"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
Expand Down Expand Up @@ -122,21 +120,22 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID
return
}

isUpdate := policyID != ""

if policyID == "" {
policyID = xid.New().String()
}

policy := server.Policy{
policy := &server.Policy{
ID: policyID,
AccountID: accountID,
Name: req.Name,
Enabled: req.Enabled,
Description: req.Description,
}
for _, rule := range req.Rules {
var ruleID string
if rule.Id != nil {
ruleID = *rule.Id
}

pr := server.PolicyRule{
ID: policyID, // TODO: when policy can contain multiple rules, need refactor
ID: ruleID,
PolicyID: policyID,
Name: rule.Name,
Destinations: rule.Destinations,
Sources: rule.Sources,
Expand Down Expand Up @@ -225,7 +224,8 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID
policy.SourcePostureChecks = *req.SourcePostureChecks
}

if err := h.accountManager.SavePolicy(r.Context(), accountID, userID, &policy, isUpdate); err != nil {
policy, err := h.accountManager.SavePolicy(r.Context(), accountID, userID, policy)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
Expand All @@ -236,7 +236,7 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID
return
}

resp := toPolicyResponse(allGroups, &policy)
resp := toPolicyResponse(allGroups, policy)
if len(resp.Rules) == 0 {
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
return
Expand Down
4 changes: 2 additions & 2 deletions management/server/http/policies_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies {
}
return policy, nil
},
SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy, _ bool) error {
SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy) (*server.Policy, error) {
if !strings.HasPrefix(policy.ID, "id-") {
policy.ID = "id-was-set"
policy.Rules[0].ID = "id-was-set"
}
return nil
return policy, nil
},
GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) {
return []*nbgroup.Group{{ID: "F"}, {ID: "G"}}, nil
Expand Down
8 changes: 4 additions & 4 deletions management/server/mock_server/account_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ type MockAccountManager struct {
GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error
GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error)
SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error
SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy) (*server.Policy, error)
DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error
ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error)
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error)
Expand Down Expand Up @@ -386,11 +386,11 @@ func (am *MockAccountManager) GetPolicy(ctx context.Context, accountID, policyID
}

// SavePolicy mock implementation of SavePolicy from server.AccountManager interface
func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error {
func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy) (*server.Policy, error) {
if am.SavePolicyFunc != nil {
return am.SavePolicyFunc(ctx, accountID, userID, policy, isUpdate)
return am.SavePolicyFunc(ctx, accountID, userID, policy)
}
return status.Errorf(codes.Unimplemented, "method SavePolicy is not implemented")
return nil, status.Errorf(codes.Unimplemented, "method SavePolicy is not implemented")
}

// DeletePolicy mock implementation of DeletePolicy from server.AccountManager interface
Expand Down
31 changes: 15 additions & 16 deletions management/server/peer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -283,14 +283,12 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
var (
group1 nbgroup.Group
group2 nbgroup.Group
policy Policy
)

group1.ID = xid.New().String()
group2.ID = xid.New().String()
group1.Name = "src"
group2.Name = "dst"
policy.ID = xid.New().String()
group1.Peers = append(group1.Peers, peer1.ID)
group2.Peers = append(group2.Peers, peer2.ID)

Expand All @@ -305,18 +303,20 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
return
}

policy.Name = "test"
policy.Enabled = true
policy.Rules = []*PolicyRule{
{
Enabled: true,
Sources: []string{group1.ID},
Destinations: []string{group2.ID},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
policy := &Policy{
Name: "test",
Enabled: true,
Rules: []*PolicyRule{
{
Enabled: true,
Sources: []string{group1.ID},
Destinations: []string{group2.ID},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
},
},
}
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
policy, err = manager.SavePolicy(context.Background(), account.Id, userID, policy)
if err != nil {
t.Errorf("expecting rule to be added, got failure %v", err)
return
Expand Down Expand Up @@ -364,7 +364,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
}

policy.Enabled = false
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
_, err = manager.SavePolicy(context.Background(), account.Id, userID, policy)
if err != nil {
t.Errorf("expecting rule to be added, got failure %v", err)
return
Expand Down Expand Up @@ -1445,8 +1445,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {

// Adding peer to group linked with policy should update account peers and send peer update
t.Run("adding peer to group linked with policy", func(t *testing.T) {
err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
ID: "policy",
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
Enabled: true,
Rules: []*PolicyRule{
{
Expand All @@ -1457,7 +1456,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
Action: PolicyTrafficActionAccept,
},
},
}, false)
})
require.NoError(t, err)

done := make(chan struct{})
Expand Down
Loading

0 comments on commit f118d81

Please sign in to comment.