Skip to content

Commit

Permalink
Add priors to git repo and adjust container build (#251)
Browse files Browse the repository at this point in the history
* add priors

* more robust metadata file handling and save priors to metadata file

* correct path to metadata

* try changing the container file for fast builds when priors change

* use shorthand for local branch name extraction

* remove redundant code

* fix copy paste error

* try allowing new syntax

* switch from yaml to toml

* pre-commit

* correction for test job args
  • Loading branch information
damonbayer authored Dec 17, 2024
1 parent 2927b03 commit 9884204
Show file tree
Hide file tree
Showing 8 changed files with 254 additions and 16 deletions.
6 changes: 5 additions & 1 deletion Containerfile
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#syntax=docker/dockerfile:1.7-labs
ARG TAG=latest

FROM cfaprdbatchcr.azurecr.io/pyrenew-hew-dependencies:${TAG}
Expand All @@ -18,5 +19,8 @@ RUN Rscript -e "pak::pkg_install('cmu-delphi/epiprocess@main')"
RUN Rscript -e "pak::pkg_install('cmu-delphi/epipredict@main')"
RUN Rscript -e "pak::local_install('hewr')"

COPY . .
COPY --exclude=pipelines/priors . .
RUN pip install --root-user-action=ignore .

# Copy priors folder last
COPY pipelines/priors pipelines/priors
2 changes: 1 addition & 1 deletion pipelines/batch/setup_prod_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def main(
"nssp-archival-vintages/gold "
"--param-data-dir params "
"--output-dir {output_dir} "
"--priors-path config/prod_priors.py "
"--priors-path pipelines/priors/prod_priors.py "
"--report-date {report_date} "
"--exclude-last-n-days {exclude_last_n_days} "
"--no-score "
Expand Down
4 changes: 2 additions & 2 deletions pipelines/batch/setup_test_prod_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"--tag",
type=str,
help="The tag name to use for the container image version",
default=Path(Repository(os.getcwd()).head.name).stem,
default=Repository(os.getcwd()).head.shorthand,
)

args = parser.parse_args()
Expand Down Expand Up @@ -95,6 +95,6 @@
diseases=["COVID-19", "Influenza"],
container_image_name="pyrenew-hew",
container_image_version=tag,
excluded_locations=locs_to_exclude,
locations_exclude=locs_to_exclude,
test=True,
)
53 changes: 41 additions & 12 deletions pipelines/forecast_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

import numpyro
import polars as pl
import yaml
import tomli_w
import tomllib
from prep_data import process_and_save_state
from pygit2 import Repository
from save_eval_data import save_eval_data
Expand All @@ -20,24 +21,52 @@


def record_git_info(model_run_dir: Path):
metadata_file = Path(model_run_dir, "metadata.yaml")
metadata_file = Path(model_run_dir, "metadata.toml")

if metadata_file.exists():
with open(metadata_file, "rb") as file:
metadata = tomllib.load(file)
else:
metadata = {}

try:
repo = Repository(os.getcwd())
branch_name = os.environ.get(
"GIT_BRANCH_NAME", Path(repo.head.name).stem
)
commit_sha = os.environ.get("GIT_COMMIT_SHA", str(repo.head.target))
except:
branch_name = repo.head.shorthand
commit_sha = str(repo.head.target)
except Exception as e:
branch_name = os.environ.get("GIT_BRANCH_NAME", "unknown")
commit_sha = os.environ.get("GIT_COMMIT_SHA", "unknown")

