diff --git a/master/internal/configpolicy/postgres_task_config_policy.go b/master/internal/configpolicy/postgres_task_config_policy.go index 9875db08168..23968fcf6c1 100644 --- a/master/internal/configpolicy/postgres_task_config_policy.go +++ b/master/internal/configpolicy/postgres_task_config_policy.go @@ -40,19 +40,16 @@ func SetTaskConfigPolicies(ctx context.Context, }) } -// SetTaskConfigPoliciesTx adds the task invariant config and constraints config policies to +// SetTaskConfigPoliciesTx adds the task invariant config and constraints policies to // the database. func SetTaskConfigPoliciesTx(ctx context.Context, tx *bun.Tx, tcp *model.TaskConfigPolicies, ) error { q := db.Bun().NewInsert().Model(tcp) - if tcp.InvariantConfig == nil { - q = q.ExcludeColumn("invariant_config") - } - if tcp.Constraints == nil { - q = q.ExcludeColumn("constraints") - } + q = q.Set("last_updated_by = ?, last_updated_time = ?", tcp.LastUpdatedBy, tcp.LastUpdatedTime) + q = q.Set("invariant_config = ?", tcp.InvariantConfig) + q = q.Set("constraints = ?", tcp.Constraints) if tcp.WorkspaceID == nil { q = q.On("CONFLICT (workload_type) WHERE workspace_id IS NULL DO UPDATE") @@ -60,14 +57,6 @@ func SetTaskConfigPoliciesTx(ctx context.Context, tx *bun.Tx, q = q.On("CONFLICT (workspace_id, workload_type) WHERE workspace_id IS NOT NULL DO UPDATE") } - q = q.Set("last_updated_by = ?, last_updated_time = ?", tcp.LastUpdatedBy, tcp.LastUpdatedTime) - if tcp.InvariantConfig != nil { - q = q.Set("invariant_config = ?", tcp.InvariantConfig) - } - if tcp.Constraints != nil { - q = q.Set("constraints = ?", tcp.Constraints) - } - _, err := q.Exec(ctx) if err != nil { return fmt.Errorf("error setting task config policies: %w", err) diff --git a/master/internal/configpolicy/postgres_task_config_policy_intg_test.go b/master/internal/configpolicy/postgres_task_config_policy_intg_test.go index 3b136dc4668..9035e6cd7c2 100644 --- a/master/internal/configpolicy/postgres_task_config_policy_intg_test.go +++ b/master/internal/configpolicy/postgres_task_config_policy_intg_test.go @@ -6,6 +6,7 @@ package configpolicy import ( "context" "encoding/json" + "regexp" "testing" "time" @@ -206,6 +207,174 @@ func TestSetTaskConfigPolicies(t *testing.T) { require.ErrorContains(t, err, "violates foreign key constraint") } +func TestUpdateTaskConfigPolicies(t *testing.T) { + ctx := context.Background() + require.NoError(t, etc.SetRootPath(db.RootFromDB)) + pgDB, cleanup := db.MustResolveNewPostgresDatabase(t) + defer cleanup() + db.MustMigrateTestPostgres(t, pgDB, db.MigrationsFromDB) + + user := db.RequireMockUser(t, pgDB) + + workspaceIDs := []int32{} + + defer func() { + if len(workspaceIDs) > 0 { + err := db.CleanupMockWorkspace(workspaceIDs) + if err != nil { + log.Errorf("error when cleaning up mock workspaces") + } + } + }() + + config1JSON := ` +{ + "resources": { + "priority": 99 + }, + "max_restarts": 20 +} +` + config2JSON := ` +{ + "resources": { + "priority": 100 + }, + "max_restarts": 25 +} +` + + constraints1JSON := ` +{ + "resources": { + "max_slots": 50 + }, + "priority_limit": 99 +} +` + constraints2JSON := ` +{ + "resources": { + "max_slots": 80 + }, + "priority_limit": 100 +} +` + whitespace := regexp.MustCompile(`[\s]`) + + config1 := whitespace.ReplaceAllString(config1JSON, "") + config2 := whitespace.ReplaceAllString(config2JSON, "") + constraints1 := whitespace.ReplaceAllString(constraints1JSON, "") + constraints2 := whitespace.ReplaceAllString(constraints2JSON, "") + + tests := []struct { + name string + tcps *model.TaskConfigPolicies + tcpsUpdated *model.TaskConfigPolicies + }{ + { + "config to config and constraints", &model.TaskConfigPolicies{ + LastUpdatedBy: user.ID, + WorkloadType: model.ExperimentType, + InvariantConfig: &config1, + }, &model.TaskConfigPolicies{ + LastUpdatedBy: user.ID, + WorkloadType: model.ExperimentType, + InvariantConfig: &config2, + Constraints: &constraints2, + }, + }, + { + "constraints to config and constraints", &model.TaskConfigPolicies{ + LastUpdatedBy: user.ID, + WorkloadType: model.ExperimentType, + Constraints: &constraints1, + }, &model.TaskConfigPolicies{ + LastUpdatedBy: user.ID, + WorkloadType: model.ExperimentType, + InvariantConfig: &config2, + Constraints: &constraints2, + }, + }, + { + "config and constraints to config and constraints", &model.TaskConfigPolicies{ + LastUpdatedBy: user.ID, + WorkloadType: model.ExperimentType, + InvariantConfig: &config1, + Constraints: &constraints1, + }, &model.TaskConfigPolicies{ + LastUpdatedBy: user.ID, + WorkloadType: model.ExperimentType, + InvariantConfig: &config2, + Constraints: &constraints2, + }, + }, + { + "config and constraints to only config", &model.TaskConfigPolicies{ + LastUpdatedBy: user.ID, + WorkloadType: model.ExperimentType, + InvariantConfig: &config1, + Constraints: &constraints1, + }, &model.TaskConfigPolicies{ + LastUpdatedBy: user.ID, + WorkloadType: model.ExperimentType, + InvariantConfig: &config2, + }, + }, + { + "config and constraints to only constraints", &model.TaskConfigPolicies{ + LastUpdatedBy: user.ID, + WorkloadType: model.ExperimentType, + InvariantConfig: &config1, + Constraints: &constraints1, + }, &model.TaskConfigPolicies{ + LastUpdatedBy: user.ID, + WorkloadType: model.ExperimentType, + Constraints: &constraints2, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + w := model.Workspace{Name: uuid.NewString(), UserID: user.ID} + _, err := db.Bun().NewInsert().Model(&w).Exec(ctx) + require.NoError(t, err) + workspaceIDs = append(workspaceIDs, int32(w.ID)) + + test.tcps.WorkspaceID = &w.ID + test.tcpsUpdated.WorkspaceID = &w.ID + + // Set config policies. + err = SetTaskConfigPolicies(ctx, test.tcps) + require.NoError(t, err) + + // Update config policies. + err = SetTaskConfigPolicies(ctx, test.tcpsUpdated) + require.NoError(t, err) + + // Verify config policies are updated properly. + tcps, err := GetTaskConfigPolicies(ctx, &w.ID, test.tcps.WorkloadType) + require.NoError(t, err) + + if test.tcpsUpdated.InvariantConfig != nil { + require.NotNil(t, tcps.InvariantConfig) + invariantConfig := whitespace.ReplaceAllString( + *tcps.InvariantConfig, + "") + require.Equal(t, *test.tcpsUpdated.InvariantConfig, + invariantConfig) + } + if test.tcpsUpdated.Constraints != nil { + require.NotNil(t, tcps.Constraints) + constraints := whitespace.ReplaceAllString(*tcps.Constraints, + "") + require.Equal(t, *test.tcpsUpdated.Constraints, constraints) + } + }) + } +} + // Test the enforcement of the primary key on the task_config_polciies table. func TestTaskConfigPoliciesUnique(t *testing.T) { ctx := context.Background()