diff --git a/master/internal/api_runs_intg_test.go b/master/internal/api_runs_intg_test.go index 5e680cb8bdc..5099847db7e 100644 --- a/master/internal/api_runs_intg_test.go +++ b/master/internal/api_runs_intg_test.go @@ -1790,7 +1790,7 @@ func TestGetRunGroups(t *testing.T) { Sort: ptrs.Ptr("state=asc"), } - hyperparameters := map[string]any{"global_batch_size": 1, "test1": map[string]any{"test2": 1}} + hyperparameters := map[string]any{"global_batch_size": 1, "test1": map[string]any{"test2": 1}, "stringH": "abc", "boolH": false} exp := createTestExpWithProjectID(t, api, curUser, projectIDInt) @@ -1839,4 +1839,40 @@ func TestGetRunGroups(t *testing.T) { resp, err = api.GetRunGroups(ctx, req) require.NoError(t, err) require.Len(t, resp.Groups, 1) + + // Add new task with different hyperperameter values + newHparams := map[string]any{"global_batch_size": 9, "test1": map[string]any{"test2": 8}, "stringH": "def", "boolH": true} + + exp3 := createTestExpWithProjectID(t, api, curUser, projectIDInt) + + task = &model.Task{TaskType: model.TaskTypeTrial, TaskID: model.NewTaskID()} + require.NoError(t, db.AddTask(ctx, task)) + require.NoError(t, db.AddTrial(ctx, &model.Trial{ + State: model.PausedState, + ExperimentID: exp3.ID, + StartTime: time.Now(), + HParams: newHparams, + }, task.TaskID)) + + req = &apiv1.GetRunGroupsRequest{ + ProjectId: &projectID, + Group: "state", + Sort: ptrs.Ptr("state=asc"), + } + + resp, err = api.GetRunGroups(ctx, req) + require.NoError(t, err) + require.Len(t, resp.Groups, 2) + require.Equal(t, resp.Groups[0].GroupName, string(model.PausedState)) + require.Equal(t, resp.Groups[1].GroupName, string(model.CanceledState)) + require.Equal(t, resp.Groups[0].Hyperparameters.Fields["global_batch_size"]. + GetStructValue().Fields["number_val"].GetNumberValue(), float64(5)) + require.Equal(t, resp.Groups[0].Hyperparameters.Fields["test1.test2"]. + GetStructValue().Fields["number_val"].GetNumberValue(), 4.5) + require.Equal(t, resp.Groups[0].Hyperparameters.Fields["stringH"]. + GetStructValue().Fields["text_val"].GetListValue().Values[0].GetStringValue(), "abc") + require.Equal(t, resp.Groups[0].Hyperparameters.Fields["stringH"]. + GetStructValue().Fields["text_val"].GetListValue().Values[1].GetStringValue(), "def") + require.Equal(t, resp.Groups[0].Hyperparameters.Fields["boolH"]. + GetStructValue().Fields["bool_val"].GetBoolValue(), true) }