Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
AmanuelAaron committed Nov 1, 2024
1 parent c142e80 commit be4575a
Showing 1 changed file with 37 additions and 1 deletion.
38 changes: 37 additions & 1 deletion master/internal/api_runs_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1790,7 +1790,7 @@ func TestGetRunGroups(t *testing.T) {
Sort: ptrs.Ptr("state=asc"),
}

hyperparameters := map[string]any{"global_batch_size": 1, "test1": map[string]any{"test2": 1}}
hyperparameters := map[string]any{"global_batch_size": 1, "test1": map[string]any{"test2": 1}, "stringH": "abc", "boolH": false}

exp := createTestExpWithProjectID(t, api, curUser, projectIDInt)

Expand Down Expand Up @@ -1839,4 +1839,40 @@ func TestGetRunGroups(t *testing.T) {
resp, err = api.GetRunGroups(ctx, req)
require.NoError(t, err)
require.Len(t, resp.Groups, 1)

// Add new task with different hyperperameter values
newHparams := map[string]any{"global_batch_size": 9, "test1": map[string]any{"test2": 8}, "stringH": "def", "boolH": true}

exp3 := createTestExpWithProjectID(t, api, curUser, projectIDInt)

task = &model.Task{TaskType: model.TaskTypeTrial, TaskID: model.NewTaskID()}
require.NoError(t, db.AddTask(ctx, task))
require.NoError(t, db.AddTrial(ctx, &model.Trial{
State: model.PausedState,
ExperimentID: exp3.ID,
StartTime: time.Now(),
HParams: newHparams,
}, task.TaskID))

req = &apiv1.GetRunGroupsRequest{
ProjectId: &projectID,
Group: "state",
Sort: ptrs.Ptr("state=asc"),
}

resp, err = api.GetRunGroups(ctx, req)
require.NoError(t, err)
require.Len(t, resp.Groups, 2)
require.Equal(t, resp.Groups[0].GroupName, string(model.PausedState))
require.Equal(t, resp.Groups[1].GroupName, string(model.CanceledState))
require.Equal(t, resp.Groups[0].Hyperparameters.Fields["global_batch_size"].
GetStructValue().Fields["number_val"].GetNumberValue(), float64(5))
require.Equal(t, resp.Groups[0].Hyperparameters.Fields["test1.test2"].
GetStructValue().Fields["number_val"].GetNumberValue(), 4.5)
require.Equal(t, resp.Groups[0].Hyperparameters.Fields["stringH"].
GetStructValue().Fields["text_val"].GetListValue().Values[0].GetStringValue(), "abc")
require.Equal(t, resp.Groups[0].Hyperparameters.Fields["stringH"].
GetStructValue().Fields["text_val"].GetListValue().Values[1].GetStringValue(), "def")
require.Equal(t, resp.Groups[0].Hyperparameters.Fields["boolH"].
GetStructValue().Fields["bool_val"].GetBoolValue(), true)
}

0 comments on commit be4575a

Please sign in to comment.