Skip to content

Commit

Permalink
add tests for score_hubverse.R (#264)
Browse files Browse the repository at this point in the history
  • Loading branch information
sbidari authored Dec 23, 2024
1 parent 1d71aca commit e7758c9
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 20 deletions.
48 changes: 48 additions & 0 deletions hewr/tests/testthat/helper.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
create_tidy_forecast_data <- function(
directory, filename, date_cols, disease_cols, n_draw) {
data <- tidyr::expand_grid(
date = date_cols,
disease = disease_cols,
.draw = 1:n_draw
) |>
dplyr::mutate(.value = sample(1:100, dplyr::n(), replace = TRUE))
if (length(disease_cols) == 1) {
data <- data |>
dplyr::rename(!!disease_cols := ".value") |>
dplyr::select(-disease)
}
arrow::write_parquet(data, fs::path(directory, filename))
}

create_observation_data <- function(
date_cols, location_cols) {
data <- tidyr::expand_grid(
reference_date = date_cols,
location = location_cols
) |>
dplyr::mutate(value = sample(1:100, dplyr::n(), replace = TRUE))
return(data)
}

create_hubverse_table <- function(
date_cols, horizon, location, output_type, output_type_id) {
data <- tidyr::expand_grid(
reference_date = date_cols,
horizon = horizon,
output_type_id = output_type_id,
location = location
) |>
dplyr::group_by(reference_date, horizon, location) |>
dplyr::mutate(
value = sort(
sample(1:100, dplyr::n(), replace = TRUE),
decreasing = FALSE
),
target = "wk inc covid prop ed visits",
output_type = "quantile",
target_end_date = reference_date + 7 * horizon
) |>
dplyr::ungroup()

return(data)
}
84 changes: 84 additions & 0 deletions hewr/tests/testthat/test_score_hubverse.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
testthat::test_that("score_hubverse works as expected with valid inputs", {
forecast <- create_hubverse_table(
date_cols = seq(
lubridate::ymd("2023-11-01"), lubridate::ymd("2024-01-29"),
by = "day"
),
horizon = c(0, 1, 2),
location = c("loc1", "loc2"),
output_type = "quantile",
output_type_id = c(0.01, 0.025, seq(0.05, 0.95, 0.05), 0.975, 0.99)
)

observed <- create_observation_data(
date_cols = seq(
lubridate::ymd("2023-11-01"), lubridate::ymd("2024-01-29"),
by = "day"
),
location = c("loc1", "loc2")
)

scored <- score_hubverse(forecast, observed)
expect_setequal(forecast$location, scored$location)
expect_setequal(scored$horizon, c(0, 1))

scored_all_horizon <- score_hubverse(
forecast, observed,
horizons = c(0, 1, 2)
)
expect_setequal(forecast$location, scored_all_horizon$location)
expect_setequal(forecast$horizon, scored_all_horizon$horizon)
})


testthat::test_that("score_hubverse handles missing location data", {
forecast <- create_hubverse_table(
date_cols = seq(
lubridate::ymd("2024-11-01"), lubridate::ymd("2024-11-29"),
by = "day"
),
horizon = c(0, 1),
location = c("loc1", "loc2"),
output_type = "quantile",
output_type_id = seq(0.05, 0.95, 0.05)
)

observed <- create_observation_data(
date_cols = seq(
lubridate::ymd("2024-11-01"), lubridate::ymd("2024-11-29"),
by = "day"
),
location = c("loc1")
)

result <- score_hubverse(forecast, observed)
expect_false("loc2" %in% result$location)
expect_setequal(observed$location, result$location)
})


testthat::test_that("score_hubverse handles zero length forecast table", {
forecast <- tibble::tibble(
reference_date = as.Date(character(0)),
horizon = integer(0),
output_type_id = numeric(0),
location = character(0),
value = numeric(0),
target = character(0),
output_type = character(0),
target_end_date = as.Date(character(0))
)

observed <- create_observation_data(
date_cols = seq(
lubridate::ymd("2024-11-01"), lubridate::ymd("2024-11-02"),
by = "day"
),
location = c("loc1")
)

expect_error(
result <- score_hubverse(forecast, observed),
"Assertion on 'data' failed: Must have at least 1 rows, but has 0 rows."
)
})
24 changes: 4 additions & 20 deletions hewr/tests/testthat/test_to_epiweekly_quantile_table.R
Original file line number Diff line number Diff line change
@@ -1,27 +1,11 @@
create_forecast_data <- function(
directory, filename, date_cols, disease_cols, n_draw) {
data <- tidyr::expand_grid(
date = date_cols,
disease = disease_cols,
.draw = 1:n_draw
) |>
dplyr::mutate(.value = sample(1:100, dplyr::n(), replace = TRUE))
if (length(disease_cols) == 1) {
data <- data |>
dplyr::rename(!!disease_cols := ".value") |>
dplyr::select(-disease)
}
arrow::write_parquet(data, fs::path(directory, filename))
}

# tests for `to_epiweekly_quantile`
test_that("to_epiweekly_quantiles works as expected", {
# create temporary directories and forecast files for tests
temp_dir <- withr::local_tempdir("CA")
fs::dir_create(fs::path(temp_dir, "pyrenew_e"))
fs::dir_create(fs::path(temp_dir, "timeseries_e"))

create_forecast_data(
create_tidy_forecast_data(
directory = fs::path(temp_dir, "pyrenew_e"),
filename = "forecast_samples.parquet",
date_cols = seq(
Expand All @@ -32,7 +16,7 @@ test_that("to_epiweekly_quantiles works as expected", {
n_draw = 20
)

create_forecast_data(
create_tidy_forecast_data(
directory = fs::path(temp_dir, "timeseries_e"),
filename = "epiweekly_other_ed_visits_forecast.parquet",
date_cols = seq(
Expand Down Expand Up @@ -67,7 +51,7 @@ test_that("to_epiweekly_quantiles calculates quantiles accurately", {
temp_dir <- withr::local_tempdir("test")
fs::dir_create(fs::path(temp_dir, "pyrenew_e"))

create_forecast_data(
create_tidy_forecast_data(
directory = fs::path(temp_dir, "pyrenew_e"),
filename = "forecast_samples.parquet",
date_cols = seq(
Expand Down Expand Up @@ -152,7 +136,7 @@ test_that("to_epiweekly_quantile_table handles multiple locations", {
loc_dir <- fs::path(temp_batch_dir, "model_runs", loc, "pyrenew_e")
fs::dir_create(loc_dir)

create_forecast_data(
create_tidy_forecast_data(
directory = loc_dir,
filename = "forecast_samples.parquet",
date_cols = seq(
Expand Down

0 comments on commit e7758c9

Please sign in to comment.