Skip to content

Commit

Permalink
fix bugs detected by log policies tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jgongd committed Oct 19, 2024
1 parent 597b727 commit 8095267
Show file tree
Hide file tree
Showing 11 changed files with 135 additions and 125 deletions.
15 changes: 7 additions & 8 deletions e2e_tests/tests/cluster/test_log_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_log_policy_cancel_retries(should_match: bool) -> None:
regex = r"(.*) this should not match (.*)"

config = {
"log_policies": [{"pattern": regex, "actions": [{"type": "cancel_retries"}]}],
"log_policies": [{"pattern": regex, "action": {"type": "cancel_retries"}}],
"max_restarts": 1,
}
exp_ref = noop.create_experiment(sess, [noop.Exit(7)], config=config)
Expand Down Expand Up @@ -48,7 +48,7 @@ def test_log_policy_exclude_node_k8s(should_match: bool) -> None:
assert agents[0].slots is not None

config = {
"log_policies": [{"pattern": regex, "actions": [{"type": "exclude_node"}]}],
"log_policies": [{"pattern": regex, "action": {"type": "exclude_node"}}],
"resources": {"slots_per_trial": len(agents[0].slots)},
"max_restarts": 1,
}
Expand Down Expand Up @@ -91,7 +91,7 @@ def test_log_policy_exclude_node_single_agent(should_match: bool) -> None:
assert agents[0].slots is not None

config = {
"log_policies": [{"pattern": regex, "actions": [{"type": "exclude_node"}]}],
"log_policies": [{"pattern": regex, "action": {"type": "exclude_node"}}],
"resources": {"slots_per_trial": len(agents[0].slots)},
"max_restarts": 1,
}
Expand Down Expand Up @@ -134,7 +134,7 @@ def test_log_policy_exclude_slurm(should_match: bool) -> None:
regex = r"(.*) this should not match (.*)"

config = {
"log_policies": [{"pattern": regex, "actions": [{"type": "exclude_node"}]}],
"log_policies": [{"pattern": regex, "action": {"type": "exclude_node"}}],
"max_restarts": 1,
}
exp_ref = noop.create_experiment(sess, [noop.Exit(7)], config=config)
Expand Down Expand Up @@ -165,7 +165,7 @@ def test_log_signal(should_match: bool) -> None:

expected_signal = "Test Signal"
config = {
"log_policies": [{"pattern": regex, "signal": expected_signal}],
"log_policies": [{"pattern": regex, "actions": [{"signal": expected_signal}]}],
"max_restarts": 1,
}

Expand Down Expand Up @@ -193,7 +193,7 @@ def test_signal_clear_after_exp_continue() -> None:

expected_signal = "Test Signal"
config = {
"log_policies": [{"pattern": regex, "signal": expected_signal}],
"log_policies": [{"pattern": regex, "actions": [{"signal": expected_signal}]}],
"max_restarts": 0,
}

