Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
AmanuelAaron committed Nov 1, 2024
1 parent be4575a commit ec80672
Showing 1 changed file with 24 additions and 14 deletions.
38 changes: 24 additions & 14 deletions master/internal/api_runs_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1790,7 +1790,12 @@ func TestGetRunGroups(t *testing.T) {
Sort: ptrs.Ptr("state=asc"),
}

hyperparameters := map[string]any{"global_batch_size": 1, "test1": map[string]any{"test2": 1}, "stringH": "abc", "boolH": false}
hyperparameters := map[string]any{
"global_batch_size": 1,
"test1": map[string]any{"test2": 1},
"stringH": "abc",
"boolH": false,
}

exp := createTestExpWithProjectID(t, api, curUser, projectIDInt)

Expand Down Expand Up @@ -1841,7 +1846,12 @@ func TestGetRunGroups(t *testing.T) {
require.Len(t, resp.Groups, 1)

// Add new task with different hyperperameter values
newHparams := map[string]any{"global_batch_size": 9, "test1": map[string]any{"test2": 8}, "stringH": "def", "boolH": true}
newHparams := map[string]any{
"global_batch_size": 9,
"test1": map[string]any{"test2": 8},
"stringH": "def",
"boolH": true,
}

exp3 := createTestExpWithProjectID(t, api, curUser, projectIDInt)

Expand All @@ -1863,16 +1873,16 @@ func TestGetRunGroups(t *testing.T) {
resp, err = api.GetRunGroups(ctx, req)
require.NoError(t, err)
require.Len(t, resp.Groups, 2)
require.Equal(t, resp.Groups[0].GroupName, string(model.PausedState))
require.Equal(t, resp.Groups[1].GroupName, string(model.CanceledState))
require.Equal(t, resp.Groups[0].Hyperparameters.Fields["global_batch_size"].
GetStructValue().Fields["number_val"].GetNumberValue(), float64(5))
require.Equal(t, resp.Groups[0].Hyperparameters.Fields["test1.test2"].
GetStructValue().Fields["number_val"].GetNumberValue(), 4.5)
require.Equal(t, resp.Groups[0].Hyperparameters.Fields["stringH"].
GetStructValue().Fields["text_val"].GetListValue().Values[0].GetStringValue(), "abc")
require.Equal(t, resp.Groups[0].Hyperparameters.Fields["stringH"].
GetStructValue().Fields["text_val"].GetListValue().Values[1].GetStringValue(), "def")
require.Equal(t, resp.Groups[0].Hyperparameters.Fields["boolH"].
GetStructValue().Fields["bool_val"].GetBoolValue(), true)
require.Equal(t, string(model.PausedState), resp.Groups[0].GroupName)
require.Equal(t, string(model.CanceledState), resp.Groups[1].GroupName)
require.InEpsilon(t, float64(5), resp.Groups[0].Hyperparameters.Fields["global_batch_size"].
GetStructValue().Fields["number_val"].GetNumberValue(), 0.00001)
require.InEpsilon(t, 4.5, resp.Groups[0].Hyperparameters.Fields["test1.test2"].
GetStructValue().Fields["number_val"].GetNumberValue(), 0.00001)
require.Equal(t, "abc", resp.Groups[0].Hyperparameters.Fields["stringH"].
GetStructValue().Fields["text_val"].GetListValue().Values[0].GetStringValue())
require.Equal(t, "def", resp.Groups[0].Hyperparameters.Fields["stringH"].
GetStructValue().Fields["text_val"].GetListValue().Values[1].GetStringValue())
require.True(t, resp.Groups[0].Hyperparameters.Fields["boolH"].
GetStructValue().Fields["bool_val"].GetBoolValue())
}

0 comments on commit ec80672

Please sign in to comment.