diff --git a/DESCRIPTION b/DESCRIPTION index c580b872..0dab118f 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: epipredict Title: Basic epidemiology forecasting methods -Version: 0.1.2 +Version: 0.1.3 Authors@R: c( person("Daniel J.", "McDonald", , "daniel@stat.ubc.ca", role = c("aut", "cre")), person("Ryan", "Tibshirani", , "ryantibs@cmu.edu", role = "aut"), @@ -49,6 +49,7 @@ Imports: workflows (>= 1.0.0) Suggests: data.table, + epidatasets, epidatr (>= 1.0.0), fs, grf, diff --git a/NEWS.md b/NEWS.md index 3e4e964b..e080d1aa 100644 --- a/NEWS.md +++ b/NEWS.md @@ -10,6 +10,7 @@ Pre-1.0.0 numbering scheme: 0.x will indicate releases, while 0.0.x will indicat ## bugfixes - shifting no columns results in no error for either `step_epi_ahead` and `step_epi_lag` +- Quantiles produced by `grf` were sometimes out of order. # epipredict 0.1 diff --git a/R/layer_population_scaling.R b/R/layer_population_scaling.R index 4755083c..7ec16882 100644 --- a/R/layer_population_scaling.R +++ b/R/layer_population_scaling.R @@ -48,7 +48,7 @@ #' @export #' @examples #' library(dplyr) -#' jhu <- cases_deaths_subset %>% +#' jhu <- epidatasets::cases_deaths_subset %>% #' filter(time_value > "2021-11-01", geo_value %in% c("ca", "ny")) %>% #' select(geo_value, time_value, cases) #' diff --git a/R/make_grf_quantiles.R b/R/make_grf_quantiles.R index 253ea1ac..2903c93a 100644 --- a/R/make_grf_quantiles.R +++ b/R/make_grf_quantiles.R @@ -165,7 +165,7 @@ make_grf_quantiles <- function() { # turn the predictions into a tibble with a dist_quantiles column process_qrf_preds <- function(x, object) { - quantile_levels <- parsnip::extract_fit_engine(object)$quantiles.orig + quantile_levels <- parsnip::extract_fit_engine(object)$quantiles.orig %>% sort() x <- x$predictions out <- lapply(vctrs::vec_chop(x), function(x) sort(drop(x))) out <- dist_quantiles(out, list(quantile_levels)) diff --git a/R/step_population_scaling.R b/R/step_population_scaling.R index 3d3e6529..7bdb0092 100644 --- a/R/step_population_scaling.R +++ b/R/step_population_scaling.R @@ -45,7 +45,7 @@ #' @export #' @examples #' library(dplyr) -#' jhu <- cases_deaths_subset %>% +#' jhu <- epidatasets::cases_deaths_subset %>% #' filter(time_value > "2021-11-01", geo_value %in% c("ca", "ny")) %>% #' select(geo_value, time_value, cases) #' diff --git a/man/layer_population_scaling.Rd b/man/layer_population_scaling.Rd index 25311669..3a22dbee 100644 --- a/man/layer_population_scaling.Rd +++ b/man/layer_population_scaling.Rd @@ -75,7 +75,7 @@ argument is a common \emph{divisor} of the selected variables. } \examples{ library(dplyr) -jhu <- cases_deaths_subset \%>\% +jhu <- epidatasets::cases_deaths_subset \%>\% filter(time_value > "2021-11-01", geo_value \%in\% c("ca", "ny")) \%>\% select(geo_value, time_value, cases) diff --git a/man/step_population_scaling.Rd b/man/step_population_scaling.Rd index 004c2c82..6a0c5dea 100644 --- a/man/step_population_scaling.Rd +++ b/man/step_population_scaling.Rd @@ -88,7 +88,7 @@ argument is a common \emph{multiplier} of the selected variables. } \examples{ library(dplyr) -jhu <- cases_deaths_subset \%>\% +jhu <- epidatasets::cases_deaths_subset \%>\% filter(time_value > "2021-11-01", geo_value \%in\% c("ca", "ny")) \%>\% select(geo_value, time_value, cases) diff --git a/tests/testthat/_snaps/snapshots.md b/tests/testthat/_snaps/snapshots.md index ae045460..f3e7e573 100644 --- a/tests/testthat/_snaps/snapshots.md +++ b/tests/testthat/_snaps/snapshots.md @@ -154,7 +154,7 @@ 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list(values = c(0, 0, - 0, 0, 0.016465765, 0.03549514, 0.05225675, 0.0644172, 0.0749343, + 0, 0, 0.016465765, 0.03549514, 0.05225675, 0.0644172, 0.0749343000000001, 0.0847941, 0.0966258, 0.103199, 0.1097722, 0.1216039, 0.1314637, 0.1419808, 0.15414125, 0.17090286, 0.189932235, 0.22848398, 0.30542311, 0.40216399, 0.512353658), quantile_levels = c(0.01, 0.025, 0.05, @@ -267,7 +267,7 @@ 0.99)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list(values = c(0, 0, 0, 0, 0.114729892920429, 0.227785958288583, 0.282278878729037, 0.320407599201492, 0.350577823459785, - 0.37665230304923, 0.39981364198757, 0.4218461, 0.444009706175862, + 0.376652303049231, 0.39981364198757, 0.4218461, 0.444009706175862, 0.466962725214852, 0.493098379685547, 0.523708407392674, 0.562100740111401, 0.619050517814778, 0.754868363055733, 1.1177263295869, 1.76277018354499, 2.37278671910076, 2.9651652434047), quantile_levels = c(0.01, @@ -314,7 +314,7 @@ 0.144337973117581, 0.250292371898569, 0.367310419323293, 0.44444044802193, 0.506592035751958, 0.558428768125431, 0.602035095628756, 0.64112383905529, 0.674354964141041, 0.703707875219752, 0.7319844, 0.760702196782168, - 0.78975826405844, 0.823427572594726, 0.860294897090771, 0.904032120658957, + 0.789758264058441, 0.823427572594726, 0.860294897090771, 0.904032120658957, 0.955736581115011, 1.0165945004053, 1.09529786576616, 1.21614421175967, 1.32331604019295, 1.45293812780298), quantile_levels = c(0.01, 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, @@ -351,7 +351,7 @@ 0.0497573816250162, 0.081255049503995, 0.108502307388674, 0.132961558931189, 0.156011650575706, 0.177125892134071, 0.1975426, 0.217737120618906, 0.239458499211792, 0.263562581820818, 0.289525383565136, 0.31824420000725, - 0.35141305194052, 0.393862560773808, 0.453538799225292, 0.558631806850418, + 0.351413051940519, 0.393862560773808, 0.453538799225292, 0.558631806850418, 0.657452391363313, 0.767918764883928), quantile_levels = c(0.01, 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99 @@ -412,20 +412,19 @@ 0.0766736159703596, 0.0942284381264812, 0.11050757203172, 0.125214601455714, 0.1393442, 0.15359732398729, 0.168500447692877, 0.184551468093631, 0.202926420944109, 0.22476606802393, 0.253070223293233, - 0.29122995395109, 0.341963643747938, 0.419747975311502, 0.495994046054689, - 0.5748791770223), quantile_levels = c(0.01, 0.025, 0.05, - 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, - 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", - "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0, 0, 0, 0, 0, 0.00603076915889168, 0.0356039073625737, - 0.0609470811194113, 0.0833232869645198, 0.103265350891109, - 0.121507077706427, 0.1393442, 0.157305073932789, 0.176004666813668, - 0.196866917086671, 0.219796529731897, 0.247137200365254, - 0.280371254591746, 0.320842872758278, 0.374783454750148, - 0.461368597638526, 0.539683256474915, 0.632562403391324), - quantile_levels = c(0.01, 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, - 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, - 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", + 0.291229953951089, 0.341963643747938, 0.419747975311502, + 0.495994046054689, 0.5748791770223), quantile_levels = c(0.01, + 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, + 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, + 0.99)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", + "vctrs_vctr")), structure(list(values = c(0, 0, 0, 0, 0, 0.00603076915889168, + 0.0356039073625737, 0.0609470811194113, 0.0833232869645198, 0.103265350891109, + 0.121507077706427, 0.1393442, 0.157305073932789, 0.176004666813668, + 0.196866917086671, 0.219796529731897, 0.247137200365254, 0.280371254591746, + 0.320842872758278, 0.374783454750148, 0.461368597638526, 0.539683256474915, + 0.632562403391324), quantile_levels = c(0.01, 0.025, 0.05, 0.1, + 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, + 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( values = c(0, 0, 0, 0, 0, 0, 0.018869505399304, 0.0471517885822858, 0.0732707765908659, 0.0969223475714758, 0.118188509171441, @@ -650,7 +649,7 @@ 0.0562218087603375, 0.0890356919950198, 0.118731362266373, 0.146216910144001, 0.172533896645116, 0.1975426, 0.223021121504065, 0.249412654553045, 0.277680444480195, 0.308522683806638, 0.342270845449704, 0.382702709814398, - 0.433443929063141, 0.501610622734127, 0.61417580106326, 0.715138862353848, + 0.433443929063141, 0.501610622734127, 0.614175801063261, 0.715138862353848, 0.833535553075286), quantile_levels = c(0.01, 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", @@ -826,8 +825,8 @@ 0.147940700281253, 0.185518687303273, 0.220197034594646, 0.2521005, 0.282477641919719, 0.3121244, 0.3414694, 0.371435390499905, 0.402230766363414, 0.436173824348844, 0.474579164424894, - 0.519690345185252, 0.57667375206677, 0.655151246845668, 0.78520792902029, - 0.90968118047453, 1.05112182091783), quantile_levels = c(0.01, + 0.519690345185252, 0.576673752066771, 0.655151246845668, + 0.78520792902029, 0.90968118047453, 1.05112182091783), quantile_levels = c(0.01, 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", @@ -1008,14 +1007,14 @@ --- structure(list(geo_value = c("ca", "fl", "ga", "ny", "pa", "tx" - ), .pred = c(0.149303403634373, 0.139764664505948, 0.333186321066645, - 0.470345577837144, 0.725986105412008, 0.212686665274007), .pred_distn = structure(list( - structure(list(values = c(0.0961118191398634, 0.202494988128882 + ), .pred = c(0.149303403634372, 0.139764664505947, 0.333186321066645, + 0.470345577837143, 0.725986105412007, 0.212686665274007), .pred_distn = structure(list( + structure(list(values = c(0.0961118191398633, 0.202494988128882 ), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0.0865730800114383, 0.192956249000457), quantile_levels = c(0.05, + values = c(0.0865730800114382, 0.192956249000457), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", - "vctrs_vctr")), structure(list(values = c(0.279994736572136, + "vctrs_vctr")), structure(list(values = c(0.279994736572135, 0.386377905561154), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( values = c(0.417153993342634, 0.523537162331653), quantile_levels = c(0.05, @@ -1034,7 +1033,7 @@ --- structure(list(geo_value = c("ca", "fl", "ga", "ny", "pa", "tx" - ), .pred = c(0.303244704017742, 0.531332853311081, 0.58882794468598, + ), .pred = c(0.303244704017742, 0.531332853311081, 0.588827944685979, 0.98869024921623, 0.79480199700164, 0.306895457225321), .pred_distn = structure(list( structure(list(values = c(0.136509784083987, 0.469979623951498 ), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", @@ -1049,7 +1048,7 @@ "vctrs_vctr")), structure(list(values = c(0.628067077067884, 0.961536916935395), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0.140160537291566, 0.473630377159077), quantile_levels = c(0.05, + values = c(0.140160537291565, 0.473630377159077), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr"))), class = c("distribution", "vctrs_vctr", "list")), forecast_date = structure(c(18997, 18997, 18997, 18997, @@ -1060,7 +1059,7 @@ --- structure(list(geo_value = c("ca", "fl", "ga", "ny", "pa", "tx" - ), .pred = c(0.303244704017742, 0.531332853311081, 0.58882794468598, + ), .pred = c(0.303244704017742, 0.531332853311081, 0.588827944685979, 0.98869024921623, 0.79480199700164, 0.306895457225321), .pred_distn = structure(list( structure(list(values = c(0.136509784083987, 0.469979623951498 ), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", @@ -1075,7 +1074,7 @@ "vctrs_vctr")), structure(list(values = c(0.628067077067884, 0.961536916935395), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list( - values = c(0.140160537291566, 0.473630377159077), quantile_levels = c(0.05, + values = c(0.140160537291565, 0.473630377159077), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd", "vctrs_vctr"))), class = c("distribution", "vctrs_vctr", "list")), forecast_date = structure(c(18997, 18997, 18997, 18997, diff --git a/tests/testthat/test-arx_forecaster.R b/tests/testthat/test-arx_forecaster.R index 61016b63..0f2b9bd1 100644 --- a/tests/testthat/test-arx_forecaster.R +++ b/tests/testthat/test-arx_forecaster.R @@ -1,4 +1,4 @@ -train_data <- cases_deaths_subset +train_data <- epidatasets::cases_deaths_subset test_that("arx_forecaster warns if forecast date beyond the implicit one", { bad_date <- max(train_data$time_value) + 300 expect_warning( diff --git a/tests/testthat/test-check-training-set.R b/tests/testthat/test-check-training-set.R index ad891dd4..071a5ccf 100644 --- a/tests/testthat/test-check-training-set.R +++ b/tests/testthat/test-check-training-set.R @@ -1,5 +1,5 @@ test_that("training set validation works", { - template <- cases_deaths_subset[1, ] + template <- epidatasets::cases_deaths_subset[1, ] rec <- list(template = template) t1 <- template diff --git a/tests/testthat/test-population_scaling.R b/tests/testthat/test-population_scaling.R index e6971bc2..a1ccba4a 100644 --- a/tests/testthat/test-population_scaling.R +++ b/tests/testthat/test-population_scaling.R @@ -90,7 +90,7 @@ test_that("Number of columns and column names returned correctly, Upper and lowe ## Postprocessing test_that("Postprocessing workflow works and values correct", { - jhu <- cases_deaths_subset %>% + jhu <- epidatasets::cases_deaths_subset %>% dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ca", "ny")) %>% dplyr::select(geo_value, time_value, cases) diff --git a/tests/testthat/test-snapshots.R b/tests/testthat/test-snapshots.R index 6ded0d89..9766618b 100644 --- a/tests/testthat/test-snapshots.R +++ b/tests/testthat/test-snapshots.R @@ -1,4 +1,4 @@ -train_data <- cases_deaths_subset +train_data <- epidatasets::cases_deaths_subset expect_snapshot_tibble <- function(x) { expect_snapshot_value(x, style = "deparse", cran = FALSE) } diff --git a/tests/testthat/test-target_date_bug.R b/tests/testthat/test-target_date_bug.R index 02a82526..8be26641 100644 --- a/tests/testthat/test-target_date_bug.R +++ b/tests/testthat/test-target_date_bug.R @@ -2,7 +2,7 @@ # https://github.com/cmu-delphi/epipredict/issues/290 library(dplyr) -train <- cases_deaths_subset |> +train <- epidatasets::cases_deaths_subset |> filter(time_value >= as.Date("2021-10-01")) |> select(geo_value, time_value, cr = case_rate_7d_av, dr = death_rate_7d_av) ngeos <- n_distinct(train$geo_value)