diff --git a/docs/release-notes/9966-fix-grid.rst b/docs/release-notes/9966-fix-grid.rst new file mode 100644 index 00000000000..f36dc1b8dc6 --- /dev/null +++ b/docs/release-notes/9966-fix-grid.rst @@ -0,0 +1,7 @@ +:orphan: + +**Fixes** + +- Previously, during a grid search, if a hyperparameter contained an empty nested hyperparameter + (that is, just an empty map), that hyperparameter would not appear in the hparams passed to the + trial. diff --git a/master/pkg/searcher/grid.go b/master/pkg/searcher/grid.go index 2e6e512648f..a83c9cc6e03 100644 --- a/master/pkg/searcher/grid.go +++ b/master/pkg/searcher/grid.go @@ -260,8 +260,13 @@ func getGridAxes(route []string, h expconf.Hyperparameter) []gridAxis { return []gridAxis{axis} case h.RawNestedHyperparameter != nil: axes := []gridAxis{} - // Use h.Each for deterministic ordering. nested := expconf.Hyperparameters(*h.RawNestedHyperparameter) + // Make sure empty maps don't disappear after sampling. + if len(nested) == 0 { + axes = append(axes, gridAxis{axisValue{route, map[string]interface{}{}}}) + return axes + } + // Use h.Each for deterministic ordering. nested.Each(func(name string, subparam expconf.HyperparameterV0) { // make a completely clean copy of route var subroute []string diff --git a/master/pkg/searcher/grid_test.go b/master/pkg/searcher/grid_test.go index b81ca83e371..cffb52f314d 100644 --- a/master/pkg/searcher/grid_test.go +++ b/master/pkg/searcher/grid_test.go @@ -74,6 +74,20 @@ func TestHyperparameterGridMethod(t *testing.T) { len(getGridAxes([]string{"x"}, expconf.Hyperparameter{RawConstHyperparameter: &constParam})[0]), 1, ) + // Regression test: make sure empty nested hyperparameters don't disappear during sampling. + nestedParam := map[string]expconf.Hyperparameter{ + "empty": {RawNestedHyperparameter: &map[string]expconf.Hyperparameter{}}, + "full": {RawCategoricalHyperparameter: &catParam}, + } + result := getGridAxes([]string{"x"}, expconf.Hyperparameter{RawNestedHyperparameter: &nestedParam}) + assert.DeepEqual(t, result, []gridAxis{ + []axisValue{{Route: []string{"x", "empty"}, Value: map[string]interface{}{}}}, + []axisValue{ + {Route: []string{"x", "full"}, Value: 1}, + {Route: []string{"x", "full"}, Value: 2}, + {Route: []string{"x", "full"}, Value: 3}, + }, + }) } func TestGrid(t *testing.T) {