Expand All @@ -216,8 +216,7 @@ def test_signal_clear_after_exp_continue() -> None:
"e",
"continue",
str(exp_ref.id),
"--config",
"hyperparameters.crash_on_startup=false",
*noop.cli_config_overrides([noop.Exit(0)]),
],
)
exp.wait_for_experiment_state(sess, exp_ref.id, bindings.experimentv1State.COMPLETED)
Expand Down
18 changes: 9 additions & 9 deletions master/internal/db/postgres_experiments_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -439,29 +439,29 @@ func TestActiveLogPatternPolicies(t *testing.T) {
eccErrorSignal := "ECC Error"
cudaOOMSignal := "CUDA OOM"
expected := expconf.LogPoliciesConfig{
expconf.LogPolicy{
RawPattern: ".*uncorrectable ECC error encountered.*",
RawActions: expconf.LogActionsV0{expconf.LogActionV0{Type: expconf.LogActionTypeSignal, Signal: &eccErrorSignal}},
},
expconf.LogPolicy{
RawPattern: ".*CUDA out of memory.*",
RawActions: expconf.LogActionsV0{expconf.LogActionV0{Type: expconf.LogActionTypeSignal, Signal: &cudaOOMSignal}},
},
expconf.LogPolicy{
RawPattern: ".*uncorrectable ECC error encountered.*",
RawActions: expconf.LogActionsV0{expconf.LogActionV0{Type: expconf.LogActionTypeSignal, Signal: &eccErrorSignal}},
},
}

require.Equal(t, expected, policies)

activeConfig, err := db.ActiveExperimentConfig(exp.ID)
require.NoError(t, err)
activeConfig.RawLogPolicies = expconf.LogPoliciesConfig{
expconf.LogPolicy{
RawPattern: "sub",
RawActions: expconf.LogActionsV0{expconf.LogActionV0{Type: expconf.LogActionTypeCancelRetries}},
},
expconf.LogPolicy{
RawPattern: `\d{5}$`,
RawActions: expconf.LogActionsV0{expconf.LogActionV0{Type: expconf.LogActionTypeExcludeNode}},
},
expconf.LogPolicy{
RawPattern: "sub",
RawActions: expconf.LogActionsV0{expconf.LogActionV0{Type: expconf.LogActionTypeCancelRetries}},
},
}

v, err := json.Marshal(activeConfig)
Expand All @@ -476,7 +476,7 @@ func TestActiveLogPatternPolicies(t *testing.T) {

policies, err = ActiveLogPolicies(ctx, exp.ID)
require.NoError(t, err)
require.Equal(t, activeConfig.RawLogPolicies, &policies)
require.Equal(t, activeConfig.RawLogPolicies, policies)
}

func TestGetNonTerminalExperimentCount(t *testing.T) {
Expand Down
78 changes: 38 additions & 40 deletions master/internal/logpattern/logpattern.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,50 +76,48 @@ func (l *LogPatternPolicies) monitor(ctx context.Context,
}

if compiledRegex.MatchString(log.Log) {
if actions := policy.Actions(); len(actions) > 0 {
for _, a := range actions {
switch a.Type {
case expconf.LogActionTypeCancelRetries:
if err := addDontRetry(
ctx, model.TaskID(log.TaskID), *log.AgentID, policy.Pattern(), log.Log,
); err != nil {
return fmt.Errorf("adding don't retry: %w", err)
}
for _, a := range policy.Actions() {
switch a.Type {
case expconf.LogActionTypeCancelRetries:
if err := addDontRetry(
ctx, model.TaskID(log.TaskID), *log.AgentID, policy.Pattern(), log.Log,
); err != nil {
return fmt.Errorf("adding don't retry: %w", err)
}

case expconf.LogActionTypeExcludeNode:
if err := addRetryOnDifferentNode(
ctx, model.TaskID(log.TaskID), *log.AgentID, policy.Pattern(), log.Log,
); err != nil {
return fmt.Errorf("adding retry on different node: %w", err)
}
case expconf.LogActionTypeSignal:
signal := a.Signal
if signal != nil {
err = db.Bun().RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
if _, err := tx.NewUpdate().Model(&model.Task{}).
Set("log_signal = ?", signal).
Where("task_id = ?", log.TaskID).
Exec(ctx); err != nil {
return fmt.Errorf("updating log signal of task %s: %w", log.TaskID, err)
}
if _, err := tx.NewUpdate().Model(&model.Run{}).
Table("run_id_task_id").
Set("log_signal = ?", signal).
Where("run.id = run_id_task_id.run_id").
Where("run_id_task_id.task_id = ?", log.TaskID).
Exec(ctx); err != nil {
return fmt.Errorf("updating log signal of task %s: %w", log.TaskID, err)
}

return nil
})
if err != nil {
return fmt.Errorf("updating log signal: %w", err)
case expconf.LogActionTypeExcludeNode:
if err := addRetryOnDifferentNode(
ctx, model.TaskID(log.TaskID), *log.AgentID, policy.Pattern(), log.Log,
); err != nil {
return fmt.Errorf("adding retry on different node: %w", err)
}
case expconf.LogActionTypeSignal:
signal := a.Signal
if signal != nil {
err = db.Bun().RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
if _, err := tx.NewUpdate().Model(&model.Task{}).
Set("log_signal = ?", signal).
Where("task_id = ?", log.TaskID).
Exec(ctx); err != nil {
return fmt.Errorf("updating log signal of task %s: %w", log.TaskID, err)
}
if _, err := tx.NewUpdate().Model(&model.Run{}).
Table("run_id_task_id").
Set("log_signal = ?", signal).
Where("run.id = run_id_task_id.run_id").
Where("run_id_task_id.task_id = ?", log.TaskID).
Exec(ctx); err != nil {
return fmt.Errorf("updating log signal of task %s: %w", log.TaskID, err)
}

return nil
})
if err != nil {
return fmt.Errorf("updating log signal: %w", err)
}
default:
return fmt.Errorf("unrecognized log pattern policy type")
}
default:
return fmt.Errorf("unrecognized log pattern policy type")
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion master/pkg/model/task_container_defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ func (c TaskContainerDefaultsConfig) Merge(
if res.LogPolicies == nil {
res.LogPolicies = other.LogPolicies
} else {
res.LogPolicies = res.LogPolicies.Merge(other.LogPolicies)
res.LogPolicies = other.LogPolicies.Merge(res.LogPolicies)
}
}

Expand Down
4 changes: 2 additions & 2 deletions master/pkg/model/task_container_defaults_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -440,8 +440,8 @@ func TestLogPatternUnmarshal(t *testing.T) {
var tcd TaskContainerDefaultsConfig
require.NoError(t, json.Unmarshal([]byte(string(`{
"log_policies": [
{"pattern": "test", "actions": [{"type": "exclude_node"}]},
{"pattern": "test2", "actions": [{"type": "cancel_retries"}]}
{"pattern": "test", "actions": ["exclude_node"]},
{"pattern": "test2", "actions": ["cancel_retries"]}
]
}`)), &tcd))

Expand Down
79 changes: 25 additions & 54 deletions master/pkg/schemas/expconf/log_pattern_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,48 +16,26 @@ import (
//go:generate ../gen.sh
type LogPoliciesConfigV0 []LogPolicyV0

// WithDefaults implements the Defaultable pseudo-interface.
func (b LogPoliciesConfigV0) WithDefaults() LogPoliciesConfigV0 {
cudaOomPattern := CUDAOOMPattern
cudaOomSignal := CUDAOOMSignal
eccErrorPattern := ECCErrorPattern
eccErrorSignal := ECCErrorSignal

if b == nil {
return LogPoliciesConfigV0{
LogPolicyV0{
RawPattern: cudaOomPattern,
RawActions: []LogActionV0{{Type: LogActionTypeSignal, Signal: &cudaOomSignal}},
},
LogPolicyV0{
RawPattern: eccErrorPattern,
RawActions: []LogActionV0{{Type: LogActionTypeSignal, Signal: &eccErrorSignal}},
},
}
}
return b
}

// Merge implements the Mergable pseudo-interface.
// We appends all LogPolicyV0s to the output slice, but if there are any with the same pattern, we merge
// their actions and save them as one LogPolicyV0.
func (b LogPoliciesConfigV0) Merge(
other LogPoliciesConfigV0,
src LogPoliciesConfigV0,
) LogPoliciesConfigV0 {
var out LogPoliciesConfigV0

patternTosrcLp := make(map[string]LogPolicyV0)
for _, lp := range b {
for _, lp := range src {
patternTosrcLp[lp.RawPattern] = lp
}

for _, otherLp := range other {
for _, otherLp := range b {
pattern := otherLp.RawPattern
if srcLp, ok := patternTosrcLp[pattern]; ok {
// Merge actions of two LogPolicies if they have the same pattern.
patternTosrcLp[pattern] = LogPolicyV0{
RawPattern: pattern,
RawActions: srcLp.RawActions.merge(otherLp.RawActions),
RawActions: otherLp.RawActions.merge(srcLp.RawActions),
}
} else {
// Source LogPoliciesConfig doesn't have this pattern.
Expand Down Expand Up @@ -101,13 +79,8 @@ func (b *LogPoliciesConfigV0) UnmarshalJSON(data []byte) error {
if err := json.Unmarshal(data, &jsonItems); err != nil {
return errors.Wrapf(err, "failed to parse runtime items")
}
// By distinguishing [] and nil input, user can override the default log policies and get empty.
// log policies. If a user provides [], the default values won't be applied.
if jsonItems == nil {
return nil
} else if len(jsonItems) == 0 {
*b = make([]LogPolicyV0, 0)
return nil
}

// Merge LogPolicyV0s with the same pattern into one.
Expand All @@ -123,7 +96,8 @@ func (b *LogPoliciesConfigV0) UnmarshalJSON(data []byte) error {
patternToLp[pattern] = LogPolicyV0{RawPattern: pattern, RawActions: mergedActions}
}

var temp LogPoliciesConfigV0
// if the input data is [] and we use `var temp LogPolicies`, function return will return nil
temp := make(LogPoliciesConfigV0, 0)
for _, lp := range patternToLp {
temp = append(temp, lp)
}
Expand Down Expand Up @@ -234,11 +208,11 @@ func (b *LogPolicyV0) UnmarshalJSON(data []byte) error {

// Merge LogActionsV0. The value of LogActionTypeSignal from other takes precedence.
// Union merge the other LogAction types.
func (s LogActionsV0) merge(other LogActionsV0) LogActionsV0 {
func (l LogActionsV0) merge(src LogActionsV0) LogActionsV0 {
// Store unique actions except signal, and find source signal.
actions := set.New[LogActionV0]()
var srcSignal *LogActionV0
for _, a := range s {
for _, a := range src {
if a.Type == LogActionTypeSignal {
srcSignal = &a
continue
Expand All @@ -248,7 +222,7 @@ func (s LogActionsV0) merge(other LogActionsV0) LogActionsV0 {

// Store unique actions except signal, and find other signal.
var otherSignal *LogActionV0
for _, a := range other {
for _, a := range l {
if a.Type == LogActionTypeSignal {
otherSignal = &a
continue
Expand All @@ -266,9 +240,9 @@ func (s LogActionsV0) merge(other LogActionsV0) LogActionsV0 {
}

// Sort LogActionsV0 by type so the output is in deterministic state. Testing will be easier.
func (s LogActionsV0) sort() {
sort.Slice(s, func(i, j int) bool {
return s[i].Type < s[j].Type
func (l LogActionsV0) sort() {
sort.Slice(l, func(i, j int) bool {
return l[i].Type < l[j].Type
})
}

Expand Down Expand Up @@ -308,28 +282,25 @@ func (s LogActionV0) MarshalJSON() ([]byte, error) {
// UnmarshalJSON implements the json.Unmarshaler interface.
func (s *LogActionV0) UnmarshalJSON(data []byte) error {
var action string
if err := json.Unmarshal(data, &action); err != nil {
return fmt.Errorf(
"failed to unmarshal log action type CancelRetries and ExcludeNode: %w, data: %q", err, string(data),
)
}

// Handle all the types beside signal
switch LogActionType(action) {
case LogActionTypeCancelRetries:
*s = LogActionV0{Type: LogActionTypeCancelRetries}
return nil
case LogActionTypeExcludeNode:
*s = LogActionV0{Type: LogActionTypeExcludeNode}
return nil
// err is not nil means input data is not cancel_retries or exclude_node.
if err := json.Unmarshal(data, &action); err == nil {
// Handle all the types beside signal
switch LogActionType(action) {
case LogActionTypeCancelRetries:
*s = LogActionV0{Type: LogActionTypeCancelRetries}
return nil
case LogActionTypeExcludeNode:
*s = LogActionV0{Type: LogActionTypeExcludeNode}
return nil
}
}

// Handle Signal
temp := struct {
Signal *string `json:"signal"`
}{}
if err := json.Unmarshal(data, &temp); err != nil {
return fmt.Errorf("failed to unmarshal log action type signal: %w, data: %q", err, string(data))
if err := json.Unmarshal(data, &temp); err != nil || temp.Signal == nil {
return fmt.Errorf("failed to unmarshal log action: %w, data: %q", err, string(data))
}
*s = LogActionV0{Type: LogActionTypeSignal, Signal: temp.Signal}

Expand Down
Loading

0 comments on commit 8095267

Please sign in to comment.