metadata = {
new_metadata = {
"branch_name": branch_name,
"commit_sha": commit_sha,
}

with open(metadata_file, "w") as file:
yaml.dump(metadata, file)
metadata.update(new_metadata)

metadata_file.parent.mkdir(parents=True, exist_ok=True)
with open(metadata_file, "wb") as file:
tomli_w.dump(metadata, file)


def copy_and_record_priors(priors_path: Path, model_run_dir: Path):
metadata_file = Path(model_run_dir, "metadata.toml")
shutil.copyfile(priors_path, Path(model_run_dir, "priors.py"))

if metadata_file.exists():
with open(metadata_file, "rb") as file:
metadata = tomllib.load(file)
else:
metadata = {}

new_metadata = {
"priors_path": str(priors_path),
}

metadata.update(new_metadata)

with open(metadata_file, "wb") as file:
tomli_w.dump(metadata, file)


def generate_epiweekly(model_run_dir: Path) -> None:
Expand Down Expand Up @@ -264,8 +293,8 @@ def main(
logger.info("Recording git info...")
record_git_info(model_run_dir)

logger.info(f"Using priors from {priors_path}...")
shutil.copyfile(priors_path, Path(model_run_dir, "priors.py"))
logger.info(f"Copying and recording priors from {priors_path}...")
copy_and_record_priors(priors_path, model_run_dir)

logger.info(f"Processing {state}")
process_and_save_state(
Expand Down
68 changes: 68 additions & 0 deletions pipelines/priors/eval_priors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import jax.numpy as jnp
import numpyro.distributions as dist
import pyrenew.transformation as transformation
from numpyro.infer.reparam import LocScaleReparam
from pyrenew.randomvariable import DistributionalVariable, TransformedVariable

i0_first_obs_n_rv = DistributionalVariable(
"i0_first_obs_n_rv",
dist.Beta(1, 10),
)

initialization_rate_rv = DistributionalVariable(
"rate", dist.Normal(0, 0.01), reparam=LocScaleReparam(0)
)

r_logmean = jnp.log(1.2)
r_logsd = jnp.log(jnp.sqrt(2))

log_r_mu_intercept_rv = DistributionalVariable(
"log_r_mu_intercept_rv", dist.Normal(r_logmean, r_logsd)
)

eta_sd_rv = DistributionalVariable(
"eta_sd", dist.TruncatedNormal(0.1, 0.05, low=0)
)

autoreg_rt_rv = DistributionalVariable("autoreg_rt", dist.Beta(2, 2))


inf_feedback_strength_rv = TransformedVariable(
"inf_feedback",
DistributionalVariable(
"inf_feedback_raw",
dist.LogNormal(jnp.log(10), jnp.log(3)),
),
transforms=transformation.AffineTransform(loc=0, scale=-1),
)
# Could be reparameterized?

p_ed_visit_mean_rv = DistributionalVariable(
"p_ed_visit_mean",
dist.Normal(
transformation.SigmoidTransform().inv(0.005),
0.3,
),
) # logit scale


p_ed_visit_w_sd_rv = DistributionalVariable(
"p_ed_visit_w_sd_sd", dist.TruncatedNormal(0, 0.01, low=0)
)


autoreg_p_ed_visit_rv = DistributionalVariable(
"autoreg_p_ed_visit_rv", dist.Beta(1, 100)
)

ed_visit_wday_effect_rv = TransformedVariable(
"ed_visit_wday_effect",
DistributionalVariable(
"ed_visit_wday_effect_raw",
dist.Dirichlet(jnp.array([5, 5, 5, 5, 5, 5, 5])),
),
transformation.AffineTransform(loc=0, scale=7),
)

# Based on looking at some historical posteriors.
phi_rv = DistributionalVariable("phi", dist.LogNormal(6, 1))
68 changes: 68 additions & 0 deletions pipelines/priors/parameter_inference_priors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import jax.numpy as jnp
import numpyro.distributions as dist
import pyrenew.transformation as transformation
from numpyro.infer.reparam import LocScaleReparam
from pyrenew.randomvariable import DistributionalVariable, TransformedVariable

i0_first_obs_n_rv = DistributionalVariable(
"i0_first_obs_n_rv",
dist.Beta(1, 10),
)

initialization_rate_rv = DistributionalVariable(
"rate", dist.Normal(0, 0.01), reparam=LocScaleReparam(0)
)

r_logmean = jnp.log(1)
r_logsd = jnp.log(jnp.sqrt(3))

log_r_mu_intercept_rv = DistributionalVariable(
"log_r_mu_intercept_rv", dist.Normal(r_logmean, r_logsd)
)

eta_sd_rv = DistributionalVariable(
"eta_sd", dist.TruncatedNormal(0.15, 0.1, low=0)
)

autoreg_rt_rv = DistributionalVariable("autoreg_rt", dist.Beta(2, 40))


inf_feedback_strength_rv = TransformedVariable(
"inf_feedback",
DistributionalVariable(
"inf_feedback_raw",
dist.LogNormal(jnp.log(1), jnp.log(20)),
),
transforms=transformation.AffineTransform(loc=0, scale=-1),
)
# Could be reparameterized?

p_ed_visit_mean_rv = DistributionalVariable(
"p_ed_visit_mean",
dist.Normal(
transformation.SigmoidTransform().inv(0.005),
0.3,
),
) # logit scale


p_ed_visit_w_sd_rv = DistributionalVariable(
"p_ed_visit_w_sd_sd", dist.TruncatedNormal(0, 0.01, low=0)
)


autoreg_p_ed_visit_rv = DistributionalVariable(
"autoreg_p_ed_visit_rv", dist.Beta(1, 100)
)

ed_visit_wday_effect_rv = TransformedVariable(
"ed_visit_wday_effect",
DistributionalVariable(
"ed_visit_wday_effect_raw",
dist.Dirichlet(jnp.array([5, 5, 5, 5, 5, 5, 5])),
),
transformation.AffineTransform(loc=0, scale=7),
)

# Based on looking at some historical posteriors.
phi_rv = DistributionalVariable("phi", dist.LogNormal(6, 1.5))
68 changes: 68 additions & 0 deletions pipelines/priors/prod_priors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import jax.numpy as jnp
import numpyro.distributions as dist
import pyrenew.transformation as transformation
from numpyro.infer.reparam import LocScaleReparam
from pyrenew.randomvariable import DistributionalVariable, TransformedVariable

i0_first_obs_n_rv = DistributionalVariable(
"i0_first_obs_n_rv",
dist.Beta(1, 10),
)

initialization_rate_rv = DistributionalVariable(
"rate", dist.Normal(0, 0.01), reparam=LocScaleReparam(0)
)

r_logmean = jnp.log(1.2)
r_logsd = jnp.log(jnp.sqrt(2))

log_r_mu_intercept_rv = DistributionalVariable(
"log_r_mu_intercept_rv", dist.Normal(r_logmean, r_logsd)
)

eta_sd_rv = DistributionalVariable(
"eta_sd", dist.TruncatedNormal(0.15, 0.05, low=0)
)

autoreg_rt_rv = DistributionalVariable("autoreg_rt", dist.Beta(2, 40))


inf_feedback_strength_rv = TransformedVariable(
"inf_feedback",
DistributionalVariable(
"inf_feedback_raw",
dist.LogNormal(jnp.log(50), jnp.log(1.5)),
),
transforms=transformation.AffineTransform(loc=0, scale=-1),
)
# Could be reparameterized?

p_ed_visit_mean_rv = DistributionalVariable(
"p_ed_visit_mean",
dist.Normal(
transformation.SigmoidTransform().inv(0.005),
0.3,
),
) # logit scale


p_ed_visit_w_sd_rv = DistributionalVariable(
"p_ed_visit_w_sd_sd", dist.TruncatedNormal(0, 0.01, low=0)
)


autoreg_p_ed_visit_rv = DistributionalVariable(
"autoreg_p_ed_visit_rv", dist.Beta(1, 100)
)

ed_visit_wday_effect_rv = TransformedVariable(
"ed_visit_wday_effect",
DistributionalVariable(
"ed_visit_wday_effect_raw",
dist.Dirichlet(jnp.array([5, 5, 5, 5, 5, 5, 5])),
),
transformation.AffineTransform(loc=0, scale=7),
)

# Based on looking at some historical posteriors.
phi_rv = DistributionalVariable("phi", dist.LogNormal(4, 1))
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pyarrow = "^18.0.0"
pygit2 = "^1.16.0"
azuretools = {git = "https://github.com/cdcgov/cfa-azuretools"}
forecasttools = {git = "https://github.com/CDCgov/forecasttools-py"}
tomli-w = "^1.1.0"

[tool.poetry.group.test]
optional = true
Expand Down

0 comments on commit 9884204

Please sign in to comment.