From 98842041f8df80741ddd8d7daf4f60ff02ae9f0b Mon Sep 17 00:00:00 2001 From: Damon Bayer Date: Tue, 17 Dec 2024 13:15:57 -0600 Subject: [PATCH] Add priors to git repo and adjust container build (#251) * add priors * more robust metadata file handling and save priors to metadata file * correct path to metadata * try changing the container file for fast builds when priors change * use shorthand for local branch name extraction * remove redundant code * fix copy paste error * try allowing new syntax * switch from yaml to toml * pre-commit * correction for test job args --- Containerfile | 6 +- pipelines/batch/setup_prod_job.py | 2 +- pipelines/batch/setup_test_prod_job.py | 4 +- pipelines/forecast_state.py | 53 +++++++++++---- pipelines/priors/eval_priors.py | 68 +++++++++++++++++++ .../priors/parameter_inference_priors.py | 68 +++++++++++++++++++ pipelines/priors/prod_priors.py | 68 +++++++++++++++++++ pyproject.toml | 1 + 8 files changed, 254 insertions(+), 16 deletions(-) create mode 100644 pipelines/priors/eval_priors.py create mode 100644 pipelines/priors/parameter_inference_priors.py create mode 100644 pipelines/priors/prod_priors.py diff --git a/Containerfile b/Containerfile index 20b98455..5b18f0c7 100644 --- a/Containerfile +++ b/Containerfile @@ -1,3 +1,4 @@ +#syntax=docker/dockerfile:1.7-labs ARG TAG=latest FROM cfaprdbatchcr.azurecr.io/pyrenew-hew-dependencies:${TAG} @@ -18,5 +19,8 @@ RUN Rscript -e "pak::pkg_install('cmu-delphi/epiprocess@main')" RUN Rscript -e "pak::pkg_install('cmu-delphi/epipredict@main')" RUN Rscript -e "pak::local_install('hewr')" -COPY . . +COPY --exclude=pipelines/priors . . RUN pip install --root-user-action=ignore . + +# Copy priors folder last +COPY pipelines/priors pipelines/priors diff --git a/pipelines/batch/setup_prod_job.py b/pipelines/batch/setup_prod_job.py index 414d3d1f..2b8040f7 100644 --- a/pipelines/batch/setup_prod_job.py +++ b/pipelines/batch/setup_prod_job.py @@ -167,7 +167,7 @@ def main( "nssp-archival-vintages/gold " "--param-data-dir params " "--output-dir {output_dir} " - "--priors-path config/prod_priors.py " + "--priors-path pipelines/priors/prod_priors.py " "--report-date {report_date} " "--exclude-last-n-days {exclude_last_n_days} " "--no-score " diff --git a/pipelines/batch/setup_test_prod_job.py b/pipelines/batch/setup_test_prod_job.py index 821d3b4f..446bad63 100644 --- a/pipelines/batch/setup_test_prod_job.py +++ b/pipelines/batch/setup_test_prod_job.py @@ -21,7 +21,7 @@ "--tag", type=str, help="The tag name to use for the container image version", - default=Path(Repository(os.getcwd()).head.name).stem, + default=Repository(os.getcwd()).head.shorthand, ) args = parser.parse_args() @@ -95,6 +95,6 @@ diseases=["COVID-19", "Influenza"], container_image_name="pyrenew-hew", container_image_version=tag, - excluded_locations=locs_to_exclude, + locations_exclude=locs_to_exclude, test=True, ) diff --git a/pipelines/forecast_state.py b/pipelines/forecast_state.py index c5c4442d..97f3946f 100644 --- a/pipelines/forecast_state.py +++ b/pipelines/forecast_state.py @@ -8,7 +8,8 @@ import numpyro import polars as pl -import yaml +import tomli_w +import tomllib from prep_data import process_and_save_state from pygit2 import Repository from save_eval_data import save_eval_data @@ -20,24 +21,52 @@ def record_git_info(model_run_dir: Path): - metadata_file = Path(model_run_dir, "metadata.yaml") + metadata_file = Path(model_run_dir, "metadata.toml") + + if metadata_file.exists(): + with open(metadata_file, "rb") as file: + metadata = tomllib.load(file) + else: + metadata = {} + try: repo = Repository(os.getcwd()) - branch_name = os.environ.get( - "GIT_BRANCH_NAME", Path(repo.head.name).stem - ) - commit_sha = os.environ.get("GIT_COMMIT_SHA", str(repo.head.target)) - except: + branch_name = repo.head.shorthand + commit_sha = str(repo.head.target) + except Exception as e: branch_name = os.environ.get("GIT_BRANCH_NAME", "unknown") commit_sha = os.environ.get("GIT_COMMIT_SHA", "unknown") - metadata = { + new_metadata = { "branch_name": branch_name, "commit_sha": commit_sha, } - with open(metadata_file, "w") as file: - yaml.dump(metadata, file) + metadata.update(new_metadata) + + metadata_file.parent.mkdir(parents=True, exist_ok=True) + with open(metadata_file, "wb") as file: + tomli_w.dump(metadata, file) + + +def copy_and_record_priors(priors_path: Path, model_run_dir: Path): + metadata_file = Path(model_run_dir, "metadata.toml") + shutil.copyfile(priors_path, Path(model_run_dir, "priors.py")) + + if metadata_file.exists(): + with open(metadata_file, "rb") as file: + metadata = tomllib.load(file) + else: + metadata = {} + + new_metadata = { + "priors_path": str(priors_path), + } + + metadata.update(new_metadata) + + with open(metadata_file, "wb") as file: + tomli_w.dump(metadata, file) def generate_epiweekly(model_run_dir: Path) -> None: @@ -264,8 +293,8 @@ def main( logger.info("Recording git info...") record_git_info(model_run_dir) - logger.info(f"Using priors from {priors_path}...") - shutil.copyfile(priors_path, Path(model_run_dir, "priors.py")) + logger.info(f"Copying and recording priors from {priors_path}...") + copy_and_record_priors(priors_path, model_run_dir) logger.info(f"Processing {state}") process_and_save_state( diff --git a/pipelines/priors/eval_priors.py b/pipelines/priors/eval_priors.py new file mode 100644 index 00000000..2d5ef792 --- /dev/null +++ b/pipelines/priors/eval_priors.py @@ -0,0 +1,68 @@ +import jax.numpy as jnp +import numpyro.distributions as dist +import pyrenew.transformation as transformation +from numpyro.infer.reparam import LocScaleReparam +from pyrenew.randomvariable import DistributionalVariable, TransformedVariable + +i0_first_obs_n_rv = DistributionalVariable( + "i0_first_obs_n_rv", + dist.Beta(1, 10), +) + +initialization_rate_rv = DistributionalVariable( + "rate", dist.Normal(0, 0.01), reparam=LocScaleReparam(0) +) + +r_logmean = jnp.log(1.2) +r_logsd = jnp.log(jnp.sqrt(2)) + +log_r_mu_intercept_rv = DistributionalVariable( + "log_r_mu_intercept_rv", dist.Normal(r_logmean, r_logsd) +) + +eta_sd_rv = DistributionalVariable( + "eta_sd", dist.TruncatedNormal(0.1, 0.05, low=0) +) + +autoreg_rt_rv = DistributionalVariable("autoreg_rt", dist.Beta(2, 2)) + + +inf_feedback_strength_rv = TransformedVariable( + "inf_feedback", + DistributionalVariable( + "inf_feedback_raw", + dist.LogNormal(jnp.log(10), jnp.log(3)), + ), + transforms=transformation.AffineTransform(loc=0, scale=-1), +) +# Could be reparameterized? + +p_ed_visit_mean_rv = DistributionalVariable( + "p_ed_visit_mean", + dist.Normal( + transformation.SigmoidTransform().inv(0.005), + 0.3, + ), +) # logit scale + + +p_ed_visit_w_sd_rv = DistributionalVariable( + "p_ed_visit_w_sd_sd", dist.TruncatedNormal(0, 0.01, low=0) +) + + +autoreg_p_ed_visit_rv = DistributionalVariable( + "autoreg_p_ed_visit_rv", dist.Beta(1, 100) +) + +ed_visit_wday_effect_rv = TransformedVariable( + "ed_visit_wday_effect", + DistributionalVariable( + "ed_visit_wday_effect_raw", + dist.Dirichlet(jnp.array([5, 5, 5, 5, 5, 5, 5])), + ), + transformation.AffineTransform(loc=0, scale=7), +) + +# Based on looking at some historical posteriors. +phi_rv = DistributionalVariable("phi", dist.LogNormal(6, 1)) diff --git a/pipelines/priors/parameter_inference_priors.py b/pipelines/priors/parameter_inference_priors.py new file mode 100644 index 00000000..87bda09c --- /dev/null +++ b/pipelines/priors/parameter_inference_priors.py @@ -0,0 +1,68 @@ +import jax.numpy as jnp +import numpyro.distributions as dist +import pyrenew.transformation as transformation +from numpyro.infer.reparam import LocScaleReparam +from pyrenew.randomvariable import DistributionalVariable, TransformedVariable + +i0_first_obs_n_rv = DistributionalVariable( + "i0_first_obs_n_rv", + dist.Beta(1, 10), +) + +initialization_rate_rv = DistributionalVariable( + "rate", dist.Normal(0, 0.01), reparam=LocScaleReparam(0) +) + +r_logmean = jnp.log(1) +r_logsd = jnp.log(jnp.sqrt(3)) + +log_r_mu_intercept_rv = DistributionalVariable( + "log_r_mu_intercept_rv", dist.Normal(r_logmean, r_logsd) +) + +eta_sd_rv = DistributionalVariable( + "eta_sd", dist.TruncatedNormal(0.15, 0.1, low=0) +) + +autoreg_rt_rv = DistributionalVariable("autoreg_rt", dist.Beta(2, 40)) + + +inf_feedback_strength_rv = TransformedVariable( + "inf_feedback", + DistributionalVariable( + "inf_feedback_raw", + dist.LogNormal(jnp.log(1), jnp.log(20)), + ), + transforms=transformation.AffineTransform(loc=0, scale=-1), +) +# Could be reparameterized? + +p_ed_visit_mean_rv = DistributionalVariable( + "p_ed_visit_mean", + dist.Normal( + transformation.SigmoidTransform().inv(0.005), + 0.3, + ), +) # logit scale + + +p_ed_visit_w_sd_rv = DistributionalVariable( + "p_ed_visit_w_sd_sd", dist.TruncatedNormal(0, 0.01, low=0) +) + + +autoreg_p_ed_visit_rv = DistributionalVariable( + "autoreg_p_ed_visit_rv", dist.Beta(1, 100) +) + +ed_visit_wday_effect_rv = TransformedVariable( + "ed_visit_wday_effect", + DistributionalVariable( + "ed_visit_wday_effect_raw", + dist.Dirichlet(jnp.array([5, 5, 5, 5, 5, 5, 5])), + ), + transformation.AffineTransform(loc=0, scale=7), +) + +# Based on looking at some historical posteriors. +phi_rv = DistributionalVariable("phi", dist.LogNormal(6, 1.5)) diff --git a/pipelines/priors/prod_priors.py b/pipelines/priors/prod_priors.py new file mode 100644 index 00000000..05c4a9f4 --- /dev/null +++ b/pipelines/priors/prod_priors.py @@ -0,0 +1,68 @@ +import jax.numpy as jnp +import numpyro.distributions as dist +import pyrenew.transformation as transformation +from numpyro.infer.reparam import LocScaleReparam +from pyrenew.randomvariable import DistributionalVariable, TransformedVariable + +i0_first_obs_n_rv = DistributionalVariable( + "i0_first_obs_n_rv", + dist.Beta(1, 10), +) + +initialization_rate_rv = DistributionalVariable( + "rate", dist.Normal(0, 0.01), reparam=LocScaleReparam(0) +) + +r_logmean = jnp.log(1.2) +r_logsd = jnp.log(jnp.sqrt(2)) + +log_r_mu_intercept_rv = DistributionalVariable( + "log_r_mu_intercept_rv", dist.Normal(r_logmean, r_logsd) +) + +eta_sd_rv = DistributionalVariable( + "eta_sd", dist.TruncatedNormal(0.15, 0.05, low=0) +) + +autoreg_rt_rv = DistributionalVariable("autoreg_rt", dist.Beta(2, 40)) + + +inf_feedback_strength_rv = TransformedVariable( + "inf_feedback", + DistributionalVariable( + "inf_feedback_raw", + dist.LogNormal(jnp.log(50), jnp.log(1.5)), + ), + transforms=transformation.AffineTransform(loc=0, scale=-1), +) +# Could be reparameterized? + +p_ed_visit_mean_rv = DistributionalVariable( + "p_ed_visit_mean", + dist.Normal( + transformation.SigmoidTransform().inv(0.005), + 0.3, + ), +) # logit scale + + +p_ed_visit_w_sd_rv = DistributionalVariable( + "p_ed_visit_w_sd_sd", dist.TruncatedNormal(0, 0.01, low=0) +) + + +autoreg_p_ed_visit_rv = DistributionalVariable( + "autoreg_p_ed_visit_rv", dist.Beta(1, 100) +) + +ed_visit_wday_effect_rv = TransformedVariable( + "ed_visit_wday_effect", + DistributionalVariable( + "ed_visit_wday_effect_raw", + dist.Dirichlet(jnp.array([5, 5, 5, 5, 5, 5, 5])), + ), + transformation.AffineTransform(loc=0, scale=7), +) + +# Based on looking at some historical posteriors. +phi_rv = DistributionalVariable("phi", dist.LogNormal(4, 1)) diff --git a/pyproject.toml b/pyproject.toml index e37307d1..e561d840 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ pyarrow = "^18.0.0" pygit2 = "^1.16.0" azuretools = {git = "https://github.com/cdcgov/cfa-azuretools"} forecasttools = {git = "https://github.com/CDCgov/forecasttools-py"} +tomli-w = "^1.1.0" [tool.poetry.group.test] optional = true