Skip to content

Commit

Permalink
Revert "fixing quantile sorting problems, adding epidatasets"
Browse files Browse the repository at this point in the history
This reverts commit cfed37a.
dsweber2 committed Oct 24, 2024
1 parent cfed37a commit febc4dc
Showing 14 changed files with 42 additions and 52 deletions.
1 change: 0 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -49,7 +49,6 @@ Imports:
workflows (>= 1.0.0)
Suggests:
data.table,
epidatasets,
epidatr (>= 1.0.0),
fs,
grf,
2 changes: 1 addition & 1 deletion R/layer_population_scaling.R
Original file line number Diff line number Diff line change
@@ -48,7 +48,7 @@
#' @export
#' @examples
#' library(dplyr)
#' jhu <- epidatasets::cases_deaths_subset %>%
#' jhu <- cases_deaths_subset %>%
#' filter(time_value > "2021-11-01", geo_value %in% c("ca", "ny")) %>%
#' select(geo_value, time_value, cases)
#'
2 changes: 1 addition & 1 deletion R/make_grf_quantiles.R
Original file line number Diff line number Diff line change
@@ -165,7 +165,7 @@ make_grf_quantiles <- function() {

# turn the predictions into a tibble with a dist_quantiles column
process_qrf_preds <- function(x, object) {
quantile_levels <- parsnip::extract_fit_engine(object)$quantiles.orig %>% sort()
quantile_levels <- parsnip::extract_fit_engine(object)$quantiles.orig
x <- x$predictions
out <- lapply(vctrs::vec_chop(x), function(x) sort(drop(x)))
out <- dist_quantiles(out, list(quantile_levels))
2 changes: 1 addition & 1 deletion R/step_population_scaling.R
Original file line number Diff line number Diff line change
@@ -45,7 +45,7 @@
#' @export
#' @examples
#' library(dplyr)
#' jhu <- epidatasets::cases_deaths_subset %>%
#' jhu <- cases_deaths_subset %>%
#' filter(time_value > "2021-11-01", geo_value %in% c("ca", "ny")) %>%
#' select(geo_value, time_value, cases)
#'
2 changes: 1 addition & 1 deletion man/layer_population_scaling.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/step_population_scaling.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

59 changes: 30 additions & 29 deletions tests/testthat/_snaps/snapshots.md
Original file line number Diff line number Diff line change
@@ -154,7 +154,7 @@
0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85,
0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", "dist_default",
"vctrs_rcrd", "vctrs_vctr")), structure(list(values = c(0, 0,
0, 0, 0.016465765, 0.03549514, 0.05225675, 0.0644172, 0.0749343000000001,
0, 0, 0.016465765, 0.03549514, 0.05225675, 0.0644172, 0.0749343,
0.0847941, 0.0966258, 0.103199, 0.1097722, 0.1216039, 0.1314637,
0.1419808, 0.15414125, 0.17090286, 0.189932235, 0.22848398, 0.30542311,
0.40216399, 0.512353658), quantile_levels = c(0.01, 0.025, 0.05,
@@ -267,7 +267,7 @@
0.99)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd",
"vctrs_vctr")), structure(list(values = c(0, 0, 0, 0, 0.114729892920429,
0.227785958288583, 0.282278878729037, 0.320407599201492, 0.350577823459785,
0.376652303049231, 0.39981364198757, 0.4218461, 0.444009706175862,
0.37665230304923, 0.39981364198757, 0.4218461, 0.444009706175862,
0.466962725214852, 0.493098379685547, 0.523708407392674, 0.562100740111401,
0.619050517814778, 0.754868363055733, 1.1177263295869, 1.76277018354499,
2.37278671910076, 2.9651652434047), quantile_levels = c(0.01,
@@ -314,7 +314,7 @@
0.144337973117581, 0.250292371898569, 0.367310419323293, 0.44444044802193,
0.506592035751958, 0.558428768125431, 0.602035095628756, 0.64112383905529,
0.674354964141041, 0.703707875219752, 0.7319844, 0.760702196782168,
0.789758264058441, 0.823427572594726, 0.860294897090771, 0.904032120658957,
0.78975826405844, 0.823427572594726, 0.860294897090771, 0.904032120658957,
0.955736581115011, 1.0165945004053, 1.09529786576616, 1.21614421175967,
1.32331604019295, 1.45293812780298), quantile_levels = c(0.01,
0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5,
@@ -351,7 +351,7 @@
0.0497573816250162, 0.081255049503995, 0.108502307388674, 0.132961558931189,
0.156011650575706, 0.177125892134071, 0.1975426, 0.217737120618906,
0.239458499211792, 0.263562581820818, 0.289525383565136, 0.31824420000725,
0.351413051940519, 0.393862560773808, 0.453538799225292, 0.558631806850418,
0.35141305194052, 0.393862560773808, 0.453538799225292, 0.558631806850418,
0.657452391363313, 0.767918764883928), quantile_levels = c(0.01,
0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5,
0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99
@@ -412,19 +412,20 @@
0.0766736159703596, 0.0942284381264812, 0.11050757203172,
0.125214601455714, 0.1393442, 0.15359732398729, 0.168500447692877,
0.184551468093631, 0.202926420944109, 0.22476606802393, 0.253070223293233,
0.291229953951089, 0.341963643747938, 0.419747975311502,
0.495994046054689, 0.5748791770223), quantile_levels = c(0.01,
0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45,
0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975,
0.99)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd",
"vctrs_vctr")), structure(list(values = c(0, 0, 0, 0, 0, 0.00603076915889168,
0.0356039073625737, 0.0609470811194113, 0.0833232869645198, 0.103265350891109,
0.121507077706427, 0.1393442, 0.157305073932789, 0.176004666813668,
0.196866917086671, 0.219796529731897, 0.247137200365254, 0.280371254591746,
0.320842872758278, 0.374783454750148, 0.461368597638526, 0.539683256474915,
0.632562403391324), quantile_levels = c(0.01, 0.025, 0.05, 0.1,
0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65,
0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles",
0.29122995395109, 0.341963643747938, 0.419747975311502, 0.495994046054689,
0.5748791770223), quantile_levels = c(0.01, 0.025, 0.05,
0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6,
0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles",
"dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list(
values = c(0, 0, 0, 0, 0, 0.00603076915889168, 0.0356039073625737,
0.0609470811194113, 0.0833232869645198, 0.103265350891109,
0.121507077706427, 0.1393442, 0.157305073932789, 0.176004666813668,
0.196866917086671, 0.219796529731897, 0.247137200365254,
0.280371254591746, 0.320842872758278, 0.374783454750148,
0.461368597638526, 0.539683256474915, 0.632562403391324),
quantile_levels = c(0.01, 0.025, 0.05, 0.1, 0.15, 0.2, 0.25,
0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8,
0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles",
"dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list(
values = c(0, 0, 0, 0, 0, 0, 0.018869505399304, 0.0471517885822858,
0.0732707765908659, 0.0969223475714758, 0.118188509171441,
@@ -649,7 +650,7 @@
0.0562218087603375, 0.0890356919950198, 0.118731362266373, 0.146216910144001,
0.172533896645116, 0.1975426, 0.223021121504065, 0.249412654553045,
0.277680444480195, 0.308522683806638, 0.342270845449704, 0.382702709814398,
0.433443929063141, 0.501610622734127, 0.614175801063261, 0.715138862353848,
0.433443929063141, 0.501610622734127, 0.61417580106326, 0.715138862353848,
0.833535553075286), quantile_levels = c(0.01, 0.025, 0.05, 0.1,
0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65,
0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles",
@@ -825,8 +826,8 @@
0.147940700281253, 0.185518687303273, 0.220197034594646,
0.2521005, 0.282477641919719, 0.3121244, 0.3414694, 0.371435390499905,
0.402230766363414, 0.436173824348844, 0.474579164424894,
0.519690345185252, 0.576673752066771, 0.655151246845668,
0.78520792902029, 0.90968118047453, 1.05112182091783), quantile_levels = c(0.01,
0.519690345185252, 0.57667375206677, 0.655151246845668, 0.78520792902029,
0.90968118047453, 1.05112182091783), quantile_levels = c(0.01,
0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45,
0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975,
0.99)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd",
@@ -1007,14 +1008,14 @@
---

structure(list(geo_value = c("ca", "fl", "ga", "ny", "pa", "tx"
), .pred = c(0.149303403634372, 0.139764664505947, 0.333186321066645,
0.470345577837143, 0.725986105412007, 0.212686665274007), .pred_distn = structure(list(
structure(list(values = c(0.0961118191398633, 0.202494988128882
), .pred = c(0.149303403634373, 0.139764664505948, 0.333186321066645,
0.470345577837144, 0.725986105412008, 0.212686665274007), .pred_distn = structure(list(
structure(list(values = c(0.0961118191398634, 0.202494988128882
), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles",
"dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list(
values = c(0.0865730800114382, 0.192956249000457), quantile_levels = c(0.05,
values = c(0.0865730800114383, 0.192956249000457), quantile_levels = c(0.05,
0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd",
"vctrs_vctr")), structure(list(values = c(0.279994736572135,
"vctrs_vctr")), structure(list(values = c(0.279994736572136,
0.386377905561154), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles",
"dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list(
values = c(0.417153993342634, 0.523537162331653), quantile_levels = c(0.05,
@@ -1033,7 +1034,7 @@
---

structure(list(geo_value = c("ca", "fl", "ga", "ny", "pa", "tx"
), .pred = c(0.303244704017742, 0.531332853311081, 0.588827944685979,
), .pred = c(0.303244704017742, 0.531332853311081, 0.58882794468598,
0.98869024921623, 0.79480199700164, 0.306895457225321), .pred_distn = structure(list(
structure(list(values = c(0.136509784083987, 0.469979623951498
), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles",
@@ -1048,7 +1049,7 @@
"vctrs_vctr")), structure(list(values = c(0.628067077067884,
0.961536916935395), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles",
"dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list(
values = c(0.140160537291565, 0.473630377159077), quantile_levels = c(0.05,
values = c(0.140160537291566, 0.473630377159077), quantile_levels = c(0.05,
0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd",
"vctrs_vctr"))), class = c("distribution", "vctrs_vctr",
"list")), forecast_date = structure(c(18997, 18997, 18997, 18997,
@@ -1059,7 +1060,7 @@
---

structure(list(geo_value = c("ca", "fl", "ga", "ny", "pa", "tx"
), .pred = c(0.303244704017742, 0.531332853311081, 0.588827944685979,
), .pred = c(0.303244704017742, 0.531332853311081, 0.58882794468598,
0.98869024921623, 0.79480199700164, 0.306895457225321), .pred_distn = structure(list(
structure(list(values = c(0.136509784083987, 0.469979623951498
), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles",
@@ -1074,7 +1075,7 @@
"vctrs_vctr")), structure(list(values = c(0.628067077067884,
0.961536916935395), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles",
"dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list(
values = c(0.140160537291565, 0.473630377159077), quantile_levels = c(0.05,
values = c(0.140160537291566, 0.473630377159077), quantile_levels = c(0.05,
0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd",
"vctrs_vctr"))), class = c("distribution", "vctrs_vctr",
"list")), forecast_date = structure(c(18997, 18997, 18997, 18997,
4 changes: 2 additions & 2 deletions tests/testthat/_snaps/step_epi_slide.md
Original file line number Diff line number Diff line change
@@ -12,7 +12,7 @@
r %>% step_epi_slide(value, .f = mean, .window_size = c(3L, 6L))
Condition
Error in `epiprocess:::validate_slide_window_arg()`:
! Slide function expected `.window_size` to be a non-null, scalar integer >= 1.
! Slide function expected `.window_size` to be a length-1 difftime with units in days or non-negative integer or Inf.

---

@@ -60,7 +60,7 @@
r %>% step_epi_slide(value, .f = mean, .window_size = 1.5)
Condition
Error in `epiprocess:::validate_slide_window_arg()`:
! Slide function expected `.window_size` to be a difftime with units in days or non-negative integer or Inf.
! Slide function expected `.window_size` to be a length-1 difftime with units in days or non-negative integer or Inf.

---

2 changes: 1 addition & 1 deletion tests/testthat/test-arx_forecaster.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
train_data <- epidatasets::cases_deaths_subset
train_data <- cases_deaths_subset
test_that("arx_forecaster warns if forecast date beyond the implicit one", {
bad_date <- max(train_data$time_value) + 300
expect_warning(
2 changes: 1 addition & 1 deletion tests/testthat/test-check-training-set.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
test_that("training set validation works", {
template <- epidatasets::cases_deaths_subset[1, ]
template <- cases_deaths_subset[1, ]
rec <- list(template = template)
t1 <- template

10 changes: 0 additions & 10 deletions tests/testthat/test-grf_quantiles.R
Original file line number Diff line number Diff line change
@@ -50,13 +50,3 @@ test_that("quantile_rand_forest handles allows setting the trees and mtry", {
expect_identical(pars$quantiles.orig, manual$quantiles.orig)
expect_identical(pars$`_num_trees`, manual$`_num_trees`)
})

test_that("quantile_rand_forest predicts reasonable quantiles", {
spec <- rand_forest(mode = "regression") %>%
set_engine("grf_quantiles", quantiles = c(.2, .5, .8))
expect_silent(out <- fit(spec, formula = y ~ x + z, data = tib))
# swapping around the probabilities, because somehow this happens in practice,
# but I'm not sure how to reproduce
out$fit$quantiles.orig <- c(0.5,0.9, 0.1)
expect_no_error(predict(out, tib))
})
2 changes: 1 addition & 1 deletion tests/testthat/test-population_scaling.R
Original file line number Diff line number Diff line change
@@ -90,7 +90,7 @@ test_that("Number of columns and column names returned correctly, Upper and lowe

## Postprocessing
test_that("Postprocessing workflow works and values correct", {
jhu <- epidatasets::cases_deaths_subset %>%
jhu <- cases_deaths_subset %>%
dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ca", "ny")) %>%
dplyr::select(geo_value, time_value, cases)

2 changes: 1 addition & 1 deletion tests/testthat/test-snapshots.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
train_data <- epidatasets::cases_deaths_subset
train_data <- cases_deaths_subset
expect_snapshot_tibble <- function(x) {
expect_snapshot_value(x, style = "deparse", cran = FALSE)
}
2 changes: 1 addition & 1 deletion tests/testthat/test-target_date_bug.R
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@
# https://github.com/cmu-delphi/epipredict/issues/290

library(dplyr)
train <- epidatasets::cases_deaths_subset |>
train <- cases_deaths_subset |>
filter(time_value >= as.Date("2021-10-01")) |>
select(geo_value, time_value, cr = case_rate_7d_av, dr = death_rate_7d_av)
ngeos <- n_distinct(train$geo_value)

0 comments on commit febc4dc

Please sign in to comment.