Skip to content

Commit

Permalink
Merge branch 'main' into dhm-predicted-actual
Browse files Browse the repository at this point in the history
  • Loading branch information
dylanhmorris committed Dec 26, 2024
2 parents 0076b4f + 93eff87 commit 8c4c67e
Show file tree
Hide file tree
Showing 19 changed files with 177 additions and 141 deletions.
48 changes: 48 additions & 0 deletions hewr/tests/testthat/helper.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
create_tidy_forecast_data <- function(
directory, filename, date_cols, disease_cols, n_draw) {
data <- tidyr::expand_grid(
date = date_cols,
disease = disease_cols,
.draw = 1:n_draw
) |>
dplyr::mutate(.value = sample(1:100, dplyr::n(), replace = TRUE))
if (length(disease_cols) == 1) {
data <- data |>
dplyr::rename(!!disease_cols := ".value") |>
dplyr::select(-disease)
}
arrow::write_parquet(data, fs::path(directory, filename))
}

create_observation_data <- function(
date_cols, location_cols) {
data <- tidyr::expand_grid(
reference_date = date_cols,
location = location_cols
) |>
dplyr::mutate(value = sample(1:100, dplyr::n(), replace = TRUE))
return(data)
}

create_hubverse_table <- function(
date_cols, horizon, location, output_type, output_type_id) {
data <- tidyr::expand_grid(
reference_date = date_cols,
horizon = horizon,
output_type_id = output_type_id,
location = location
) |>
dplyr::group_by(reference_date, horizon, location) |>
dplyr::mutate(
value = sort(
sample(1:100, dplyr::n(), replace = TRUE),
decreasing = FALSE
),
target = "wk inc covid prop ed visits",
output_type = "quantile",
target_end_date = reference_date + 7 * horizon
) |>
dplyr::ungroup()

return(data)
}
84 changes: 84 additions & 0 deletions hewr/tests/testthat/test_score_hubverse.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
testthat::test_that("score_hubverse works as expected with valid inputs", {
forecast <- create_hubverse_table(
date_cols = seq(
lubridate::ymd("2023-11-01"), lubridate::ymd("2024-01-29"),
by = "day"
),
horizon = c(0, 1, 2),
location = c("loc1", "loc2"),
output_type = "quantile",
output_type_id = c(0.01, 0.025, seq(0.05, 0.95, 0.05), 0.975, 0.99)
)

observed <- create_observation_data(
date_cols = seq(
lubridate::ymd("2023-11-01"), lubridate::ymd("2024-01-29"),
by = "day"
),
location = c("loc1", "loc2")
)

scored <- score_hubverse(forecast, observed)
expect_setequal(forecast$location, scored$location)
expect_setequal(scored$horizon, c(0, 1))

scored_all_horizon <- score_hubverse(
forecast, observed,
horizons = c(0, 1, 2)
)
expect_setequal(forecast$location, scored_all_horizon$location)
expect_setequal(forecast$horizon, scored_all_horizon$horizon)
})


testthat::test_that("score_hubverse handles missing location data", {
forecast <- create_hubverse_table(
date_cols = seq(
lubridate::ymd("2024-11-01"), lubridate::ymd("2024-11-29"),
by = "day"
),
horizon = c(0, 1),
location = c("loc1", "loc2"),
output_type = "quantile",
output_type_id = seq(0.05, 0.95, 0.05)
)

observed <- create_observation_data(
date_cols = seq(
lubridate::ymd("2024-11-01"), lubridate::ymd("2024-11-29"),
by = "day"
),
location = c("loc1")
)

result <- score_hubverse(forecast, observed)
expect_false("loc2" %in% result$location)
expect_setequal(observed$location, result$location)
})


testthat::test_that("score_hubverse handles zero length forecast table", {
forecast <- tibble::tibble(
reference_date = as.Date(character(0)),
horizon = integer(0),
output_type_id = numeric(0),
location = character(0),
value = numeric(0),
target = character(0),
output_type = character(0),
target_end_date = as.Date(character(0))
)

observed <- create_observation_data(
date_cols = seq(
lubridate::ymd("2024-11-01"), lubridate::ymd("2024-11-02"),
by = "day"
),
location = c("loc1")
)

expect_error(
result <- score_hubverse(forecast, observed),
"Assertion on 'data' failed: Must have at least 1 rows, but has 0 rows."
)
})
24 changes: 4 additions & 20 deletions hewr/tests/testthat/test_to_epiweekly_quantile_table.R
Original file line number Diff line number Diff line change
@@ -1,27 +1,11 @@
create_forecast_data <- function(
directory, filename, date_cols, disease_cols, n_draw) {
data <- tidyr::expand_grid(
date = date_cols,
disease = disease_cols,
.draw = 1:n_draw
) |>
dplyr::mutate(.value = sample(1:100, dplyr::n(), replace = TRUE))
if (length(disease_cols) == 1) {
data <- data |>
dplyr::rename(!!disease_cols := ".value") |>
dplyr::select(-disease)
}
arrow::write_parquet(data, fs::path(directory, filename))
}

