From a45085cc788ee756b9bec256a16eb25bf95e6301 Mon Sep 17 00:00:00 2001 From: Damon Bayer Date: Mon, 16 Dec 2024 14:44:00 -0600 Subject: [PATCH 1/4] Track git commit id when fitting models (#242) * user git commit sha as build arg for container * change SHA reference * record git info * add branch name to container * extract branch name in correct github action * try again * minor tweaks * add debug step * more debugging * better handling when no git repo is present * Remove debug steps --- .github/workflows/containers.yaml | 5 ++++- Containerfile | 5 +++++ pipelines/forecast_state.py | 26 ++++++++++++++++++++++++++ 3 files changed, 35 insertions(+), 1 deletion(-) diff --git a/.github/workflows/containers.yaml b/.github/workflows/containers.yaml index 0ba0d5ba..bf91e07e 100644 --- a/.github/workflows/containers.yaml +++ b/.github/workflows/containers.yaml @@ -20,6 +20,7 @@ jobs: outputs: tag: ${{ steps.image-tag.outputs.tag }} commit-msg: ${{ steps.commit-message.outputs.message }} + branch: ${{ steps.branch-name.outputs.branch }} steps: @@ -115,7 +116,9 @@ jobs: with: push: true # This can be toggled manually for tweaking. tags: | - ${{ env.REGISTRY}}/${{ env.IMAGE_NAME }}:${{ needs.build-dependencies-image.outputs.tag }} + ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ needs.build-dependencies-image.outputs.tag }} file: ./Containerfile build-args: | TAG=${{ needs.build-dependencies-image.outputs.tag }} + GIT_COMMIT_SHA=${{ github.event.pull_request.head.sha || github.sha }} + GIT_BRANCH_NAME=${{ needs.build-dependencies-image.outputs.branch }} diff --git a/Containerfile b/Containerfile index f627bf32..20b98455 100644 --- a/Containerfile +++ b/Containerfile @@ -1,6 +1,11 @@ ARG TAG=latest FROM cfaprdbatchcr.azurecr.io/pyrenew-hew-dependencies:${TAG} +ARG GIT_COMMIT_SHA +ENV GIT_COMMIT_SHA=$GIT_COMMIT_SHA + +ARG GIT_BRANCH_NAME +ENV GIT_BRANCH_NAME=$GIT_BRANCH_NAME COPY ./hewr /pyrenew-hew/hewr diff --git a/pipelines/forecast_state.py b/pipelines/forecast_state.py index c5e8d445..c5c4442d 100644 --- a/pipelines/forecast_state.py +++ b/pipelines/forecast_state.py @@ -8,7 +8,9 @@ import numpyro import polars as pl +import yaml from prep_data import process_and_save_state +from pygit2 import Repository from save_eval_data import save_eval_data numpyro.set_host_device_count(4) @@ -17,6 +19,27 @@ from generate_predictive import generate_and_save_predictions # noqa +def record_git_info(model_run_dir: Path): + metadata_file = Path(model_run_dir, "metadata.yaml") + try: + repo = Repository(os.getcwd()) + branch_name = os.environ.get( + "GIT_BRANCH_NAME", Path(repo.head.name).stem + ) + commit_sha = os.environ.get("GIT_COMMIT_SHA", str(repo.head.target)) + except: + branch_name = os.environ.get("GIT_BRANCH_NAME", "unknown") + commit_sha = os.environ.get("GIT_COMMIT_SHA", "unknown") + + metadata = { + "branch_name": branch_name, + "commit_sha": commit_sha, + } + + with open(metadata_file, "w") as file: + yaml.dump(metadata, file) + + def generate_epiweekly(model_run_dir: Path) -> None: result = subprocess.run( [ @@ -238,6 +261,9 @@ def main( os.makedirs(model_run_dir, exist_ok=True) + logger.info("Recording git info...") + record_git_info(model_run_dir) + logger.info(f"Using priors from {priors_path}...") shutil.copyfile(priors_path, Path(model_run_dir, "priors.py")) From 2927b03331dc071032e5ab7f752801cb16097b31 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Tue, 17 Dec 2024 08:52:46 -0500 Subject: [PATCH 2/4] Depend on `ensure_listlike` from `forecasttools` (#252) --- pipelines/collate_plots.py | 3 ++- pipelines/utils.py | 28 ---------------------------- 2 files changed, 2 insertions(+), 29 deletions(-) diff --git a/pipelines/collate_plots.py b/pipelines/collate_plots.py index 3f396d87..deda97d7 100644 --- a/pipelines/collate_plots.py +++ b/pipelines/collate_plots.py @@ -5,8 +5,9 @@ import os from pathlib import Path +from forecasttools import ensure_listlike from pypdf import PdfWriter -from utils import ensure_listlike, get_all_forecast_dirs +from utils import get_all_forecast_dirs def merge_pdfs_and_save( diff --git a/pipelines/utils.py b/pipelines/utils.py index c18215cc..9504cd86 100644 --- a/pipelines/utils.py +++ b/pipelines/utils.py @@ -6,39 +6,11 @@ import datetime import os import re -from collections.abc import MutableSequence from pathlib import Path disease_map_lower_ = {"influenza": "Influenza", "covid-19": "COVID-19"} -def ensure_listlike(x): - """ - Ensure that an object either behaves like a - :class:`MutableSequence` and if not return a - one-item :class:`list` containing the object. - - Useful for handling list-of-strings inputs - alongside single strings. - - Based on this _`StackOverflow approach - `. - - Parameters - ---------- - x - The item to ensure is :class:`list`-like. - - Returns - ------- - MutableSequence - ``x`` if ``x`` is a :class:`MutableSequence` - otherwise ``[x]`` (i.e. a one-item list containing - ``x``. - """ - return x if isinstance(x, MutableSequence) else [x] - - def parse_model_batch_dir_name(model_batch_dir_name): """ Parse the name of a model batch directory, From 98842041f8df80741ddd8d7daf4f60ff02ae9f0b Mon Sep 17 00:00:00 2001 From: Damon Bayer Date: Tue, 17 Dec 2024 13:15:57 -0600 Subject: [PATCH 3/4] Add priors to git repo and adjust container build (#251) * add priors * more robust metadata file handling and save priors to metadata file * correct path to metadata * try changing the container file for fast builds when priors change * use shorthand for local branch name extraction * remove redundant code * fix copy paste error * try allowing new syntax * switch from yaml to toml * pre-commit * correction for test job args --- Containerfile | 6 +- pipelines/batch/setup_prod_job.py | 2 +- pipelines/batch/setup_test_prod_job.py | 4 +- pipelines/forecast_state.py | 53 +++++++++++---- pipelines/priors/eval_priors.py | 68 +++++++++++++++++++ .../priors/parameter_inference_priors.py | 68 +++++++++++++++++++ pipelines/priors/prod_priors.py | 68 +++++++++++++++++++ pyproject.toml | 1 + 8 files changed, 254 insertions(+), 16 deletions(-) create mode 100644 pipelines/priors/eval_priors.py create mode 100644 pipelines/priors/parameter_inference_priors.py create mode 100644 pipelines/priors/prod_priors.py diff --git a/Containerfile b/Containerfile index 20b98455..5b18f0c7 100644 --- a/Containerfile +++ b/Containerfile @@ -1,3 +1,4 @@ +#syntax=docker/dockerfile:1.7-labs ARG TAG=latest FROM cfaprdbatchcr.azurecr.io/pyrenew-hew-dependencies:${TAG} @@ -18,5 +19,8 @@ RUN Rscript -e "pak::pkg_install('cmu-delphi/epiprocess@main')" RUN Rscript -e "pak::pkg_install('cmu-delphi/epipredict@main')" RUN Rscript -e "pak::local_install('hewr')" -COPY . . +COPY --exclude=pipelines/priors . . RUN pip install --root-user-action=ignore . + +# Copy priors folder last +COPY pipelines/priors pipelines/priors diff --git a/pipelines/batch/setup_prod_job.py b/pipelines/batch/setup_prod_job.py index 414d3d1f..2b8040f7 100644 --- a/pipelines/batch/setup_prod_job.py +++ b/pipelines/batch/setup_prod_job.py @@ -167,7 +167,7 @@ def main( "nssp-archival-vintages/gold " "--param-data-dir params " "--output-dir {output_dir} " - "--priors-path config/prod_priors.py " + "--priors-path pipelines/priors/prod_priors.py " "--report-date {report_date} " "--exclude-last-n-days {exclude_last_n_days} " "--no-score " diff --git a/pipelines/batch/setup_test_prod_job.py b/pipelines/batch/setup_test_prod_job.py index 821d3b4f..446bad63 100644 --- a/pipelines/batch/setup_test_prod_job.py +++ b/pipelines/batch/setup_test_prod_job.py @@ -21,7 +21,7 @@ "--tag", type=str, help="The tag name to use for the container image version", - default=Path(Repository(os.getcwd()).head.name).stem, + default=Repository(os.getcwd()).head.shorthand, ) args = parser.parse_args() @@ -95,6 +95,6 @@ diseases=["COVID-19", "Influenza"], container_image_name="pyrenew-hew", container_image_version=tag, - excluded_locations=locs_to_exclude, + locations_exclude=locs_to_exclude, test=True, ) diff --git a/pipelines/forecast_state.py b/pipelines/forecast_state.py index c5c4442d..97f3946f 100644 --- a/pipelines/forecast_state.py +++ b/pipelines/forecast_state.py @@ -8,7 +8,8 @@ import numpyro import polars as pl -import yaml +import tomli_w +import tomllib from prep_data import process_and_save_state from pygit2 import Repository from save_eval_data import save_eval_data @@ -20,24 +21,52 @@ def record_git_info(model_run_dir: Path): - metadata_file = Path(model_run_dir, "metadata.yaml") + metadata_file = Path(model_run_dir, "metadata.toml") + + if metadata_file.exists(): + with open(metadata_file, "rb") as file: + metadata = tomllib.load(file) + else: + metadata = {} + try: repo = Repository(os.getcwd()) - branch_name = os.environ.get( - "GIT_BRANCH_NAME", Path(repo.head.name).stem - ) - commit_sha = os.environ.get("GIT_COMMIT_SHA", str(repo.head.target)) - except: + branch_name = repo.head.shorthand + commit_sha = str(repo.head.target) + except Exception as e: branch_name = os.environ.get("GIT_BRANCH_NAME", "unknown") commit_sha = os.environ.get("GIT_COMMIT_SHA", "unknown") - metadata = { + new_metadata = { "branch_name": branch_name, "commit_sha": commit_sha, } - with open(metadata_file, "w") as file: - yaml.dump(metadata, file) + metadata.update(new_metadata) + + metadata_file.parent.mkdir(parents=True, exist_ok=True) + with open(metadata_file, "wb") as file: + tomli_w.dump(metadata, file) + + +def copy_and_record_priors(priors_path: Path, model_run_dir: Path): + metadata_file = Path(model_run_dir, "metadata.toml") + shutil.copyfile(priors_path, Path(model_run_dir, "priors.py")) + + if metadata_file.exists(): + with open(metadata_file, "rb") as file: + metadata = tomllib.load(file) + else: + metadata = {} + + new_metadata = { + "priors_path": str(priors_path), + } + + metadata.update(new_metadata) + + with open(metadata_file, "wb") as file: + tomli_w.dump(metadata, file) def generate_epiweekly(model_run_dir: Path) -> None: @@ -264,8 +293,8 @@ def main( logger.info("Recording git info...") record_git_info(model_run_dir) - logger.info(f"Using priors from {priors_path}...") - shutil.copyfile(priors_path, Path(model_run_dir, "priors.py")) + logger.info(f"Copying and recording priors from {priors_path}...") + copy_and_record_priors(priors_path, model_run_dir) logger.info(f"Processing {state}") process_and_save_state( diff --git a/pipelines/priors/eval_priors.py b/pipelines/priors/eval_priors.py new file mode 100644 index 00000000..2d5ef792 --- /dev/null +++ b/pipelines/priors/eval_priors.py @@ -0,0 +1,68 @@ +import jax.numpy as jnp +import numpyro.distributions as dist +import pyrenew.transformation as transformation +from numpyro.infer.reparam import LocScaleReparam +from pyrenew.randomvariable import DistributionalVariable, TransformedVariable + +i0_first_obs_n_rv = DistributionalVariable( + "i0_first_obs_n_rv", + dist.Beta(1, 10), +) + +initialization_rate_rv = DistributionalVariable( + "rate", dist.Normal(0, 0.01), reparam=LocScaleReparam(0) +) + +r_logmean = jnp.log(1.2) +r_logsd = jnp.log(jnp.sqrt(2)) + +log_r_mu_intercept_rv = DistributionalVariable( + "log_r_mu_intercept_rv", dist.Normal(r_logmean, r_logsd) +) + +eta_sd_rv = DistributionalVariable( + "eta_sd", dist.TruncatedNormal(0.1, 0.05, low=0) +) + +autoreg_rt_rv = DistributionalVariable("autoreg_rt", dist.Beta(2, 2)) + + +inf_feedback_strength_rv = TransformedVariable( + "inf_feedback", + DistributionalVariable( + "inf_feedback_raw", + dist.LogNormal(jnp.log(10), jnp.log(3)), + ), + transforms=transformation.AffineTransform(loc=0, scale=-1), +) +# Could be reparameterized? + +p_ed_visit_mean_rv = DistributionalVariable( + "p_ed_visit_mean", + dist.Normal( + transformation.SigmoidTransform().inv(0.005), + 0.3, + ), +) # logit scale + + +p_ed_visit_w_sd_rv = DistributionalVariable( + "p_ed_visit_w_sd_sd", dist.TruncatedNormal(0, 0.01, low=0) +) + + +autoreg_p_ed_visit_rv = DistributionalVariable( + "autoreg_p_ed_visit_rv", dist.Beta(1, 100) +) + +ed_visit_wday_effect_rv = TransformedVariable( + "ed_visit_wday_effect", + DistributionalVariable( + "ed_visit_wday_effect_raw", + dist.Dirichlet(jnp.array([5, 5, 5, 5, 5, 5, 5])), + ), + transformation.AffineTransform(loc=0, scale=7), +) + +# Based on looking at some historical posteriors. +phi_rv = DistributionalVariable("phi", dist.LogNormal(6, 1)) diff --git a/pipelines/priors/parameter_inference_priors.py b/pipelines/priors/parameter_inference_priors.py new file mode 100644 index 00000000..87bda09c --- /dev/null +++ b/pipelines/priors/parameter_inference_priors.py @@ -0,0 +1,68 @@ +import jax.numpy as jnp +import numpyro.distributions as dist +import pyrenew.transformation as transformation +from numpyro.infer.reparam import LocScaleReparam +from pyrenew.randomvariable import DistributionalVariable, TransformedVariable + +i0_first_obs_n_rv = DistributionalVariable( + "i0_first_obs_n_rv", + dist.Beta(1, 10), +) + +initialization_rate_rv = DistributionalVariable( + "rate", dist.Normal(0, 0.01), reparam=LocScaleReparam(0) +) + +r_logmean = jnp.log(1) +r_logsd = jnp.log(jnp.sqrt(3)) + +log_r_mu_intercept_rv = DistributionalVariable( + "log_r_mu_intercept_rv", dist.Normal(r_logmean, r_logsd) +) + +eta_sd_rv = DistributionalVariable( + "eta_sd", dist.TruncatedNormal(0.15, 0.1, low=0) +) + +autoreg_rt_rv = DistributionalVariable("autoreg_rt", dist.Beta(2, 40)) + + +inf_feedback_strength_rv = TransformedVariable( + "inf_feedback", + DistributionalVariable( + "inf_feedback_raw", + dist.LogNormal(jnp.log(1), jnp.log(20)), + ), + transforms=transformation.AffineTransform(loc=0, scale=-1), +) +# Could be reparameterized? + +p_ed_visit_mean_rv = DistributionalVariable( + "p_ed_visit_mean", + dist.Normal( + transformation.SigmoidTransform().inv(0.005), + 0.3, + ), +) # logit scale + + +p_ed_visit_w_sd_rv = DistributionalVariable( + "p_ed_visit_w_sd_sd", dist.TruncatedNormal(0, 0.01, low=0) +) + + +autoreg_p_ed_visit_rv = DistributionalVariable( + "autoreg_p_ed_visit_rv", dist.Beta(1, 100) +) + +ed_visit_wday_effect_rv = TransformedVariable( + "ed_visit_wday_effect", + DistributionalVariable( + "ed_visit_wday_effect_raw", + dist.Dirichlet(jnp.array([5, 5, 5, 5, 5, 5, 5])), + ), + transformation.AffineTransform(loc=0, scale=7), +) + +# Based on looking at some historical posteriors. +phi_rv = DistributionalVariable("phi", dist.LogNormal(6, 1.5)) diff --git a/pipelines/priors/prod_priors.py b/pipelines/priors/prod_priors.py new file mode 100644 index 00000000..05c4a9f4 --- /dev/null +++ b/pipelines/priors/prod_priors.py @@ -0,0 +1,68 @@ +import jax.numpy as jnp +import numpyro.distributions as dist +import pyrenew.transformation as transformation +from numpyro.infer.reparam import LocScaleReparam +from pyrenew.randomvariable import DistributionalVariable, TransformedVariable + +i0_first_obs_n_rv = DistributionalVariable( + "i0_first_obs_n_rv", + dist.Beta(1, 10), +) + +initialization_rate_rv = DistributionalVariable( + "rate", dist.Normal(0, 0.01), reparam=LocScaleReparam(0) +) + +r_logmean = jnp.log(1.2) +r_logsd = jnp.log(jnp.sqrt(2)) + +log_r_mu_intercept_rv = DistributionalVariable( + "log_r_mu_intercept_rv", dist.Normal(r_logmean, r_logsd) +) + +eta_sd_rv = DistributionalVariable( + "eta_sd", dist.TruncatedNormal(0.15, 0.05, low=0) +) + +autoreg_rt_rv = DistributionalVariable("autoreg_rt", dist.Beta(2, 40)) + + +inf_feedback_strength_rv = TransformedVariable( + "inf_feedback", + DistributionalVariable( + "inf_feedback_raw", + dist.LogNormal(jnp.log(50), jnp.log(1.5)), + ), + transforms=transformation.AffineTransform(loc=0, scale=-1), +) +# Could be reparameterized? + +p_ed_visit_mean_rv = DistributionalVariable( + "p_ed_visit_mean", + dist.Normal( + transformation.SigmoidTransform().inv(0.005), + 0.3, + ), +) # logit scale + + +p_ed_visit_w_sd_rv = DistributionalVariable( + "p_ed_visit_w_sd_sd", dist.TruncatedNormal(0, 0.01, low=0) +) + + +autoreg_p_ed_visit_rv = DistributionalVariable( + "autoreg_p_ed_visit_rv", dist.Beta(1, 100) +) + +ed_visit_wday_effect_rv = TransformedVariable( + "ed_visit_wday_effect", + DistributionalVariable( + "ed_visit_wday_effect_raw", + dist.Dirichlet(jnp.array([5, 5, 5, 5, 5, 5, 5])), + ), + transformation.AffineTransform(loc=0, scale=7), +) + +# Based on looking at some historical posteriors. +phi_rv = DistributionalVariable("phi", dist.LogNormal(4, 1)) diff --git a/pyproject.toml b/pyproject.toml index e37307d1..e561d840 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ pyarrow = "^18.0.0" pygit2 = "^1.16.0" azuretools = {git = "https://github.com/cdcgov/cfa-azuretools"} forecasttools = {git = "https://github.com/CDCgov/forecasttools-py"} +tomli-w = "^1.1.0" [tool.poetry.group.test] optional = true From 0d79bd009189e98df4279b5c1a15c352a6432a29 Mon Sep 17 00:00:00 2001 From: Samuel Brand <48288458+SamuelBrand1@users.noreply.github.com> Date: Tue, 17 Dec 2024 20:55:20 +0000 Subject: [PATCH 4/4] Issue 216: end-to-end CI (#247) * add simple model for generating count data * pre-commit * Create generate_test_data.R * state-level data * Update generate_test_data.R * styler * fake parameter data generation * populate test_output with priors * Create test_forecast_state.sh * Update generate_test_data.R * catch different naming of disease * catch script error * generate latest comprehensive fake state data * Update test_forecast_state.sh * reduce generated facs * reduce number of samples * ignore test pipe output * Update .gitignore * modify to generate out-of-sample data * add explicit fails to pipeline run These will trigger CI fails to avoid false CI passes * style * Add CI to run pipeline test mode * add install hewr step * catch bad var * catch no interactive * add epipredict and epiprocess * install quarto * cleanup generate_test_data * pre-commit --------- Co-authored-by: Damon Bayer --- .github/workflows/pipeline-run-check.yaml | 39 +++ .gitignore | 3 + hewr/NAMESPACE | 1 + hewr/R/generate_exp_growth_process.R | 26 ++ hewr/man/generate_exp_growth_pois.Rd | 32 +++ .../test_generate_exp_growth_process.R | 30 ++ pipelines/generate_test_data.R | 269 ++++++++++++++++++ pipelines/tests/test_forecast_state.sh | 39 +++ pipelines/tests/test_output/priors.py | 68 +++++ 9 files changed, 507 insertions(+) create mode 100644 .github/workflows/pipeline-run-check.yaml create mode 100644 hewr/R/generate_exp_growth_process.R create mode 100644 hewr/man/generate_exp_growth_pois.Rd create mode 100644 hewr/tests/testthat/test_generate_exp_growth_process.R create mode 100644 pipelines/generate_test_data.R create mode 100644 pipelines/tests/test_forecast_state.sh create mode 100644 pipelines/tests/test_output/priors.py diff --git a/.github/workflows/pipeline-run-check.yaml b/.github/workflows/pipeline-run-check.yaml new file mode 100644 index 00000000..222c7d73 --- /dev/null +++ b/.github/workflows/pipeline-run-check.yaml @@ -0,0 +1,39 @@ +name: Pipeline Run Check + +on: + pull_request: + push: + branches: [main] + +jobs: + run-pipeline: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + - name: "Set up R" + uses: r-lib/actions/setup-r@v2 + with: + r-version: "release" + use-public-rspm: true + - name: "Install poetry" + run: pip install poetry + - name: "Install pyrenew-hew" + run: poetry install + - name: "Set up Quarto" + uses: quarto-dev/quarto-actions/setup@v2 + - name: "Set up dependencies for hewr" + uses: r-lib/actions/setup-r-dependencies@v2 + with: + working-directory: hewr + - name: "Install extra pkgs" + run: | + pak::local_install("hewr", ask = FALSE) + pak::pkg_install("cmu-delphi/epipredict@main", ask = FALSE) + pak::pkg_install("cmu-delphi/epiprocess@main", ask = FALSE) + shell: Rscript {0} + - name: "Run pipeline" + run: poetry run bash pipelines/tests/test_forecast_state.sh pipelines/tests diff --git a/.gitignore b/.gitignore index e8d168e6..6ee6d2de 100644 --- a/.gitignore +++ b/.gitignore @@ -403,3 +403,6 @@ private_data/* # Test data exceptions to the general data exclusion !pipelines/tests/covid-19_r_2024-01-29_f_2023-11-01_t_2024-01-29/model_runs/TD/data/data.tsv !pipelines/tests/covid-19_r_2024-01-29_f_2023-11-01_t_2024-01-29/model_runs/TD/data/eval_data.tsv + +# Ignore test pipe output +pipelines/tests/private_data/* diff --git a/hewr/NAMESPACE b/hewr/NAMESPACE index 0e6f5e5e..2af66233 100644 --- a/hewr/NAMESPACE +++ b/hewr/NAMESPACE @@ -1,6 +1,7 @@ # Generated by roxygen2: do not edit by hand export(combine_training_and_eval_data) +export(generate_exp_growth_pois) export(get_all_model_batch_dirs) export(make_forecast_figure) export(parse_model_batch_dir_path) diff --git a/hewr/R/generate_exp_growth_process.R b/hewr/R/generate_exp_growth_process.R new file mode 100644 index 00000000..4370a7ba --- /dev/null +++ b/hewr/R/generate_exp_growth_process.R @@ -0,0 +1,26 @@ +#' Generate Exponential Growth Process with Poisson Noise +#' +#' This function generates a sequence of samples from an exponential growth +#' process through Poisson sampling: +#' ```math +#' \begin{aligned} +#' \( \lambda_t &= I_0 \exp(\sum_{t=1}^{t} r_t) \) \\ +#' I_t &\sim \text{Poisson}(\lambda_t). +#' ``` +#' @param rt A numeric vector of exponential growth rates. +#' @param initial A numeric value representing the initial value of the process. +#' +#' @return An integer vector of Poisson samples generated from the exponential +#' growth process. +#' +#' @examples +#' rt <- c(0.1, 0.2, 0.15) +#' initial <- 10 +#' generate_exp_growth_pois(rt, initial) +#' +#' @export +generate_exp_growth_pois <- function(rt, initial) { + means <- initial * exp(cumsum(rt)) + samples <- stats::rpois(length(means), lambda = means) + return(samples) +} diff --git a/hewr/man/generate_exp_growth_pois.Rd b/hewr/man/generate_exp_growth_pois.Rd new file mode 100644 index 00000000..213e5415 --- /dev/null +++ b/hewr/man/generate_exp_growth_pois.Rd @@ -0,0 +1,32 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/generate_exp_growth_process.R +\name{generate_exp_growth_pois} +\alias{generate_exp_growth_pois} +\title{Generate Exponential Growth Process with Poisson Noise} +\usage{ +generate_exp_growth_pois(rt, initial) +} +\arguments{ +\item{rt}{A numeric vector of exponential growth rates.} + +\item{initial}{A numeric value representing the initial value of the process.} +} +\value{ +An integer vector of Poisson samples generated from the exponential +growth process. +} +\description{ +This function generates a sequence of samples from an exponential growth +process through Poisson sampling: + +\if{html}{\out{
}}\preformatted{\\begin\{aligned\} +\\( \\lambda_t &= I_0 \\exp(\\sum_\{t=1\}^\{t\} r_t) \\) \\\\ +I_t &\\sim \\text\{Poisson\}(\\lambda_t). +}\if{html}{\out{
}} +} +\examples{ +rt <- c(0.1, 0.2, 0.15) +initial <- 10 +generate_exp_growth_pois(rt, initial) + +} diff --git a/hewr/tests/testthat/test_generate_exp_growth_process.R b/hewr/tests/testthat/test_generate_exp_growth_process.R new file mode 100644 index 00000000..8204b07f --- /dev/null +++ b/hewr/tests/testthat/test_generate_exp_growth_process.R @@ -0,0 +1,30 @@ +test_that("generate_exp_growth_pois generates correct number of samples", { + rt <- c(0.1, 0.2, 0.15) + initial <- 10 + samples <- generate_exp_growth_pois(rt, initial) + expect_length(samples, length(rt)) +}) + +test_that("generate_exp_growth_pois returns a vector of integers", { + rt <- c(0.1, 0.2, 0.15) + initial <- 10 + samples <- generate_exp_growth_pois(rt, initial) + expect_type(samples, "integer") +}) + +test_that("generate_exp_growth_pois does not return implausible values", { + rt <- c(0.1, 0.2, 0.15) + initial <- 10 + analytic_av <- initial * exp(cumsum(rt)) + analytic_std <- sqrt(analytic_av) + samples <- generate_exp_growth_pois(rt, initial) + expect_true(all(samples <= initial + 10 * analytic_std)) + expect_true(all(samples >= initial - 10 * analytic_std)) +}) + +test_that("generate_exp_growth_pois handles empty growth rates", { + rt <- numeric(0) + initial <- 10 + samples <- generate_exp_growth_pois(rt, initial) + expect_equal(samples, numeric(0)) +}) diff --git a/pipelines/generate_test_data.R b/pipelines/generate_test_data.R new file mode 100644 index 00000000..07fe0577 --- /dev/null +++ b/pipelines/generate_test_data.R @@ -0,0 +1,269 @@ +script_packages <- c( + "argparser", + "arrow", + "dplyr", + "fs", + "hewr", + "lubridate", + "tidyr" +) + +# load in packages without messages +purrr::walk(script_packages, \(pkg) { + suppressPackageStartupMessages( + library(pkg, character.only = TRUE) + ) +}) + +# Set seed for reproducibility +set.seed(123) + +# Dict for converting to short names +disease_short_names <- list("COVID-19/Omicron" = "COVID-19") + + +#' Create Facility Test Data +#' +#' This function generates test data for a given facility over a specified +#' date range. The test data counts for the target disease are generated using +#' a cosine varying exponential growth process, with Poisson samples, while +#' counts for other diseases are generated using a mean-constant Poisson +#' process. +#' +#' @param facility A number representing the name of the facility. +#' @param start_reference A Date object representing the start date of the +#' reference period. +#' @param end_reference A Date object representing the end date of the reference +#' period. +#' @param initial A numeric value representing the initial expected count for +#' the target disease. Default is 10.0. +#' @param mean_other A numeric value representing the mean count for other +#' diseases. Default is 200.0. +#' @param target_disease A character string representing the name of the target +#' disease. Default is "COVID-19/Omicron". +#' +#' @return A tibble containing the generated test data with columns for +#' reference date, report date, geo type, geo value, as of date, run ID, +#' facility, disease, and value. +create_facility_test_data <- function( + facility, start_reference, end_reference, + initial = 10.0, mean_other = 200.0, target_disease = "COVID-19/Omicron") { + reference_dates <- seq(start_reference, end_reference, by = "day") + rt <- 0.25 * cos(2 * pi * as.numeric(difftime(reference_dates, + start_reference, + units = "days" + )) / 180) + yt <- generate_exp_growth_pois(rt, initial) + others <- generate_exp_growth_pois(0.0 * rt, mean_other) + target_fac_data <- tibble( + reference_date = reference_dates, + report_date = end_reference, + geo_type = "state", + geo_value = "CA", + asof = end_reference, + metric = "count_ed_visits", + run_id = 0, + facility = facility, + !!target_disease := yt, + Total = yt + others, + ) |> pivot_longer( + cols = c(all_of(target_disease), "Total"), + names_to = "disease", values_to = "value" + ) + return(target_fac_data) +} + +#' Generate Fake Facility Data +#' +#' This function generates fake facility data for a specified number of +#' facilities within a given date range and writes the data to a parquet file. +#' +#' @param private_data_dir A string specifying the directory where the generated +#' data will be saved. +#' @param n_facilities An integer specifying the number of facilities to +#' generate data for. Default is 3. +#' @param start_reference A Date object specifying the start date for the data +#' generation. Default is "2024-06-01". +#' @param end_reference A Date object specifying the end date for the data +#' generation. Default is "2024-12-25". +#' @param initial A numeric value specifying the initial value for the data +#' generation. Default is 10. +#' @param mean_other A numeric value specifying the mean value for other data +#' points. Default is 200. +#' @param target_disease A string specifying the target disease for the data +#' generation. Default is "COVID-19/Omicron". +#' +#' @return This function does not return a value. It writes the generated data +#' to a parquet file. +generate_fake_facility_data <- function( + private_data_dir = path(getwd()), n_facilities = 1, + start_reference = as.Date("2024-06-01"), + end_reference = as.Date("2024-12-25"), initial = 10, mean_other = 200, + target_disease = "COVID-19/Omicron") { + nssp_etl_gold_dir <- path(private_data_dir, "nssp_etl_gold") + dir_create(nssp_etl_gold_dir, recurse = TRUE) + + fac_data <- purrr::map(1:n_facilities, \(i) { + create_facility_test_data( + i, start_reference, end_reference, + initial, mean_other, target_disease + ) + }) |> + bind_rows() |> + write_parquet(path(nssp_etl_gold_dir, end_reference, ext = "parquet")) +} + +#' Generate State Level Data +#' +#' This function generates state-level test data for a specified disease over a +#' given time period. +#' +#' @param private_data_dir A string specifying the directory where the generated +#' data will be stored. +#' @param start_reference A Date object specifying the start date for the data +#' generation period. Default is "2024-06-01". +#' @param end_reference A Date object specifying the end date for the data +#' generation period. Default is "2024-12-25". +#' @param initial A numeric value specifying the initial value for the data +#' generation. Default is 10. +#' @param mean_other A numeric value specifying the mean value for other data +#' points. Default is 200. +#' @param target_disease A string specifying the target disease for the data +#' generation. Default is "COVID-19/Omicron". +#' +#' @return This function does not return a value. It writes the generated data +#' to a parquet file in the specified directory. +generate_fake_state_level_data <- function( + private_data_dir = path(getwd()), + start_reference = as.Date("2024-06-01"), + end_reference = as.Date("2024-12-25"), initial = 10, mean_other = 200, + target_disease = "COVID-19/Omicron", n_forecast_days = 28) { + gold_dir <- path(private_data_dir, "nssp_state_level_gold") + dir_create(gold_dir, recurse = TRUE) + + comp_dir <- path(private_data_dir, "nssp-archival-vintages") + dir_create(comp_dir, recurse = TRUE) + + state_data <- create_facility_test_data( + 1, start_reference, end_reference + n_forecast_days, + initial, mean_other, target_disease + ) |> + select(-facility, -run_id, -asof) + + # Write in-sample state-level data to gold directory + state_data |> + filter(reference_date <= end_reference) |> + mutate(any_update_this_day = TRUE) |> + write_parquet(path(gold_dir, end_reference, ext = "parquet")) + + # Write out-of-sample state-level data to comparison directory + state_data |> + filter(reference_date > end_reference) |> + write_parquet(path(comp_dir, "latest_comprehensive", + ext = "parquet" + )) +} + +#' Generate Fake Parameter Data +#' +#' This function generates fake parameter data for a specified disease and +#' saves it as a parquet file. +#' +#' The function creates a directory for storing the parameter estimates if it +#' does not already exist. It then generates a simple discretized exponential +#' distribution for the generation interval (gi_pmf) and a right truncation +#' probability mass function (rt_truncation_pmf). +#' +#' @param private_data_dir A string specifying the directory where the data will +#' be saved. +#' @param end_reference A Date object specifying the end reference date for the +#' data. Default is "2024-12-25". +#' @param target_disease A string specifying the target disease for the data. +#' Default is "COVID-19". +generate_fake_param_data <- function( + private_data_dir = path(getwd()), + end_reference = as.Date("2024-12-25"), target_disease = "COVID-19") { + prod_param_estimates_dir <- path(private_data_dir, "prod_param_estimates") + dir_create(prod_param_estimates_dir, recurse = TRUE) + + # Simple discretize exponential distribution + gi_pmf <- seq(0.5, 6.5) |> dexp() + gi_pmf <- gi_pmf / sum(gi_pmf) + delay_pmf <- seq(0.5, 10.5) |> dexp(rate = 1 / 2) + delay_pmf <- delay_pmf / sum(delay_pmf) + rt_truncation_pmf <- c(1, 0, 0, 0) + + gi_data <- tibble( + id = 0, + start_date = as.Date("2024-06-01"), + end_date = NA, + reference_date = end_reference, + disease = target_disease, + format = "PMF", + parameter = "generation_interval", + geo_value = NA, + value = list(gi_pmf) + ) + delay_data <- tibble( + id = 0, + start_date = as.Date("2024-06-01"), + end_date = NA, + reference_date = end_reference, + disease = target_disease, + format = "PMF", + parameter = "delay", + geo_value = NA, + value = list(delay_pmf) + ) + rt_trunc_data <- tibble( + id = 0, + start_date = as.Date("2024-06-01"), + end_date = NA, + reference_date = end_reference, + disease = target_disease, + format = "PMF", + parameter = "right_truncation", + geo_value = "CA", + value = list(rt_truncation_pmf) + ) + write_parquet( + bind_rows(gi_data, delay_data, rt_trunc_data), + path(prod_param_estimates_dir, "prod", ext = "parquet") + ) +} + +main <- function(private_data_dir, target_disease, n_forecast_days) { + short_target_disease <- disease_short_names[[target_disease]] + generate_fake_facility_data(private_data_dir, target_disease = target_disease) + generate_fake_state_level_data(private_data_dir, + target_disease = target_disease, n_forecast_days = n_forecast_days + ) + generate_fake_param_data(private_data_dir, + target_disease = short_target_disease + ) +} + +p <- arg_parser("Create epiweekly data") |> + add_argument( + "model_run_dir", + help = "Directory containing the model data and output." + ) |> + add_argument( + "--target-disease", + type = "character", + default = "COVID-19/Omicron", + help = "Target disease for the data generation." + ) |> + add_argument( + "--n-forecast-days", + type = "integer", + default = 28, + help = "Number of days to forecast." + ) + +argv <- parse_args(p) + +main(argv$model_run_dir, + target_disease = argv$target_disease, + n_forecast_days = argv$n_forecast_days +) diff --git a/pipelines/tests/test_forecast_state.sh b/pipelines/tests/test_forecast_state.sh new file mode 100644 index 00000000..56a42e4f --- /dev/null +++ b/pipelines/tests/test_forecast_state.sh @@ -0,0 +1,39 @@ +#!/bin/bash + +# Check if the base directory is provided as an argument +if [ -z "$1" ]; then + echo "Usage: $0 " + exit 1 +fi + +BASE_DIR="$1" +echo "TEST-MODE: Running forecast_state.py in test mode with base directory $BASE_DIR" +Rscript pipelines/generate_test_data.R "$BASE_DIR/private_data" +if [ $? -ne 0 ]; then + echo "TEST-MODE FAIL: Generating test data failed" + exit 1 +else + echo "TEST-MODE: Finished generating test data" +fi +echo "TEST-MODE: Running forecasting pipeline" +python pipelines/forecast_state.py \ + --disease "COVID-19" \ + --state "CA" \ + --facility-level-nssp-data-dir "$BASE_DIR/private_data/nssp_etl_gold" \ + --state-level-nssp-data-dir "$BASE_DIR/private_data/nssp_state_level_gold" \ + --priors-path "$BASE_DIR/test_output/priors.py" \ + --param-data-dir "$BASE_DIR/private_data/prod_param_estimates" \ + --output-dir "$BASE_DIR/private_data" \ + --n-training-days 90 \ + --n-chains 1 \ + --n-samples 500 \ + --n-warmup 500 \ + --score \ + --eval-data-path "$BASE_DIR/private_data/nssp-archival-vintages" +if [ $? -ne 0 ]; then + echo "TEST-MODE FAIL: Forecasting/postprocessing/scoring pipeline failed" + exit 1 +else + echo "TEST-MODE: Finished forecasting/postprocessing/scoring pipeline" +fi +echo "TEST-MODE: All finished successfully" diff --git a/pipelines/tests/test_output/priors.py b/pipelines/tests/test_output/priors.py new file mode 100644 index 00000000..4f9d61ab --- /dev/null +++ b/pipelines/tests/test_output/priors.py @@ -0,0 +1,68 @@ +import jax.numpy as jnp +import numpyro.distributions as dist +import pyrenew.transformation as transformation +from numpyro.infer.reparam import LocScaleReparam +from pyrenew.randomvariable import DistributionalVariable, TransformedVariable + +i0_first_obs_n_rv = DistributionalVariable( + "i0_first_obs_n_rv", + dist.Beta(1, 10), +) + +initialization_rate_rv = DistributionalVariable( + "rate", dist.Normal(0, 0.01), reparam=LocScaleReparam(0) +) + +r_logmean = jnp.log(1) +r_logsd = jnp.log(jnp.sqrt(2)) + +log_r_mu_intercept_rv = DistributionalVariable( + "log_r_mu_intercept_rv", dist.Normal(r_logmean, r_logsd) +) + +eta_sd_rv = DistributionalVariable( + "eta_sd", dist.TruncatedNormal(0.04, 0.02, low=0) +) + +autoreg_rt_rv = DistributionalVariable("autoreg_rt", dist.Beta(2, 40)) + + +inf_feedback_strength_rv = TransformedVariable( + "inf_feedback", + DistributionalVariable( + "inf_feedback_raw", + dist.LogNormal(jnp.log(50), jnp.log(2)), + ), + transforms=transformation.AffineTransform(loc=0, scale=-1), +) +# Could be reparameterized? + +p_ed_visit_mean_rv = DistributionalVariable( + "p_ed_visit_mean", + dist.Normal( + transformation.SigmoidTransform().inv(0.005), + 0.3, + ), +) # logit scale + + +p_ed_visit_w_sd_rv = DistributionalVariable( + "p_ed_visit_w_sd_sd", dist.TruncatedNormal(0, 0.01, low=0) +) + + +autoreg_p_ed_visit_rv = DistributionalVariable( + "autoreg_p_ed_visit_rv", dist.Beta(1, 100) +) + +ed_visit_wday_effect_rv = TransformedVariable( + "ed_visit_wday_effect", + DistributionalVariable( + "ed_visit_wday_effect_raw", + dist.Dirichlet(jnp.array([5, 5, 5, 5, 5, 5, 5])), + ), + transformation.AffineTransform(loc=0, scale=7), +) + +# Based on looking at some historical posteriors. +phi_rv = DistributionalVariable("phi", dist.LogNormal(6, 1))