Skip to content

Commit

Permalink
fix(store): allow dynamic ORDER BY options, improve resource search, …
Browse files Browse the repository at this point in the history
…handle case where empty ORDER BY was appended to query (#152)

* fix(store): allow dynamic ORDER BY options

* fix(store): improve resource search, handle case where empty ORDER BY was appended to query

* chore: resolve comments

* test(store): update tests

* chore: refactor code

* test(store): update tests

* test(store): update tests

* chore: resolve comments
  • Loading branch information
abhishekv24 authored May 28, 2024
1 parent 5aeae47 commit 6e243b2
Show file tree
Hide file tree
Showing 6 changed files with 166 additions and 45 deletions.
29 changes: 21 additions & 8 deletions internal/store/postgres/appeal_repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,11 @@ func (r *AppealRepository) Find(ctx context.Context, filters *domain.ListAppeals
}

db := r.db.WithContext(ctx)
db = applyAppealFilter(db, filters)
var err error
db, err = applyAppealFilter(db, filters)
if err != nil {
return nil, err
}

var models []*model.Appeal
if err := db.Joins("Grant").Find(&models).Error; err != nil {
Expand All @@ -98,10 +102,13 @@ func (r *AppealRepository) GetAppealsTotalCount(ctx context.Context, filter *dom
appealFilters.Size = 0
appealFilters.Offset = 0

db = applyAppealFilter(db, &appealFilters)

var err error
db, err = applyAppealFilter(db, &appealFilters)
if err != nil {
return 0, err
}
var count int64
err := db.Model(&model.Appeal{}).Count(&count).Error
err = db.Model(&model.Appeal{}).Count(&count).Error

return count, err
}
Expand Down Expand Up @@ -165,7 +172,7 @@ func (r *AppealRepository) Update(ctx context.Context, a *domain.Appeal) error {
})
}

