Skip to content

Commit

Permalink
chore: add checkpoint and max slots config policy enforcements in PAT…
Browse files Browse the repository at this point in the history
…CH experiment (#10125)
  • Loading branch information
amandavialva01 authored Oct 25, 2024
1 parent b3f928b commit 233e095
Show file tree
Hide file tree
Showing 7 changed files with 505 additions and 15 deletions.
43 changes: 33 additions & 10 deletions master/internal/api_experiment.go
Original file line number Diff line number Diff line change
Expand Up @@ -1218,6 +1218,7 @@ func (a *apiServer) PatchExperiment(
}
activeConfig.SetResources(resources)
}

newCheckpointStorage := req.Experiment.CheckpointStorage

if newCheckpointStorage != nil {
Expand All @@ -1231,6 +1232,23 @@ func (a *apiServer) PatchExperiment(
storage.SetSaveTrialBest(int(newCheckpointStorage.SaveTrialBest))
storage.SetSaveTrialLatest(int(newCheckpointStorage.SaveTrialLatest))
activeConfig.SetCheckpointStorage(storage)

// Only allow checkpoint storage changes if it is not specified as an invariant config.
w, err := getWorkspaceByConfig(activeConfig)
if err != nil {
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to get workspace %s",
activeConfig.Workspace()))
}

enforcedChkptConf, err := configpolicy.GetConfigPolicyField[expconf.CheckpointStorageConfig](
ctx, &w.ID, "invariant_config", "checkpoint_storage",
model.ExperimentType)
if err != nil {
return nil, fmt.Errorf("unable to fetch task config policies: %w", err)
}
if enforcedChkptConf != nil {
activeConfig.SetCheckpointStorage(*enforcedChkptConf)
}
}

// `patch` represents the allowed mutations that can be performed on an experiment, in JSON
Expand Down Expand Up @@ -1470,20 +1488,16 @@ func (a *apiServer) parseAndMergeContinueConfig(expID int, overrideConfig string
fmt.Sprintf("override config must have single searcher type got '%s' instead", overrideName))
}

// Determine which workspace the experiment is in.
wkspName := activeConfig.Workspace()
if wkspName == "" {
wkspName = model.DefaultWorkspaceName
}
ctx := context.TODO()
w, err := workspace.WorkspaceByName(ctx, wkspName)
// Merge the config with the optionally specified invariant config specified by task config
// policies.
w, err := getWorkspaceByConfig(activeConfig)
if err != nil {
return nil, false, status.Errorf(codes.Internal,
fmt.Sprintf("failed to get workspace %s", activeConfig.Workspace()))
}
// Merge the config with the optionally specified invariant config specified by task config
// policies.
configWithInvariantDefaults, err := configpolicy.MergeWithInvariantExperimentConfigs(ctx,

configWithInvariantDefaults, err := configpolicy.MergeWithInvariantExperimentConfigs(
context.TODO(),
w.ID, mergedConfig)
if err != nil {
return nil, false,
Expand All @@ -1499,6 +1513,15 @@ func (a *apiServer) parseAndMergeContinueConfig(expID int, overrideConfig string
return bytes.([]byte), isSingle, nil
}

func getWorkspaceByConfig(config expconf.ExperimentConfig) (*model.Workspace, error) {
wkspName := config.Workspace()
if wkspName == "" {
wkspName = model.DefaultWorkspaceName
}
ctx := context.TODO()
return workspace.WorkspaceByName(ctx, wkspName)
}

var errContinueHPSearchCompleted = status.Error(codes.FailedPrecondition,
"experiment has been completed, cannot continue this experiment")

Expand Down
24 changes: 24 additions & 0 deletions master/internal/api_experiment_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2347,3 +2347,27 @@ func TestDeleteExperimentsFiltered(t *testing.T) {
}
t.Error("expected experiments to delete after 15 seconds and they did not")
}

func TestGetWorkspaceByConfig(t *testing.T) {
api, _, ctx := setupAPITest(t, nil)
resp, err := api.PostWorkspace(ctx, &apiv1.PostWorkspaceRequest{
Name: uuid.New().String(),
})
require.NoError(t, err)
wkspName := &resp.Workspace.Name

t.Run("no workspace name", func(t *testing.T) {
w, err := getWorkspaceByConfig(expconf.ExperimentConfig{RawWorkspace: ptrs.Ptr("")})
require.NoError(t, err)

// Verify we get the Uncategorized workspace.
require.Equal(t, 1, w.ID)
})
t.Run("has workspace name", func(t *testing.T) {
w, err := getWorkspaceByConfig(expconf.ExperimentConfig{
RawWorkspace: wkspName,
})
require.NoError(t, err)
require.Equal(t, *wkspName, w.Name)
})
}
62 changes: 60 additions & 2 deletions master/internal/configpolicy/postgres_task_config_policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package configpolicy
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"

Expand All @@ -14,8 +15,9 @@ import (
)

const (
wkspIDQuery = "workspace_id = ?"
wkspIDGlobalQuery = "workspace_id IS ?"
wkspIDQuery = "workspace_id = ?"
wkspIDGlobalQuery = "workspace_id IS ?"
invalidPolicyTypeErr = "invalid policy type"
// DefaultInvariantConfigStr is the default invariant config val used for tests.
DefaultInvariantConfigStr = `{
"description": "random description",
Expand Down Expand Up @@ -108,3 +110,59 @@ func DeleteConfigPolicies(ctx context.Context,
}
return nil
}

// GetConfigPolicyField fetches the field from an invariant_config or constraints policyType, in order
// of precedence. Global scope has highest precedence, then workspace. Returns nil if none is found.
// **NOTE** The field arguments are wrapped in bun.Safe, so you must specify the "raw" string
// exactly as you wish for it to be accessed in the database. For example, if you want to access
// resources.max_slots, the field argument should be "'resources' -> 'max_slots'" NOT
// "resources -> max_slots".
// **NOTE**When using this function to retrieve an object of Kind Pointer, set T as the Type of
// object that the Pointer wraps. For example, if we want an object of type *int, set T to int, so
// that when its pointer is returned, you get an object of type *int.
func GetConfigPolicyField[T any](ctx context.Context, wkspID *int, policyType, field, workloadType string) (*T,
error,
) {
if policyType != "invariant_config" && policyType != "constraints" {
return nil, fmt.Errorf("%s :%s", invalidPolicyTypeErr, policyType)
}

var confBytes []byte
var conf T
err := db.Bun().RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
var globalBytes []byte
err := tx.NewSelect().Table("task_config_policies").
ColumnExpr("? -> ?", bun.Safe(policyType), bun.Safe(field)).
Where("workspace_id IS NULL").
Where("workload_type = ?", workloadType).Scan(ctx, &globalBytes)
if err == nil && len(globalBytes) > 0 {
confBytes = globalBytes
}
if err != nil && err != sql.ErrNoRows {
return err
}

var wkspBytes []byte
err = tx.NewSelect().Table("task_config_policies").
ColumnExpr("? -> ?", bun.Safe(policyType), bun.Safe(field)).
Where("workspace_id = ?", wkspID).
Where("workload_type = ?", workloadType).Scan(ctx, &wkspBytes)
if err == nil && len(globalBytes) == 0 {
confBytes = wkspBytes
}
return err
})
if err == sql.ErrNoRows || len(confBytes) == 0 {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("error getting config field %s: %w", field, err)
}

err = json.Unmarshal(confBytes, &conf)
if err != nil {
return nil, fmt.Errorf("error unmarshaling config field: %w", err)
}

return &conf, nil
}
Loading

0 comments on commit 233e095

Please sign in to comment.