Skip to content

Commit

Permalink
Merge branch 'dev' into grf-arx-hotfix
Browse files Browse the repository at this point in the history
  • Loading branch information
dajmcdon authored Oct 26, 2024
2 parents b929e47 + 2126db5 commit 202202c
Show file tree
Hide file tree
Showing 13 changed files with 42 additions and 41 deletions.
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: epipredict
Title: Basic epidemiology forecasting methods
Version: 0.1.2
Version: 0.1.3
Authors@R: c(
person("Daniel J.", "McDonald", , "[email protected]", role = c("aut", "cre")),
person("Ryan", "Tibshirani", , "[email protected]", role = "aut"),
Expand Down Expand Up @@ -49,6 +49,7 @@ Imports:
workflows (>= 1.0.0)
Suggests:
data.table,
epidatasets,
epidatr (>= 1.0.0),
fs,
grf,
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Pre-1.0.0 numbering scheme: 0.x will indicate releases, while 0.0.x will indicat

## bugfixes
- shifting no columns results in no error for either `step_epi_ahead` and `step_epi_lag`
- Quantiles produced by `grf` were sometimes out of order.

# epipredict 0.1

Expand Down
2 changes: 1 addition & 1 deletion R/layer_population_scaling.R
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
#' @export
#' @examples
#' library(dplyr)
#' jhu <- cases_deaths_subset %>%
#' jhu <- epidatasets::cases_deaths_subset %>%
#' filter(time_value > "2021-11-01", geo_value %in% c("ca", "ny")) %>%
#' select(geo_value, time_value, cases)
#'
Expand Down
2 changes: 1 addition & 1 deletion R/make_grf_quantiles.R
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ make_grf_quantiles <- function() {

# turn the predictions into a tibble with a dist_quantiles column
process_qrf_preds <- function(x, object) {
quantile_levels <- parsnip::extract_fit_engine(object)$quantiles.orig
quantile_levels <- parsnip::extract_fit_engine(object)$quantiles.orig %>% sort()
x <- x$predictions
out <- lapply(vctrs::vec_chop(x), function(x) sort(drop(x)))
out <- dist_quantiles(out, list(quantile_levels))
Expand Down
2 changes: 1 addition & 1 deletion R/step_population_scaling.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
#' @export
#' @examples
#' library(dplyr)
#' jhu <- cases_deaths_subset %>%
#' jhu <- epidatasets::cases_deaths_subset %>%
#' filter(time_value > "2021-11-01", geo_value %in% c("ca", "ny")) %>%
#' select(geo_value, time_value, cases)
#'
Expand Down
2 changes: 1 addition & 1 deletion man/layer_population_scaling.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/step_population_scaling.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

59 changes: 29 additions & 30 deletions tests/testthat/_snaps/snapshots.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@
0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85,
0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", "dist_default",
"vctrs_rcrd", "vctrs_vctr")), structure(list(values = c(0, 0,
0, 0, 0.016465765, 0.03549514, 0.05225675, 0.0644172, 0.0749343,
0, 0, 0.016465765, 0.03549514, 0.05225675, 0.0644172, 0.0749343000000001,
0.0847941, 0.0966258, 0.103199, 0.1097722, 0.1216039, 0.1314637,
0.1419808, 0.15414125, 0.17090286, 0.189932235, 0.22848398, 0.30542311,
0.40216399, 0.512353658), quantile_levels = c(0.01, 0.025, 0.05,
Expand Down Expand Up @@ -267,7 +267,7 @@
0.99)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd",
"vctrs_vctr")), structure(list(values = c(0, 0, 0, 0, 0.114729892920429,
0.227785958288583, 0.282278878729037, 0.320407599201492, 0.350577823459785,
0.37665230304923, 0.39981364198757, 0.4218461, 0.444009706175862,
0.376652303049231, 0.39981364198757, 0.4218461, 0.444009706175862,
0.466962725214852, 0.493098379685547, 0.523708407392674, 0.562100740111401,
0.619050517814778, 0.754868363055733, 1.1177263295869, 1.76277018354499,
2.37278671910076, 2.9651652434047), quantile_levels = c(0.01,
Expand Down Expand Up @@ -314,7 +314,7 @@
0.144337973117581, 0.250292371898569, 0.367310419323293, 0.44444044802193,
0.506592035751958, 0.558428768125431, 0.602035095628756, 0.64112383905529,
0.674354964141041, 0.703707875219752, 0.7319844, 0.760702196782168,
0.78975826405844, 0.823427572594726, 0.860294897090771, 0.904032120658957,
0.789758264058441, 0.823427572594726, 0.860294897090771, 0.904032120658957,
0.955736581115011, 1.0165945004053, 1.09529786576616, 1.21614421175967,
1.32331604019295, 1.45293812780298), quantile_levels = c(0.01,
0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5,
Expand Down Expand Up @@ -351,7 +351,7 @@
0.0497573816250162, 0.081255049503995, 0.108502307388674, 0.132961558931189,
0.156011650575706, 0.177125892134071, 0.1975426, 0.217737120618906,
0.239458499211792, 0.263562581820818, 0.289525383565136, 0.31824420000725,
0.35141305194052, 0.393862560773808, 0.453538799225292, 0.558631806850418,
0.351413051940519, 0.393862560773808, 0.453538799225292, 0.558631806850418,
0.657452391363313, 0.767918764883928), quantile_levels = c(0.01,
0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5,
0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99
Expand Down Expand Up @@ -412,20 +412,19 @@
0.0766736159703596, 0.0942284381264812, 0.11050757203172,
0.125214601455714, 0.1393442, 0.15359732398729, 0.168500447692877,
0.184551468093631, 0.202926420944109, 0.22476606802393, 0.253070223293233,
0.29122995395109, 0.341963643747938, 0.419747975311502, 0.495994046054689,
0.5748791770223), quantile_levels = c(0.01, 0.025, 0.05,
0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6,
0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles",
"dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list(
values = c(0, 0, 0, 0, 0, 0.00603076915889168, 0.0356039073625737,
0.0609470811194113, 0.0833232869645198, 0.103265350891109,
0.121507077706427, 0.1393442, 0.157305073932789, 0.176004666813668,
0.196866917086671, 0.219796529731897, 0.247137200365254,
0.280371254591746, 0.320842872758278, 0.374783454750148,
0.461368597638526, 0.539683256474915, 0.632562403391324),
quantile_levels = c(0.01, 0.025, 0.05, 0.1, 0.15, 0.2, 0.25,
0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8,
0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles",
0.291229953951089, 0.341963643747938, 0.419747975311502,
0.495994046054689, 0.5748791770223), quantile_levels = c(0.01,
0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45,
0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975,
0.99)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd",
"vctrs_vctr")), structure(list(values = c(0, 0, 0, 0, 0, 0.00603076915889168,
0.0356039073625737, 0.0609470811194113, 0.0833232869645198, 0.103265350891109,
0.121507077706427, 0.1393442, 0.157305073932789, 0.176004666813668,
0.196866917086671, 0.219796529731897, 0.247137200365254, 0.280371254591746,
0.320842872758278, 0.374783454750148, 0.461368597638526, 0.539683256474915,
0.632562403391324), quantile_levels = c(0.01, 0.025, 0.05, 0.1,
0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65,
0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles",
"dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list(
values = c(0, 0, 0, 0, 0, 0, 0.018869505399304, 0.0471517885822858,
0.0732707765908659, 0.0969223475714758, 0.118188509171441,
Expand Down Expand Up @@ -650,7 +649,7 @@
0.0562218087603375, 0.0890356919950198, 0.118731362266373, 0.146216910144001,
0.172533896645116, 0.1975426, 0.223021121504065, 0.249412654553045,
0.277680444480195, 0.308522683806638, 0.342270845449704, 0.382702709814398,
0.433443929063141, 0.501610622734127, 0.61417580106326, 0.715138862353848,
0.433443929063141, 0.501610622734127, 0.614175801063261, 0.715138862353848,
0.833535553075286), quantile_levels = c(0.01, 0.025, 0.05, 0.1,
0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65,
0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles",
Expand Down Expand Up @@ -826,8 +825,8 @@
0.147940700281253, 0.185518687303273, 0.220197034594646,
0.2521005, 0.282477641919719, 0.3121244, 0.3414694, 0.371435390499905,
0.402230766363414, 0.436173824348844, 0.474579164424894,
0.519690345185252, 0.57667375206677, 0.655151246845668, 0.78520792902029,
0.90968118047453, 1.05112182091783), quantile_levels = c(0.01,
0.519690345185252, 0.576673752066771, 0.655151246845668,
0.78520792902029, 0.90968118047453, 1.05112182091783), quantile_levels = c(0.01,
0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45,
0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975,
0.99)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd",
Expand Down Expand Up @@ -1008,14 +1007,14 @@
---

structure(list(geo_value = c("ca", "fl", "ga", "ny", "pa", "tx"
), .pred = c(0.149303403634373, 0.139764664505948, 0.333186321066645,
0.470345577837144, 0.725986105412008, 0.212686665274007), .pred_distn = structure(list(
structure(list(values = c(0.0961118191398634, 0.202494988128882
), .pred = c(0.149303403634372, 0.139764664505947, 0.333186321066645,
0.470345577837143, 0.725986105412007, 0.212686665274007), .pred_distn = structure(list(
structure(list(values = c(0.0961118191398633, 0.202494988128882
), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles",
"dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list(
values = c(0.0865730800114383, 0.192956249000457), quantile_levels = c(0.05,
values = c(0.0865730800114382, 0.192956249000457), quantile_levels = c(0.05,
0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd",
"vctrs_vctr")), structure(list(values = c(0.279994736572136,
"vctrs_vctr")), structure(list(values = c(0.279994736572135,
0.386377905561154), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles",
"dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list(
values = c(0.417153993342634, 0.523537162331653), quantile_levels = c(0.05,
Expand All @@ -1034,7 +1033,7 @@
---

structure(list(geo_value = c("ca", "fl", "ga", "ny", "pa", "tx"
), .pred = c(0.303244704017742, 0.531332853311081, 0.58882794468598,
), .pred = c(0.303244704017742, 0.531332853311081, 0.588827944685979,
0.98869024921623, 0.79480199700164, 0.306895457225321), .pred_distn = structure(list(
structure(list(values = c(0.136509784083987, 0.469979623951498
), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles",
Expand All @@ -1049,7 +1048,7 @@
"vctrs_vctr")), structure(list(values = c(0.628067077067884,
0.961536916935395), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles",
"dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list(
values = c(0.140160537291566, 0.473630377159077), quantile_levels = c(0.05,
values = c(0.140160537291565, 0.473630377159077), quantile_levels = c(0.05,
0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd",
"vctrs_vctr"))), class = c("distribution", "vctrs_vctr",
"list")), forecast_date = structure(c(18997, 18997, 18997, 18997,
Expand All @@ -1060,7 +1059,7 @@
---

structure(list(geo_value = c("ca", "fl", "ga", "ny", "pa", "tx"
), .pred = c(0.303244704017742, 0.531332853311081, 0.58882794468598,
), .pred = c(0.303244704017742, 0.531332853311081, 0.588827944685979,
0.98869024921623, 0.79480199700164, 0.306895457225321), .pred_distn = structure(list(
structure(list(values = c(0.136509784083987, 0.469979623951498
), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles",
Expand All @@ -1075,7 +1074,7 @@
"vctrs_vctr")), structure(list(values = c(0.628067077067884,
0.961536916935395), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles",
"dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list(
values = c(0.140160537291566, 0.473630377159077), quantile_levels = c(0.05,
values = c(0.140160537291565, 0.473630377159077), quantile_levels = c(0.05,
0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd",
"vctrs_vctr"))), class = c("distribution", "vctrs_vctr",
"list")), forecast_date = structure(c(18997, 18997, 18997, 18997,
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-arx_forecaster.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
train_data <- cases_deaths_subset
train_data <- epidatasets::cases_deaths_subset
test_that("arx_forecaster warns if forecast date beyond the implicit one", {
bad_date <- max(train_data$time_value) + 300
expect_warning(
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-check-training-set.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
test_that("training set validation works", {
template <- cases_deaths_subset[1, ]
template <- epidatasets::cases_deaths_subset[1, ]
rec <- list(template = template)
t1 <- template

Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-population_scaling.R
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ test_that("Number of columns and column names returned correctly, Upper and lowe

## Postprocessing
test_that("Postprocessing workflow works and values correct", {
jhu <- cases_deaths_subset %>%
jhu <- epidatasets::cases_deaths_subset %>%
dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ca", "ny")) %>%
dplyr::select(geo_value, time_value, cases)

Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-snapshots.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
train_data <- cases_deaths_subset
train_data <- epidatasets::cases_deaths_subset
expect_snapshot_tibble <- function(x) {
expect_snapshot_value(x, style = "deparse", cran = FALSE)
}
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-target_date_bug.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# https://github.com/cmu-delphi/epipredict/issues/290

library(dplyr)
train <- cases_deaths_subset |>
train <- epidatasets::cases_deaths_subset |>
filter(time_value >= as.Date("2021-10-01")) |>
select(geo_value, time_value, cr = case_rate_7d_av, dr = death_rate_7d_av)
ngeos <- n_distinct(train$geo_value)
Expand Down

0 comments on commit 202202c

Please sign in to comment.