diff --git a/internal/store/postgres/appeal_repository.go b/internal/store/postgres/appeal_repository.go index e5ed11eed..2721e63a9 100644 --- a/internal/store/postgres/appeal_repository.go +++ b/internal/store/postgres/appeal_repository.go @@ -71,7 +71,11 @@ func (r *AppealRepository) Find(ctx context.Context, filters *domain.ListAppeals } db := r.db.WithContext(ctx) - db = applyAppealFilter(db, filters) + var err error + db, err = applyAppealFilter(db, filters) + if err != nil { + return nil, err + } var models []*model.Appeal if err := db.Joins("Grant").Find(&models).Error; err != nil { @@ -98,10 +102,13 @@ func (r *AppealRepository) GetAppealsTotalCount(ctx context.Context, filter *dom appealFilters.Size = 0 appealFilters.Offset = 0 - db = applyAppealFilter(db, &appealFilters) - + var err error + db, err = applyAppealFilter(db, &appealFilters) + if err != nil { + return 0, err + } var count int64 - err := db.Model(&model.Appeal{}).Count(&count).Error + err = db.Model(&model.Appeal{}).Count(&count).Error return count, err } @@ -165,7 +172,7 @@ func (r *AppealRepository) Update(ctx context.Context, a *domain.Appeal) error { }) } -func applyAppealFilter(db *gorm.DB, filters *domain.ListAppealsFilter) *gorm.DB { +func applyAppealFilter(db *gorm.DB, filters *domain.ListAppealsFilter) (*gorm.DB, error) { db = db.Joins("JOIN resources ON appeals.resource_id = resources.id") if filters.Q != "" { // NOTE: avoid adding conditions before this grouped where clause. @@ -222,10 +229,16 @@ func applyAppealFilter(db *gorm.DB, filters *domain.ListAppealsFilter) *gorm.DB db = db.Where(`"options" -> 'expiration_date' > ?`, filters.ExpirationDateGreaterThan) } if filters.OrderBy != nil { - db = addOrderByClause(db, filters.OrderBy, addOrderByClauseOptions{ + var err error + db, err = addOrderByClause(db, filters.OrderBy, addOrderByClauseOptions{ statusColumnName: `"appeals"."status"`, statusesOrder: AppealStatusDefaultSort, - }) + }, + []string{"updated_at", "created_at"}) + + if err != nil { + return nil, err + } } db = db.Joins("Resource") @@ -242,5 +255,5 @@ func applyAppealFilter(db *gorm.DB, filters *domain.ListAppealsFilter) *gorm.DB db = db.Where(`"Resource"."urn" IN ?`, filters.ResourceURNs) } - return db + return db, nil } diff --git a/internal/store/postgres/approval_repository.go b/internal/store/postgres/approval_repository.go index 636c98ad0..94a31fdab 100644 --- a/internal/store/postgres/approval_repository.go +++ b/internal/store/postgres/approval_repository.go @@ -37,7 +37,11 @@ func (r *ApprovalRepository) ListApprovals(ctx context.Context, filter *domain.L records := []*domain.Approval{} db := r.db.WithContext(ctx) - db = applyFilter(db, filter) + var err error + db, err = applyFilter(db, filter) + if err != nil { + return nil, err + } if filter.Size > 0 { db = db.Limit(filter.Size) } @@ -68,8 +72,11 @@ func (r *ApprovalRepository) GetApprovalsTotalCount(ctx context.Context, filter f := *filter f.Size = 0 f.Offset = 0 - db = applyFilter(db, &f) - + var err error + db, err = applyFilter(db, &f) + if err != nil { + return 0, err + } var count int64 if err := db.Model(&model.Approval{}).Count(&count).Error; err != nil { return 0, err @@ -139,7 +146,7 @@ func (r *ApprovalRepository) DeleteApprover(ctx context.Context, approvalID, ema return nil } -func applyFilter(db *gorm.DB, filter *domain.ListApprovalsFilter) *gorm.DB { +func applyFilter(db *gorm.DB, filter *domain.ListApprovalsFilter) (*gorm.DB, error) { db = db.Joins("Appeal"). Joins("Appeal.Resource"). Joins(`JOIN "approvers" ON "approvals"."id" = "approvers"."approval_id"`) @@ -176,11 +183,17 @@ func applyFilter(db *gorm.DB, filter *domain.ListApprovalsFilter) *gorm.DB { } if filter.OrderBy != nil { - db = addOrderByClause(db, filter.OrderBy, addOrderByClauseOptions{ + var err error + db, err = addOrderByClause(db, filter.OrderBy, addOrderByClauseOptions{ statusColumnName: `"approvals"."status"`, statusesOrder: AppealStatusDefaultSort, - }) + }, + []string{"updated_at", "created_at"}) + + if err != nil { + return nil, err + } } - return db + return db, nil } diff --git a/internal/store/postgres/grant_repository.go b/internal/store/postgres/grant_repository.go index 275892f5e..40e04eaa4 100644 --- a/internal/store/postgres/grant_repository.go +++ b/internal/store/postgres/grant_repository.go @@ -30,7 +30,11 @@ func NewGrantRepository(db *gorm.DB) *GrantRepository { func (r *GrantRepository) List(ctx context.Context, filter domain.ListGrantsFilter) ([]domain.Grant, error) { db := r.db.WithContext(ctx) - db = applyGrantFilter(db, filter) + var err error + db, err = applyGrantFilter(db, filter) + if err != nil { + return nil, err + } var models []model.Grant if err := db.Joins("Resource").Joins("Appeal").Find(&models).Error; err != nil { @@ -56,9 +60,13 @@ func (r *GrantRepository) GetGrantsTotalCount(ctx context.Context, filter domain grantFilters.Size = 0 grantFilters.Offset = 0 - db = applyGrantFilter(db, grantFilters) + var err error + db, err = applyGrantFilter(db, grantFilters) + if err != nil { + return 0, err + } var count int64 - err := db.Model(&model.Grant{}).Count(&count).Error + err = db.Model(&model.Grant{}).Count(&count).Error return count, err } @@ -212,7 +220,7 @@ func upsertResources(tx *gorm.DB, models []*model.Grant) error { return nil } -func applyGrantFilter(db *gorm.DB, filter domain.ListGrantsFilter) *gorm.DB { +func applyGrantFilter(db *gorm.DB, filter domain.ListGrantsFilter) (*gorm.DB, error) { db = db.Joins("JOIN resources ON grants.resource_id = resources.id") if filter.Q != "" { // NOTE: avoid adding conditions before this grouped where clause. @@ -262,10 +270,16 @@ func applyGrantFilter(db *gorm.DB, filter domain.ListGrantsFilter) *gorm.DB { db = db.Where(`"grants"."created_at" <= ?`, filter.CreatedAtLte) } if filter.OrderBy != nil { - db = addOrderByClause(db, filter.OrderBy, addOrderByClauseOptions{ + var err error + db, err = addOrderByClause(db, filter.OrderBy, addOrderByClauseOptions{ statusColumnName: `"grants"."status"`, statusesOrder: GrantStatusDefaultSort, - }) + }, + []string{"updated_at", "created_at"}) + + if err != nil { + return nil, err + } } if !filter.ExpirationDateLessThan.IsZero() { db = db.Where(`"grants"."expiration_date" < ?`, filter.ExpirationDateLessThan) @@ -286,5 +300,5 @@ func applyGrantFilter(db *gorm.DB, filter domain.ListGrantsFilter) *gorm.DB { if filter.ResourceURNs != nil { db = db.Where(`"Resource"."urn" IN ?`, filter.ResourceURNs) } - return db + return db, nil } diff --git a/internal/store/postgres/resource_repository.go b/internal/store/postgres/resource_repository.go index e81ec04f4..e84436a49 100644 --- a/internal/store/postgres/resource_repository.go +++ b/internal/store/postgres/resource_repository.go @@ -36,7 +36,11 @@ func (r *ResourceRepository) Find(ctx context.Context, filter domain.ListResourc } db := r.db.WithContext(ctx) - db = applyResourceFilter(db, filter) + var err error + db, err = applyResourceFilter(db, filter) + if err != nil { + return nil, err + } var models []*model.Resource if err := db.Find(&models).Error; err != nil { return nil, err @@ -61,14 +65,18 @@ func (r *ResourceRepository) GetResourcesTotalCount(ctx context.Context, filter f := filter f.Size = 0 f.Offset = 0 - db = applyResourceFilter(db, f) + var err error + db, err = applyResourceFilter(db, f) + if err != nil { + return 0, err + } var count int64 - err := db.Model(&model.Resource{}).Count(&count).Error + err = db.Model(&model.Resource{}).Count(&count).Error return count, err } -func applyResourceFilter(db *gorm.DB, filter domain.ListResourcesFilter) *gorm.DB { +func applyResourceFilter(db *gorm.DB, filter domain.ListResourcesFilter) (*gorm.DB, error) { if filter.Q != "" { // NOTE: avoid adding conditions before this grouped where clause. // Otherwise, it will be wrapped in parentheses and the query will be invalid. @@ -125,17 +133,24 @@ func applyResourceFilter(db *gorm.DB, filter domain.ListResourcesFilter) *gorm.D } if len(sortOrder) != 0 { - db = addOrderByClause(db, sortOrder, addOrderByClauseOptions{ + var err error + db, err = addOrderByClause(db, sortOrder, addOrderByClauseOptions{ statusColumnName: "", statusesOrder: []string{}, - }) + searchQuery: filter.Q, + }, + []string{"updated_at", "created_at", "name", "urn", "global_urn"}) + + if err != nil { + return nil, err + } } for path, v := range filter.Details { pathArr := "{" + strings.Join(strings.Split(path, "."), ",") + "}" db = db.Where(`"details" #>> ? = ?`, pathArr, v) } - return db + return db, nil } // GetOne record by ID diff --git a/internal/store/postgres/resource_repository_test.go b/internal/store/postgres/resource_repository_test.go index 6191b45e4..85190dc4f 100644 --- a/internal/store/postgres/resource_repository_test.go +++ b/internal/store/postgres/resource_repository_test.go @@ -92,6 +92,24 @@ func (s *ResourceRepositoryTestSuite) TestFind() { CreatedAt: time.Now().Add(10 * time.Minute), GlobalURN: "global_urn_2", }, + { + ProviderType: s.dummyProvider.Type, + ProviderURN: s.dummyProvider.URN, + Type: "test_type_2", + URN: "test_exact_urn_match", + Name: "test_exact_name_match", + CreatedAt: time.Now().Add(15 * time.Minute), + GlobalURN: "global_urn_4", + }, + { + ProviderType: s.dummyProvider.Type, + ProviderURN: s.dummyProvider.URN, + Type: "test_type_2", + URN: "test_exact_urn_match_2", + Name: "test_exact_name_match_2", + CreatedAt: time.Now().Add(20 * time.Minute), + GlobalURN: "global_urn_3", + }, } err := s.repository.BulkUpsert(context.Background(), dummyResources) s.Require().NoError(err) @@ -130,7 +148,7 @@ func (s *ResourceRepositoryTestSuite) TestFind() { filters: domain.ListResourcesFilter{ ResourceType: "test_type", }, - expectedResult: dummyResources, + expectedResult: []*domain.Resource{dummyResources[0], dummyResources[1]}, }, { name: "filter by name", @@ -175,7 +193,7 @@ func (s *ResourceRepositoryTestSuite) TestFind() { Size: 1, Offset: 0, }, - expectedResult: []*domain.Resource{dummyResources[1]}, + expectedResult: []*domain.Resource{dummyResources[3]}, }, { name: "filter by size and offset 1", @@ -183,21 +201,35 @@ func (s *ResourceRepositoryTestSuite) TestFind() { Size: 1, Offset: 1, }, - expectedResult: []*domain.Resource{dummyResources[0]}, + expectedResult: []*domain.Resource{dummyResources[2]}, }, { name: "filter by size only", filters: domain.ListResourcesFilter{ Size: 1, }, - expectedResult: []*domain.Resource{dummyResources[1]}, + expectedResult: []*domain.Resource{dummyResources[3]}, }, { name: "Order by created at desc", filters: domain.ListResourcesFilter{ OrderBy: []string{"created_at:desc"}, }, - expectedResult: []*domain.Resource{dummyResources[1], dummyResources[0]}, + expectedResult: []*domain.Resource{dummyResources[3], dummyResources[2], dummyResources[1], dummyResources[0]}, + }, + { + name: "filter by urns", + filters: domain.ListResourcesFilter{ + ResourceURNs: []string{"test_urn_1", "test_urn_2"}, + }, + expectedResult: []*domain.Resource{dummyResources[0], dummyResources[1]}, + }, + { + name: "filter by resource types", + filters: domain.ListResourcesFilter{ + ResourceTypes: []string{"test_type"}, + }, + expectedResult: []*domain.Resource{dummyResources[0], dummyResources[1]}, }, } @@ -217,6 +249,28 @@ func (s *ResourceRepositoryTestSuite) TestFind() { } }) + s.Run("should return exact name matching resource on top", func() { + exact_match_filter := domain.ListResourcesFilter{ + Q: "test_exact_name_match", + OrderBy: []string{"name:exact_asc"}, + } + + actualResult, actualError := s.repository.Find(context.Background(), exact_match_filter) + s.NoError(actualError) + s.Equal("test_exact_name_match", actualResult[0].Name) + }) + + s.Run("should return error when invalid order by direction is passed", func() { + exact_match_filter := domain.ListResourcesFilter{ + Q: "test_exact_name_match", + OrderBy: []string{"name:test"}, + } + + actualResult, actualError := s.repository.Find(context.Background(), exact_match_filter) + s.Error(actualError) + s.Nil(actualResult) + }) + s.Run("should return error if filters validation returns an error", func() { invalidFilters := domain.ListResourcesFilter{ IDs: []string{}, @@ -277,6 +331,12 @@ func (s *ResourceRepositoryTestSuite) TestGetResourcesTotalCount() { s.Nil(actualError) }) + + s.Run("should return error", func() { + _, actualError := s.repository.GetResourcesTotalCount(context.Background(), domain.ListResourcesFilter{OrderBy: []string{"name:test"}}) + + s.Error(actualError) + }) } func (s *ResourceRepositoryTestSuite) TestBulkUpsert() { diff --git a/internal/store/postgres/utils.go b/internal/store/postgres/utils.go index 7e3a5fa8a..3b944796f 100644 --- a/internal/store/postgres/utils.go +++ b/internal/store/postgres/utils.go @@ -12,9 +12,10 @@ import ( type addOrderByClauseOptions struct { statusColumnName string statusesOrder []string + searchQuery string } -func addOrderByClause(db *gorm.DB, conditions []string, options addOrderByClauseOptions) *gorm.DB { +func addOrderByClause(db *gorm.DB, conditions []string, options addOrderByClauseOptions, allowedColumns []string) (*gorm.DB, error) { var orderByClauses []string var vars []interface{} @@ -24,15 +25,20 @@ func addOrderByClause(db *gorm.DB, conditions []string, options addOrderByClause vars = append(vars, options.statusesOrder) } else { columnOrder := strings.Split(orderBy, ":") - column := columnOrder[0] - if utils.ContainsString([]string{"updated_at", "created_at"}, column) { - if len(columnOrder) == 1 { - orderByClauses = append(orderByClauses, fmt.Sprintf(`"%s"`, column)) - } else if len(columnOrder) == 2 { - order := columnOrder[1] - if utils.ContainsString([]string{"asc", "desc"}, order) { - orderByClauses = append(orderByClauses, fmt.Sprintf(`"%s" %s`, column, order)) - } + columnName := columnOrder[0] + if !utils.ContainsString(allowedColumns, columnName) { + return nil, fmt.Errorf("cannot order by column %q", columnName) + } + if len(columnOrder) == 1 { + orderByClauses = append(orderByClauses, fmt.Sprintf(`"%s"`, columnName)) + } else if len(columnOrder) == 2 { + orderDirection := columnOrder[1] + if utils.ContainsString([]string{"asc", "desc"}, orderDirection) { + orderByClauses = append(orderByClauses, fmt.Sprintf(`"%s" %s`, columnName, orderDirection)) + } else if orderDirection == "exact_asc" && columnName == "name" { + orderByClauses = append(orderByClauses, fmt.Sprintf(`(CASE WHEN lower("%s") = '%s' THEN 1 ELSE 2 END)`, columnName, options.searchQuery)) + } else { + return nil, fmt.Errorf("invalid order by direction: %s", orderDirection) } } } @@ -44,7 +50,7 @@ func addOrderByClause(db *gorm.DB, conditions []string, options addOrderByClause Vars: vars, WithoutParentheses: true, }, - }) + }), nil } func addOrderBy(db *gorm.DB, orderBy string) *gorm.DB {