diff --git a/cmd/warrant/main.go b/cmd/warrant/main.go index 43b0941d..5a7e09b8 100644 --- a/cmd/warrant/main.go +++ b/cmd/warrant/main.go @@ -48,7 +48,7 @@ type ServiceEnv struct { Datastore database.Database } -func (env ServiceEnv) DB() database.Database { +func (env *ServiceEnv) DB() database.Database { return env.Datastore } @@ -56,8 +56,8 @@ func (env *ServiceEnv) InitDB(cfg config.Config) error { ctx, cancelFunc := context.WithTimeout(context.Background(), 10*time.Second) defer cancelFunc() - if cfg.GetDatastore().MySQL.Hostname != "" || cfg.GetDatastore().MySQL.DSN != "" { - db := database.NewMySQL(*cfg.GetDatastore().MySQL) + if cfg.GetDatastore().GetMySQL().Hostname != "" || cfg.GetDatastore().GetMySQL().DSN != "" { + db := database.NewMySQL(*cfg.GetDatastore().GetMySQL()) err := db.Connect(ctx) if err != nil { return err @@ -74,8 +74,8 @@ func (env *ServiceEnv) InitDB(cfg config.Config) error { return nil } - if cfg.GetDatastore().Postgres.Hostname != "" { - db := database.NewPostgres(*cfg.GetDatastore().Postgres) + if cfg.GetDatastore().GetPostgres().Hostname != "" { + db := database.NewPostgres(*cfg.GetDatastore().GetPostgres()) err := db.Connect(ctx) if err != nil { return err @@ -92,8 +92,8 @@ func (env *ServiceEnv) InitDB(cfg config.Config) error { return nil } - if cfg.GetDatastore().SQLite.Database != "" { - db := database.NewSQLite(*cfg.GetDatastore().SQLite) + if cfg.GetDatastore().GetSQLite().Database != "" { + db := database.NewSQLite(*cfg.GetDatastore().GetSQLite()) err := db.Connect(ctx) if err != nil { return err @@ -113,8 +113,8 @@ func (env *ServiceEnv) InitDB(cfg config.Config) error { return errors.New("invalid database configuration provided") } -func NewServiceEnv() ServiceEnv { - return ServiceEnv{ +func NewServiceEnv() *ServiceEnv { + return &ServiceEnv{ Datastore: nil, } } @@ -155,22 +155,22 @@ func main() { querySvc := query.NewService(svcEnv, objectTypeSvc, warrantSvc, objectSvc) // Init feature service - featureSvc := feature.NewService(&svcEnv, objectSvc) + featureSvc := feature.NewService(svcEnv, objectSvc) // Init permission service - permissionSvc := permission.NewService(&svcEnv, objectSvc) + permissionSvc := permission.NewService(svcEnv, objectSvc) // Init pricing tier service - pricingTierSvc := pricingtier.NewService(&svcEnv, objectSvc) + pricingTierSvc := pricingtier.NewService(svcEnv, objectSvc) // Init role service - roleSvc := role.NewService(&svcEnv, objectSvc) + roleSvc := role.NewService(svcEnv, objectSvc) // Init tenant service - tenantSvc := tenant.NewService(&svcEnv, objectSvc) + tenantSvc := tenant.NewService(svcEnv, objectSvc) // Init user service - userSvc := user.NewService(&svcEnv, objectSvc) + userSvc := user.NewService(svcEnv, objectSvc) svcs := []service.Service{ checkSvc, diff --git a/pkg/authz/objecttype/model.go b/pkg/authz/objecttype/model.go index 8b5b40d5..52c2f40b 100644 --- a/pkg/authz/objecttype/model.go +++ b/pkg/authz/objecttype/model.go @@ -33,12 +33,12 @@ type Model interface { } type ObjectType struct { - ID int64 `mysql:"id" postgres:"id" sqlite:"id"` - TypeId string `mysql:"typeId" postgres:"type_id" sqlite:"typeId"` + ID int64 `mysql:"id" postgres:"id" sqlite:"id"` + TypeId string `mysql:"typeId" postgres:"type_id" sqlite:"typeId"` Definition string `mysql:"definition" postgres:"definition" sqlite:"definition"` - CreatedAt time.Time `mysql:"createdAt" postgres:"created_at" sqlite:"createdAt"` - UpdatedAt time.Time `mysql:"updatedAt" postgres:"updated_at" sqlite:"updatedAt"` - DeletedAt *time.Time `mysql:"deletedAt" postgres:"deleted_at" sqlite:"deletedAt"` + CreatedAt time.Time `mysql:"createdAt" postgres:"created_at" sqlite:"createdAt"` + UpdatedAt time.Time `mysql:"updatedAt" postgres:"updated_at" sqlite:"updatedAt"` + DeletedAt *time.Time `mysql:"deletedAt" postgres:"deleted_at" sqlite:"deletedAt"` } func (objectType ObjectType) GetID() int64 { diff --git a/pkg/authz/warrant/handlers.go b/pkg/authz/warrant/handlers.go index 92a80147..dd8ee01c 100644 --- a/pkg/authz/warrant/handlers.go +++ b/pkg/authz/warrant/handlers.go @@ -175,9 +175,5 @@ func buildFilterOptions(r *http.Request) *FilterParams { filterOptions.SubjectRelation = queryParams.Get("subjectRelation") } - if queryParams.Has("policy") { - filterOptions.Policy = Policy(queryParams.Get("policy")) - } - return &filterOptions } diff --git a/pkg/authz/warrant/list.go b/pkg/authz/warrant/list.go index e9c59f6e..2391dd53 100644 --- a/pkg/authz/warrant/list.go +++ b/pkg/authz/warrant/list.go @@ -29,7 +29,6 @@ type FilterParams struct { SubjectType string `json:"subjectType,omitempty"` SubjectId string `json:"subjectId,omitempty"` SubjectRelation string `json:"subjectRelation,omitempty"` - Policy Policy `json:"policy,omitempty"` } func (fp FilterParams) String() string { @@ -58,10 +57,6 @@ func (fp FilterParams) String() string { s = fmt.Sprintf("%s&subjectRelation=%s", s, fp.SubjectRelation) } - if fp.Policy != "" { - s = fmt.Sprintf("%s&policy=%s", s, fp.Policy) - } - return strings.TrimPrefix(s, "&") } @@ -79,7 +74,6 @@ func (parser WarrantListParamParser) GetSupportedSortBys() []string { } func (parser WarrantListParamParser) ParseValue(val string, sortBy string) (interface{}, error) { - // TODO: add support for more sortBy columns switch sortBy { case "createdAt": value, err := time.Parse(time.RFC3339, val) diff --git a/pkg/authz/warrant/model.go b/pkg/authz/warrant/model.go index 3d182066..ad8a2dcb 100644 --- a/pkg/authz/warrant/model.go +++ b/pkg/authz/warrant/model.go @@ -37,18 +37,18 @@ type Model interface { } type Warrant struct { - ID int64 `mysql:"id" postgres:"id" sqlite:"id"` - ObjectType string `mysql:"objectType" postgres:"object_type" sqlite:"objectType"` - ObjectId string `mysql:"objectId" postgres:"object_id" sqlite:"objectId"` - Relation string `mysql:"relation" postgres:"relation" sqlite:"relation"` - SubjectType string `mysql:"subjectType" postgres:"subject_type" sqlite:"subjectType"` - SubjectId string `mysql:"subjectId" postgres:"subject_id" sqlite:"subjectId"` + ID int64 `mysql:"id" postgres:"id" sqlite:"id"` + ObjectType string `mysql:"objectType" postgres:"object_type" sqlite:"objectType"` + ObjectId string `mysql:"objectId" postgres:"object_id" sqlite:"objectId"` + Relation string `mysql:"relation" postgres:"relation" sqlite:"relation"` + SubjectType string `mysql:"subjectType" postgres:"subject_type" sqlite:"subjectType"` + SubjectId string `mysql:"subjectId" postgres:"subject_id" sqlite:"subjectId"` SubjectRelation string `mysql:"subjectRelation" postgres:"subject_relation" sqlite:"subjectRelation"` - Policy Policy `mysql:"policy" postgres:"policy" sqlite:"policy"` - PolicyHash string `mysql:"policyHash" postgres:"policy_hash" sqlite:"policyHash"` - CreatedAt time.Time `mysql:"createdAt" postgres:"created_at" sqlite:"createdAt"` - UpdatedAt time.Time `mysql:"updatedAt" postgres:"updated_at" sqlite:"updatedAt"` - DeletedAt *time.Time `mysql:"deletedAt" postgres:"deleted_at" sqlite:"deletedAt"` + Policy Policy `mysql:"policy" postgres:"policy" sqlite:"policy"` + PolicyHash string `mysql:"policyHash" postgres:"policy_hash" sqlite:"policyHash"` + CreatedAt time.Time `mysql:"createdAt" postgres:"created_at" sqlite:"createdAt"` + UpdatedAt time.Time `mysql:"updatedAt" postgres:"updated_at" sqlite:"updatedAt"` + DeletedAt *time.Time `mysql:"deletedAt" postgres:"deleted_at" sqlite:"deletedAt"` } func (warrant Warrant) GetID() int64 { diff --git a/pkg/authz/warrant/mysql.go b/pkg/authz/warrant/mysql.go index 1d5ef71d..25db940d 100644 --- a/pkg/authz/warrant/mysql.go +++ b/pkg/authz/warrant/mysql.go @@ -229,11 +229,6 @@ func (repo MySQLRepository) List(ctx context.Context, filterParams FilterParams, replacements = append(replacements, filterParams.SubjectRelation) } - if filterParams.Policy != "" { - query = fmt.Sprintf("%s AND policyHash = ?", query) - replacements = append(replacements, filterParams.Policy.Hash()) - } - if listParams.NextCursor != nil { comparisonOp := "<" if listParams.SortOrder == service.SortOrderAsc { diff --git a/pkg/authz/warrant/postgres.go b/pkg/authz/warrant/postgres.go index 352c3ea2..fbf2232e 100644 --- a/pkg/authz/warrant/postgres.go +++ b/pkg/authz/warrant/postgres.go @@ -234,11 +234,6 @@ func (repo PostgresRepository) List(ctx context.Context, filterParams FilterPara replacements = append(replacements, filterParams.SubjectRelation) } - if filterParams.Policy != "" { - query = fmt.Sprintf("%s AND policy_hash = ?", query) - replacements = append(replacements, filterParams.Policy.Hash()) - } - if listParams.NextCursor != nil { comparisonOp := "<" if listParams.SortOrder == service.SortOrderAsc { diff --git a/pkg/authz/warrant/sqlite.go b/pkg/authz/warrant/sqlite.go index f03271c8..f79629df 100644 --- a/pkg/authz/warrant/sqlite.go +++ b/pkg/authz/warrant/sqlite.go @@ -238,11 +238,6 @@ func (repo SQLiteRepository) List(ctx context.Context, filterParams FilterParams replacements = append(replacements, filterParams.SubjectRelation) } - if filterParams.Policy != "" { - query = fmt.Sprintf("%s AND policyHash = ?", query) - replacements = append(replacements, filterParams.Policy.Hash()) - } - if listParams.NextCursor != nil { comparisonOp := "<" if listParams.SortOrder == service.SortOrderAsc { diff --git a/pkg/config/config.go b/pkg/config/config.go index 01cd7af3..1f281c4d 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -41,17 +41,17 @@ type Config interface { GetLogLevel() int8 GetEnableAccessLog() bool GetAutoMigrate() bool - GetDatastore() *DatastoreConfig + GetDatastore() DatastoreConfig } type WarrantConfig struct { - Port int `mapstructure:"port"` - LogLevel int8 `mapstructure:"logLevel"` - EnableAccessLog bool `mapstructure:"enableAccessLog"` - AutoMigrate bool `mapstructure:"autoMigrate"` - Datastore *DatastoreConfig `mapstructure:"datastore"` - Authentication *AuthConfig `mapstructure:"authentication"` - Check *CheckConfig `mapstructure:"check"` + Port int `mapstructure:"port"` + LogLevel int8 `mapstructure:"logLevel"` + EnableAccessLog bool `mapstructure:"enableAccessLog"` + AutoMigrate bool `mapstructure:"autoMigrate"` + Datastore *WarrantDatastoreConfig `mapstructure:"datastore"` + Authentication *AuthConfig `mapstructure:"authentication"` + Check *CheckConfig `mapstructure:"check"` } func (warrantConfig WarrantConfig) GetPort() int { @@ -70,7 +70,7 @@ func (warrantConfig WarrantConfig) GetAutoMigrate() bool { return warrantConfig.AutoMigrate } -func (warrantConfig WarrantConfig) GetDatastore() *DatastoreConfig { +func (warrantConfig WarrantConfig) GetDatastore() DatastoreConfig { return warrantConfig.Datastore } @@ -82,12 +82,30 @@ func (warrantConfig WarrantConfig) GetCheck() *CheckConfig { return warrantConfig.Check } -type DatastoreConfig struct { +type DatastoreConfig interface { + GetMySQL() *MySQLConfig + GetPostgres() *PostgresConfig + GetSQLite() *SQLiteConfig +} + +type WarrantDatastoreConfig struct { MySQL *MySQLConfig `mapstructure:"mysql"` Postgres *PostgresConfig `mapstructure:"postgres"` SQLite *SQLiteConfig `mapstructure:"sqlite"` } +func (warrantDatastoreConfig WarrantDatastoreConfig) GetMySQL() *MySQLConfig { + return warrantDatastoreConfig.MySQL +} + +func (warrantDatastoreConfig WarrantDatastoreConfig) GetPostgres() *PostgresConfig { + return warrantDatastoreConfig.Postgres +} + +func (warrantDatastoreConfig WarrantDatastoreConfig) GetSQLite() *SQLiteConfig { + return warrantDatastoreConfig.SQLite +} + type MySQLConfig struct { Username string `mapstructure:"username"` Password string `mapstructure:"password"` diff --git a/pkg/database/sql.go b/pkg/database/sql.go index f16e512d..24a433e8 100644 --- a/pkg/database/sql.go +++ b/pkg/database/sql.go @@ -64,8 +64,8 @@ func (q SqlTx) ExecContext(ctx context.Context, query string, args ...interface{ query = q.Tx.Rebind(query) result, err := q.Tx.ExecContext(ctx, query, args...) if err != nil { - switch err { - case sql.ErrNoRows: + switch { + case errors.Is(err, sql.ErrNoRows): return result, err default: return result, errors.Wrap(err, "SqlTx error") @@ -78,8 +78,8 @@ func (q SqlTx) GetContext(ctx context.Context, dest interface{}, query string, a query = q.Tx.Rebind(query) err := q.Tx.GetContext(ctx, dest, query, args...) if err != nil { - switch err { - case sql.ErrNoRows: + switch { + case errors.Is(err, sql.ErrNoRows): return err default: return errors.Wrap(err, "SqlTx error") @@ -92,8 +92,8 @@ func (q SqlTx) NamedExecContext(ctx context.Context, query string, arg interface query = q.Tx.Rebind(query) result, err := q.Tx.NamedExecContext(ctx, query, arg) if err != nil { - switch err { - case sql.ErrNoRows: + switch { + case errors.Is(err, sql.ErrNoRows): return result, err default: return result, errors.Wrap(err, "SqlTx error") @@ -119,8 +119,8 @@ func (q SqlTx) SelectContext(ctx context.Context, dest interface{}, query string query = q.Tx.Rebind(query) err := q.Tx.SelectContext(ctx, dest, query, args...) if err != nil { - switch err { - case sql.ErrNoRows: + switch { + case errors.Is(err, sql.ErrNoRows): return err default: return errors.Wrap(err, "SqlTx error") @@ -203,8 +203,8 @@ func (ds SQL) ExecContext(ctx context.Context, query string, args ...interface{} result, err := queryable.ExecContext(ctx, query, args...) if err != nil { - switch err { - case sql.ErrNoRows: + switch { + case errors.Is(err, sql.ErrNoRows): return result, err default: return result, errors.Wrap(err, "Error when calling sql ExecContext") @@ -222,8 +222,8 @@ func (ds SQL) GetContext(ctx context.Context, dest interface{}, query string, ar err := queryable.GetContext(ctx, dest, query, args...) if err != nil { - switch err { - case sql.ErrNoRows: + switch { + case errors.Is(err, sql.ErrNoRows): return err default: return errors.Wrap(err, "Error when calling sql GetContext") @@ -241,8 +241,8 @@ func (ds SQL) NamedExecContext(ctx context.Context, query string, arg interface{ result, err := queryable.NamedExecContext(ctx, query, arg) if err != nil { - switch err { - case sql.ErrNoRows: + switch { + case errors.Is(err, sql.ErrNoRows): return result, err default: return result, errors.Wrap(err, "Error when calling sql NamedExecContext") @@ -278,8 +278,8 @@ func (ds SQL) SelectContext(ctx context.Context, dest interface{}, query string, err := queryable.SelectContext(ctx, dest, query, args...) if err != nil { - switch err { - case sql.ErrNoRows: + switch { + case errors.Is(err, sql.ErrNoRows): return err default: return errors.Wrap(err, "Error when calling sql SelectContext") diff --git a/tests/v1/warrants-list.json b/tests/v1/warrants-list.json index 28522465..7b95f1fe 100644 --- a/tests/v1/warrants-list.json +++ b/tests/v1/warrants-list.json @@ -295,28 +295,6 @@ ] } }, - { - "name": "listWarrantsFilterByPolicy", - "request": { - "method": "GET", - "url": "/v1/warrants?policy=tenant%20%3D%3D%20%22tenant-a%22%20%26%26%20organization%20%3D%3D%20%22org-a%22" - }, - "expectedResponse": { - "statusCode": 200, - "body": [ - { - "objectType": "role", - "objectId": "senior-accountant", - "relation": "member", - "subject": { - "objectType": "user", - "objectId": "user-a" - }, - "policy": "tenant == \"tenant-a\" \u0026\u0026 organization == \"org-a\"" - } - ] - } - }, { "name": "removeRoleSeniorAccountantFromUserAWithPolicy", "request": { diff --git a/tests/v2/warrants-list.json b/tests/v2/warrants-list.json index d4444e8b..0e98196d 100644 --- a/tests/v2/warrants-list.json +++ b/tests/v2/warrants-list.json @@ -373,30 +373,6 @@ } } }, - { - "name": "listWarrantsFilterByPolicy", - "request": { - "method": "GET", - "url": "/v2/warrants?policy=tenant%20%3D%3D%20%22tenant-a%22%20%26%26%20organization%20%3D%3D%20%22org-a%22" - }, - "expectedResponse": { - "statusCode": 200, - "body": { - "results": [ - { - "objectType": "role", - "objectId": "senior-accountant", - "relation": "member", - "subject": { - "objectType": "user", - "objectId": "user-a" - }, - "policy": "tenant == \"tenant-a\" \u0026\u0026 organization == \"org-a\"" - } - ] - } - } - }, { "name": "removeRoleSeniorAccountantFromUserAWithPolicy", "request": {