# tests for `to_epiweekly_quantile`
test_that("to_epiweekly_quantiles works as expected", {
# create temporary directories and forecast files for tests
temp_dir <- withr::local_tempdir("CA")
fs::dir_create(fs::path(temp_dir, "pyrenew_e"))
fs::dir_create(fs::path(temp_dir, "timeseries_e"))

create_forecast_data(
create_tidy_forecast_data(
directory = fs::path(temp_dir, "pyrenew_e"),
filename = "forecast_samples.parquet",
date_cols = seq(
Expand All @@ -32,7 +16,7 @@ test_that("to_epiweekly_quantiles works as expected", {
n_draw = 20
)

create_forecast_data(
create_tidy_forecast_data(
directory = fs::path(temp_dir, "timeseries_e"),
filename = "epiweekly_other_ed_visits_forecast.parquet",
date_cols = seq(
Expand Down Expand Up @@ -67,7 +51,7 @@ test_that("to_epiweekly_quantiles calculates quantiles accurately", {
temp_dir <- withr::local_tempdir("test")
fs::dir_create(fs::path(temp_dir, "pyrenew_e"))

create_forecast_data(
create_tidy_forecast_data(
directory = fs::path(temp_dir, "pyrenew_e"),
filename = "forecast_samples.parquet",
date_cols = seq(
Expand Down Expand Up @@ -152,7 +136,7 @@ test_that("to_epiweekly_quantile_table handles multiple locations", {
loc_dir <- fs::path(temp_batch_dir, "model_runs", loc, "pyrenew_e")
fs::dir_create(loc_dir)

create_forecast_data(
create_tidy_forecast_data(
directory = loc_dir,
filename = "forecast_samples.parquet",
date_cols = seq(
Expand Down
1 change: 1 addition & 0 deletions pipelines/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Pipelines
File renamed without changes.
68 changes: 0 additions & 68 deletions pipelines/default_priors.py

This file was deleted.

File renamed without changes.
4 changes: 3 additions & 1 deletion pipelines/fit_model.py → pipelines/fit_pyrenew_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

import jax
import numpy as np
from build_model import build_model_from_dir
from build_pyrenew_model import (
build_model_from_dir,
)


def fit_and_save_model(
Expand Down
24 changes: 13 additions & 11 deletions pipelines/forecast_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
import tomli_w
import tomllib
from prep_data import process_and_save_state
from prep_eval_data import save_eval_data
from pygit2 import Repository
from save_eval_data import save_eval_data

numpyro.set_host_device_count(4)

from fit_model import fit_and_save_model # noqa
from generate_predictive import generate_and_save_predictions # noqa
from fit_pyrenew_model import fit_and_save_model # noqa
from generate_predictive import (
generate_and_save_predictions,
) # noqa


def record_git_info(model_run_dir: Path):
Expand Down Expand Up @@ -125,13 +127,13 @@ def convert_inferencedata_to_parquet(
return None


def postprocess_forecast(
def plot_and_save_state_forecast(
model_run_dir: Path, pyrenew_model_name: str, timeseries_model_name: str
) -> None:
result = subprocess.run(
[
"Rscript",
"pipelines/postprocess_state_forecast.R",
"pipelines/plot_and_save_state_forecast.R",
f"{model_run_dir}",
"--pyrenew-model-name",
f"{pyrenew_model_name}",
Expand All @@ -141,7 +143,7 @@ def postprocess_forecast(
capture_output=True,
)
if result.returncode != 0:
raise RuntimeError(f"postprocess_forecast: {result.stderr}")
raise RuntimeError(f"plot_and_save_state_forecast: {result.stderr}")
return None


Expand All @@ -159,17 +161,17 @@ def score_forecast(model_run_dir: Path) -> None:
return None


def render_webpage(model_run_dir: Path) -> None:
def render_diagnostic_report(model_run_dir: Path) -> None:
result = subprocess.run(
[
"Rscript",
"pipelines/render_webpage.R",
"pipelines/diagnostic_report/render_diagnostic_report.R",
f"{model_run_dir}",
],
capture_output=True,
)
if result.returncode != 0:
raise RuntimeError(f"render_webpage: {result.stderr}")
raise RuntimeError(f"render_diagnostic_report: {result.stderr}")
return None


Expand Down Expand Up @@ -363,11 +365,11 @@ def main(
logger.info("Conversion complete.")

logger.info("Postprocessing forecast...")
postprocess_forecast(model_run_dir, "pyrenew_e", "timeseries_e")
plot_and_save_state_forecast(model_run_dir, "pyrenew_e", "timeseries_e")
logger.info("Postprocessing complete.")

logger.info("Rendering webpage...")
render_webpage(model_run_dir)
render_diagnostic_report(model_run_dir)
logger.info("Rendering complete.")

if score:
Expand Down
4 changes: 3 additions & 1 deletion pipelines/generate_predictive.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from pathlib import Path

import arviz as az
from build_model import build_model_from_dir
from build_pyrenew_model import (
build_model_from_dir,
)


def generate_and_save_predictions(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@

import epiweeks
import polars as pl
from prep_data import aggregate_facility_level_nssp_to_state, get_state_pop_df
from prep_data import (
aggregate_facility_level_nssp_to_state,
get_state_pop_df,
)


def save_observed_data_tables(
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Loading

0 comments on commit 8c4c67e

Please sign in to comment.