From ec80672c5c84ca21196843895e64a7c27959ef22 Mon Sep 17 00:00:00 2001 From: AmanuelAaron Date: Fri, 1 Nov 2024 13:55:42 -0400 Subject: [PATCH] fix tests --- master/internal/api_runs_intg_test.go | 38 +++++++++++++++++---------- 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/master/internal/api_runs_intg_test.go b/master/internal/api_runs_intg_test.go index 5099847db7e..35b7f82f013 100644 --- a/master/internal/api_runs_intg_test.go +++ b/master/internal/api_runs_intg_test.go @@ -1790,7 +1790,12 @@ func TestGetRunGroups(t *testing.T) { Sort: ptrs.Ptr("state=asc"), } - hyperparameters := map[string]any{"global_batch_size": 1, "test1": map[string]any{"test2": 1}, "stringH": "abc", "boolH": false} + hyperparameters := map[string]any{ + "global_batch_size": 1, + "test1": map[string]any{"test2": 1}, + "stringH": "abc", + "boolH": false, + } exp := createTestExpWithProjectID(t, api, curUser, projectIDInt) @@ -1841,7 +1846,12 @@ func TestGetRunGroups(t *testing.T) { require.Len(t, resp.Groups, 1) // Add new task with different hyperperameter values - newHparams := map[string]any{"global_batch_size": 9, "test1": map[string]any{"test2": 8}, "stringH": "def", "boolH": true} + newHparams := map[string]any{ + "global_batch_size": 9, + "test1": map[string]any{"test2": 8}, + "stringH": "def", + "boolH": true, + } exp3 := createTestExpWithProjectID(t, api, curUser, projectIDInt) @@ -1863,16 +1873,16 @@ func TestGetRunGroups(t *testing.T) { resp, err = api.GetRunGroups(ctx, req) require.NoError(t, err) require.Len(t, resp.Groups, 2) - require.Equal(t, resp.Groups[0].GroupName, string(model.PausedState)) - require.Equal(t, resp.Groups[1].GroupName, string(model.CanceledState)) - require.Equal(t, resp.Groups[0].Hyperparameters.Fields["global_batch_size"]. - GetStructValue().Fields["number_val"].GetNumberValue(), float64(5)) - require.Equal(t, resp.Groups[0].Hyperparameters.Fields["test1.test2"]. - GetStructValue().Fields["number_val"].GetNumberValue(), 4.5) - require.Equal(t, resp.Groups[0].Hyperparameters.Fields["stringH"]. - GetStructValue().Fields["text_val"].GetListValue().Values[0].GetStringValue(), "abc") - require.Equal(t, resp.Groups[0].Hyperparameters.Fields["stringH"]. - GetStructValue().Fields["text_val"].GetListValue().Values[1].GetStringValue(), "def") - require.Equal(t, resp.Groups[0].Hyperparameters.Fields["boolH"]. - GetStructValue().Fields["bool_val"].GetBoolValue(), true) + require.Equal(t, string(model.PausedState), resp.Groups[0].GroupName) + require.Equal(t, string(model.CanceledState), resp.Groups[1].GroupName) + require.InEpsilon(t, float64(5), resp.Groups[0].Hyperparameters.Fields["global_batch_size"]. + GetStructValue().Fields["number_val"].GetNumberValue(), 0.00001) + require.InEpsilon(t, 4.5, resp.Groups[0].Hyperparameters.Fields["test1.test2"]. + GetStructValue().Fields["number_val"].GetNumberValue(), 0.00001) + require.Equal(t, "abc", resp.Groups[0].Hyperparameters.Fields["stringH"]. + GetStructValue().Fields["text_val"].GetListValue().Values[0].GetStringValue()) + require.Equal(t, "def", resp.Groups[0].Hyperparameters.Fields["stringH"]. + GetStructValue().Fields["text_val"].GetListValue().Values[1].GetStringValue()) + require.True(t, resp.Groups[0].Hyperparameters.Fields["boolH"]. + GetStructValue().Fields["bool_val"].GetBoolValue()) }