diff --git a/NEWS.md b/NEWS.md index 2cc4948b..e9d5eb2e 100644 --- a/NEWS.md +++ b/NEWS.md @@ -4,7 +4,7 @@ * The package will now log a backtrace for errors and warnings that occur during tuning. When a tuning process encounters issues, see the new `trace` column in the `collect_notes(.Last.tune.result)` output to find precisely where the error occurred (#873). -* When automatic grids are used, `dials::grid_space_filling()` is now used (instead of `dials::grid_latin_hypercube()`). Overall, the new function produces optimized designs (not depending on random numbers). When using Bayesian models, we will use a Latin Hypercube since we produce 5,000 candidates, which is too slow to do with pre-optimized designs. +* When automatic grids are used, `dials::grid_space_filling()` is now used (instead of `dials::grid_latin_hypercube()`). There is an exception: when `grid = 1`, `dials::grid_random()` is used to avoid warnings. Overall, the new function produces optimized designs (not depending on random numbers). When using Bayesian models, we will use a Latin Hypercube since we produce 5,000 candidates, which is too slow to do with pre-optimized designs. (#962) # tune 1.2.1 diff --git a/R/tune_bayes.R b/R/tune_bayes.R index 8a580114..257f493d 100644 --- a/R/tune_bayes.R +++ b/R/tune_bayes.R @@ -531,7 +531,12 @@ create_initial_set <- function(param, n = NULL, checks) { if (any(checks == "bayes")) { check_bayes_initial_size(nrow(param), n) } - dials::grid_space_filling(param, size = n) + if (n == 1) { + res <- dials::grid_random(param, size = n) + } else { + res <- dials::grid_space_filling(param, size = n) + } + res } check_iter <- function(iter, call) { diff --git a/tests/testthat/test-bayes.R b/tests/testthat/test-bayes.R index 59bd5ff3..f62169d6 100644 --- a/tests/testthat/test-bayes.R +++ b/tests/testthat/test-bayes.R @@ -611,3 +611,19 @@ test_that("tune_bayes() output for `iter` edge cases (#721)", { tune_bayes(wf, boots, iter = NULL) ) }) + +test_that("1-point grid (#962)", { + skip_if_not_installed("dials", minimum_version = "1.3.0") + + expect_silent({ + set.seed(1) + grid <- tune:::create_initial_set( + dials::parameters(dials::penalty(), dials::deg_free()), + n = 1, + checks = "none" + ) + }) + expect_equal(nrow(grid), 1L) +}) + +