func applyAppealFilter(db *gorm.DB, filters *domain.ListAppealsFilter) *gorm.DB {
func applyAppealFilter(db *gorm.DB, filters *domain.ListAppealsFilter) (*gorm.DB, error) {
db = db.Joins("JOIN resources ON appeals.resource_id = resources.id")
if filters.Q != "" {
// NOTE: avoid adding conditions before this grouped where clause.
Expand Down Expand Up @@ -222,10 +229,16 @@ func applyAppealFilter(db *gorm.DB, filters *domain.ListAppealsFilter) *gorm.DB
db = db.Where(`"options" -> 'expiration_date' > ?`, filters.ExpirationDateGreaterThan)
}
if filters.OrderBy != nil {
db = addOrderByClause(db, filters.OrderBy, addOrderByClauseOptions{
var err error
db, err = addOrderByClause(db, filters.OrderBy, addOrderByClauseOptions{
statusColumnName: `"appeals"."status"`,
statusesOrder: AppealStatusDefaultSort,
})
},
[]string{"updated_at", "created_at"})

if err != nil {
return nil, err
}
}

db = db.Joins("Resource")
Expand All @@ -242,5 +255,5 @@ func applyAppealFilter(db *gorm.DB, filters *domain.ListAppealsFilter) *gorm.DB
db = db.Where(`"Resource"."urn" IN ?`, filters.ResourceURNs)
}

return db
return db, nil
}
27 changes: 20 additions & 7 deletions internal/store/postgres/approval_repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,11 @@ func (r *ApprovalRepository) ListApprovals(ctx context.Context, filter *domain.L
records := []*domain.Approval{}

db := r.db.WithContext(ctx)
db = applyFilter(db, filter)
var err error
db, err = applyFilter(db, filter)
if err != nil {
return nil, err
}
if filter.Size > 0 {
db = db.Limit(filter.Size)
}
Expand Down Expand Up @@ -68,8 +72,11 @@ func (r *ApprovalRepository) GetApprovalsTotalCount(ctx context.Context, filter
f := *filter
f.Size = 0
f.Offset = 0
db = applyFilter(db, &f)

var err error
db, err = applyFilter(db, &f)
if err != nil {
return 0, err
}
var count int64
if err := db.Model(&model.Approval{}).Count(&count).Error; err != nil {
return 0, err
Expand Down Expand Up @@ -139,7 +146,7 @@ func (r *ApprovalRepository) DeleteApprover(ctx context.Context, approvalID, ema
return nil
}

func applyFilter(db *gorm.DB, filter *domain.ListApprovalsFilter) *gorm.DB {
func applyFilter(db *gorm.DB, filter *domain.ListApprovalsFilter) (*gorm.DB, error) {
db = db.Joins("Appeal").
Joins("Appeal.Resource").
Joins(`JOIN "approvers" ON "approvals"."id" = "approvers"."approval_id"`)
Expand Down Expand Up @@ -176,11 +183,17 @@ func applyFilter(db *gorm.DB, filter *domain.ListApprovalsFilter) *gorm.DB {
}

if filter.OrderBy != nil {
db = addOrderByClause(db, filter.OrderBy, addOrderByClauseOptions{
var err error
db, err = addOrderByClause(db, filter.OrderBy, addOrderByClauseOptions{
statusColumnName: `"approvals"."status"`,
statusesOrder: AppealStatusDefaultSort,
})
},
[]string{"updated_at", "created_at"})

if err != nil {
return nil, err
}
}

return db
return db, nil
}
28 changes: 21 additions & 7 deletions internal/store/postgres/grant_repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@ func NewGrantRepository(db *gorm.DB) *GrantRepository {

func (r *GrantRepository) List(ctx context.Context, filter domain.ListGrantsFilter) ([]domain.Grant, error) {
db := r.db.WithContext(ctx)
db = applyGrantFilter(db, filter)
var err error
db, err = applyGrantFilter(db, filter)
if err != nil {
return nil, err
}

var models []model.Grant
if err := db.Joins("Resource").Joins("Appeal").Find(&models).Error; err != nil {
Expand All @@ -56,9 +60,13 @@ func (r *GrantRepository) GetGrantsTotalCount(ctx context.Context, filter domain
grantFilters.Size = 0
grantFilters.Offset = 0

db = applyGrantFilter(db, grantFilters)
var err error
db, err = applyGrantFilter(db, grantFilters)
if err != nil {
return 0, err
}
var count int64
err := db.Model(&model.Grant{}).Count(&count).Error
err = db.Model(&model.Grant{}).Count(&count).Error

return count, err
}
Expand Down Expand Up @@ -212,7 +220,7 @@ func upsertResources(tx *gorm.DB, models []*model.Grant) error {
return nil
}

func applyGrantFilter(db *gorm.DB, filter domain.ListGrantsFilter) *gorm.DB {
func applyGrantFilter(db *gorm.DB, filter domain.ListGrantsFilter) (*gorm.DB, error) {
db = db.Joins("JOIN resources ON grants.resource_id = resources.id")
if filter.Q != "" {
// NOTE: avoid adding conditions before this grouped where clause.
Expand Down Expand Up @@ -262,10 +270,16 @@ func applyGrantFilter(db *gorm.DB, filter domain.ListGrantsFilter) *gorm.DB {
db = db.Where(`"grants"."created_at" <= ?`, filter.CreatedAtLte)
}
if filter.OrderBy != nil {
db = addOrderByClause(db, filter.OrderBy, addOrderByClauseOptions{
var err error
db, err = addOrderByClause(db, filter.OrderBy, addOrderByClauseOptions{
statusColumnName: `"grants"."status"`,
statusesOrder: GrantStatusDefaultSort,
})
},
[]string{"updated_at", "created_at"})

if err != nil {
return nil, err
}
}
if !filter.ExpirationDateLessThan.IsZero() {
db = db.Where(`"grants"."expiration_date" < ?`, filter.ExpirationDateLessThan)
Expand All @@ -286,5 +300,5 @@ func applyGrantFilter(db *gorm.DB, filter domain.ListGrantsFilter) *gorm.DB {
if filter.ResourceURNs != nil {
db = db.Where(`"Resource"."urn" IN ?`, filter.ResourceURNs)
}
return db
return db, nil
}
29 changes: 22 additions & 7 deletions internal/store/postgres/resource_repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@ func (r *ResourceRepository) Find(ctx context.Context, filter domain.ListResourc
}

db := r.db.WithContext(ctx)
db = applyResourceFilter(db, filter)
var err error
db, err = applyResourceFilter(db, filter)
if err != nil {
return nil, err
}
var models []*model.Resource
if err := db.Find(&models).Error; err != nil {
return nil, err
Expand All @@ -61,14 +65,18 @@ func (r *ResourceRepository) GetResourcesTotalCount(ctx context.Context, filter
f := filter
f.Size = 0
f.Offset = 0
db = applyResourceFilter(db, f)
var err error
db, err = applyResourceFilter(db, f)
if err != nil {
return 0, err
}
var count int64
err := db.Model(&model.Resource{}).Count(&count).Error
err = db.Model(&model.Resource{}).Count(&count).Error

return count, err
}

func applyResourceFilter(db *gorm.DB, filter domain.ListResourcesFilter) *gorm.DB {
func applyResourceFilter(db *gorm.DB, filter domain.ListResourcesFilter) (*gorm.DB, error) {
if filter.Q != "" {
// NOTE: avoid adding conditions before this grouped where clause.
// Otherwise, it will be wrapped in parentheses and the query will be invalid.
Expand Down Expand Up @@ -125,17 +133,24 @@ func applyResourceFilter(db *gorm.DB, filter domain.ListResourcesFilter) *gorm.D
}

if len(sortOrder) != 0 {
db = addOrderByClause(db, sortOrder, addOrderByClauseOptions{
var err error
db, err = addOrderByClause(db, sortOrder, addOrderByClauseOptions{
statusColumnName: "",
statusesOrder: []string{},
})
searchQuery: filter.Q,
},
[]string{"updated_at", "created_at", "name", "urn", "global_urn"})

if err != nil {
return nil, err
}
}

for path, v := range filter.Details {
pathArr := "{" + strings.Join(strings.Split(path, "."), ",") + "}"
db = db.Where(`"details" #>> ? = ?`, pathArr, v)
}
return db
return db, nil
}

// GetOne record by ID
Expand Down
70 changes: 65 additions & 5 deletions internal/store/postgres/resource_repository_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,24 @@ func (s *ResourceRepositoryTestSuite) TestFind() {
CreatedAt: time.Now().Add(10 * time.Minute),
GlobalURN: "global_urn_2",
},
{
ProviderType: s.dummyProvider.Type,
ProviderURN: s.dummyProvider.URN,
Type: "test_type_2",
URN: "test_exact_urn_match",
Name: "test_exact_name_match",
CreatedAt: time.Now().Add(15 * time.Minute),
GlobalURN: "global_urn_4",
},
{
ProviderType: s.dummyProvider.Type,
ProviderURN: s.dummyProvider.URN,
Type: "test_type_2",
URN: "test_exact_urn_match_2",
Name: "test_exact_name_match_2",
CreatedAt: time.Now().Add(20 * time.Minute),
GlobalURN: "global_urn_3",
},
}
err := s.repository.BulkUpsert(context.Background(), dummyResources)
s.Require().NoError(err)
Expand Down Expand Up @@ -130,7 +148,7 @@ func (s *ResourceRepositoryTestSuite) TestFind() {
filters: domain.ListResourcesFilter{
ResourceType: "test_type",
},
expectedResult: dummyResources,
expectedResult: []*domain.Resource{dummyResources[0], dummyResources[1]},
},
{
name: "filter by name",
Expand Down Expand Up @@ -175,29 +193,43 @@ func (s *ResourceRepositoryTestSuite) TestFind() {
Size: 1,
Offset: 0,
},
expectedResult: []*domain.Resource{dummyResources[1]},
expectedResult: []*domain.Resource{dummyResources[3]},
},
{
name: "filter by size and offset 1",
filters: domain.ListResourcesFilter{
Size: 1,
Offset: 1,
},
expectedResult: []*domain.Resource{dummyResources[0]},
expectedResult: []*domain.Resource{dummyResources[2]},
},
{
name: "filter by size only",
filters: domain.ListResourcesFilter{
Size: 1,
},
expectedResult: []*domain.Resource{dummyResources[1]},
expectedResult: []*domain.Resource{dummyResources[3]},
},
{
name: "Order by created at desc",
filters: domain.ListResourcesFilter{
OrderBy: []string{"created_at:desc"},
},
expectedResult: []*domain.Resource{dummyResources[1], dummyResources[0]},
expectedResult: []*domain.Resource{dummyResources[3], dummyResources[2], dummyResources[1], dummyResources[0]},
},
{
name: "filter by urns",
filters: domain.ListResourcesFilter{
ResourceURNs: []string{"test_urn_1", "test_urn_2"},
},
expectedResult: []*domain.Resource{dummyResources[0], dummyResources[1]},
},
{
name: "filter by resource types",
filters: domain.ListResourcesFilter{
ResourceTypes: []string{"test_type"},
},
expectedResult: []*domain.Resource{dummyResources[0], dummyResources[1]},
},
}

Expand All @@ -217,6 +249,28 @@ func (s *ResourceRepositoryTestSuite) TestFind() {
}
})

s.Run("should return exact name matching resource on top", func() {
exact_match_filter := domain.ListResourcesFilter{
Q: "test_exact_name_match",
OrderBy: []string{"name:exact_asc"},
}

actualResult, actualError := s.repository.Find(context.Background(), exact_match_filter)
s.NoError(actualError)
s.Equal("test_exact_name_match", actualResult[0].Name)
})

s.Run("should return error when invalid order by direction is passed", func() {
exact_match_filter := domain.ListResourcesFilter{
Q: "test_exact_name_match",
OrderBy: []string{"name:test"},
}

actualResult, actualError := s.repository.Find(context.Background(), exact_match_filter)
s.Error(actualError)
s.Nil(actualResult)
})

s.Run("should return error if filters validation returns an error", func() {
invalidFilters := domain.ListResourcesFilter{
IDs: []string{},
Expand Down Expand Up @@ -277,6 +331,12 @@ func (s *ResourceRepositoryTestSuite) TestGetResourcesTotalCount() {

s.Nil(actualError)
})

s.Run("should return error", func() {
_, actualError := s.repository.GetResourcesTotalCount(context.Background(), domain.ListResourcesFilter{OrderBy: []string{"name:test"}})

s.Error(actualError)
})
}

func (s *ResourceRepositoryTestSuite) TestBulkUpsert() {
Expand Down
Loading

0 comments on commit 6e243b2

Please sign in to comment.