Skip to content

Commit

Permalink
Create score_location.py
Browse files Browse the repository at this point in the history
  • Loading branch information
dylanhmorris committed Dec 19, 2024
1 parent 82460d4 commit 5db8bca
Showing 1 changed file with 231 additions and 0 deletions.
231 changes: 231 additions & 0 deletions pipelines/score_location.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
import argparse
import logging
import subprocess
from datetime import datetime, timedelta
from pathlib import Path

Check warning on line 5 in pipelines/score_location.py

View check run for this annotation

Codecov / codecov/patch

pipelines/score_location.py#L1-L5

Added lines #L1 - L5 were not covered by tests

import numpyro
from save_eval_data import save_eval_data
from utils import parse_model_batch_dir_name

Check warning on line 9 in pipelines/score_location.py

View check run for this annotation

Codecov / codecov/patch

pipelines/score_location.py#L7-L9

Added lines #L7 - L9 were not covered by tests

numpyro.set_host_device_count(4)

Check warning on line 11 in pipelines/score_location.py

View check run for this annotation

Codecov / codecov/patch

pipelines/score_location.py#L11

Added line #L11 was not covered by tests

from fit_model import fit_and_save_model # noqa
from generate_predictive import generate_and_save_predictions # noqa

Check warning on line 14 in pipelines/score_location.py

View check run for this annotation

Codecov / codecov/patch

pipelines/score_location.py#L13-L14

Added lines #L13 - L14 were not covered by tests


def generate_epiweekly(model_run_dir: Path) -> None:
result = subprocess.run(

Check warning on line 18 in pipelines/score_location.py

View check run for this annotation

Codecov / codecov/patch

pipelines/score_location.py#L17-L18

Added lines #L17 - L18 were not covered by tests
[
"Rscript",
"pipelines/generate_epiweekly.R",
f"{model_run_dir}",
],
capture_output=True,
)
if result.returncode != 0:
raise RuntimeError(f"generate_epiweekly: {result.stderr}")
return None

Check warning on line 28 in pipelines/score_location.py

View check run for this annotation

Codecov / codecov/patch

pipelines/score_location.py#L26-L28

Added lines #L26 - L28 were not covered by tests


def timeseries_forecasts(

Check warning on line 31 in pipelines/score_location.py

View check run for this annotation

Codecov / codecov/patch

pipelines/score_location.py#L31

Added line #L31 was not covered by tests
model_run_dir: Path, model_name: str, n_forecast_days: int, n_samples: int
) -> None:
result = subprocess.run(

Check warning on line 34 in pipelines/score_location.py

View check run for this annotation

Codecov / codecov/patch

pipelines/score_location.py#L34

Added line #L34 was not covered by tests
[
"Rscript",
"pipelines/timeseries_forecasts.R",
f"{model_run_dir}",
"--model-name",
f"{model_name}",
"--n-forecast-days",
f"{n_forecast_days}",
"--n-samples",
f"{n_samples}",
],
capture_output=True,
)
if result.returncode != 0:
raise RuntimeError(f"timeseries_forecasts: {result.stderr}")
return None

Check warning on line 50 in pipelines/score_location.py

View check run for this annotation

Codecov / codecov/patch

pipelines/score_location.py#L48-L50

Added lines #L48 - L50 were not covered by tests


def convert_inferencedata_to_parquet(

Check warning on line 53 in pipelines/score_location.py

View check run for this annotation

Codecov / codecov/patch

pipelines/score_location.py#L53

Added line #L53 was not covered by tests
model_run_dir: Path, model_name: str
) -> None:
result = subprocess.run(

Check warning on line 56 in pipelines/score_location.py

View check run for this annotation

Codecov / codecov/patch

pipelines/score_location.py#L56

Added line #L56 was not covered by tests
[
"Rscript",
"pipelines/convert_inferencedata_to_parquet.R",
f"{model_run_dir}",
"--model-name",
f"{model_name}",
],
capture_output=True,
)
if result.returncode != 0:
raise RuntimeError(

Check warning on line 67 in pipelines/score_location.py

View check run for this annotation

Codecov / codecov/patch

pipelines/score_location.py#L66-L67

Added lines #L66 - L67 were not covered by tests
f"convert_inferencedata_to_parquet: {result.stderr}"
)
return None

Check warning on line 70 in pipelines/score_location.py

View check run for this annotation

Codecov / codecov/patch

pipelines/score_location.py#L70

Added line #L70 was not covered by tests


def postprocess_forecast(

Check warning on line 73 in pipelines/score_location.py

View check run for this annotation

Codecov / codecov/patch

pipelines/score_location.py#L73

Added line #L73 was not covered by tests
model_run_dir: Path, pyrenew_model_name: str, timeseries_model_name: str
) -> None:
result = subprocess.run(

Check warning on line 76 in pipelines/score_location.py

View check run for this annotation

Codecov / codecov/patch

pipelines/score_location.py#L76

Added line #L76 was not covered by tests
[
"Rscript",
"pipelines/postprocess_state_forecast.R",
f"{model_run_dir}",
"--pyrenew-model-name",
f"{pyrenew_model_name}",
"--timeseries-model-name",
f"{timeseries_model_name}",
],
capture_output=True,
)
if result.returncode != 0:
raise RuntimeError(f"postprocess_forecast: {result.stderr}")
return None

Check warning on line 90 in pipelines/score_location.py

View check run for this annotation

Codecov / codecov/patch

pipelines/score_location.py#L88-L90

Added lines #L88 - L90 were not covered by tests


def score_forecast(model_run_dir: Path) -> None:
result = subprocess.run(

Check warning on line 94 in pipelines/score_location.py

View check run for this annotation

Codecov / codecov/patch

pipelines/score_location.py#L93-L94

Added lines #L93 - L94 were not covered by tests
[
"Rscript",
"pipelines/score_forecast.R",
f"{model_run_dir}",
],
capture_output=True,
)
if result.returncode != 0:
raise RuntimeError(f"score_forecast: {result.stderr}")
return None

Check warning on line 104 in pipelines/score_location.py

View check run for this annotation

Codecov / codecov/patch

pipelines/score_location.py#L102-L104

Added lines #L102 - L104 were not covered by tests


def render_webpage(model_run_dir: Path) -> None:
result = subprocess.run(

Check warning on line 108 in pipelines/score_location.py

View check run for this annotation

Codecov / codecov/patch

pipelines/score_location.py#L107-L108

Added lines #L107 - L108 were not covered by tests
[
"Rscript",
"pipelines/render_webpage.R",
f"{model_run_dir}",
],
capture_output=True,
)
if result.returncode != 0:
raise RuntimeError(f"render_webpage: {result.stderr}")
return None

Check warning on line 118 in pipelines/score_location.py

View check run for this annotation

Codecov / codecov/patch

pipelines/score_location.py#L116-L118

Added lines #L116 - L118 were not covered by tests


def get_available_reports(

Check warning on line 121 in pipelines/score_location.py

View check run for this annotation

Codecov / codecov/patch

pipelines/score_location.py#L121

Added line #L121 was not covered by tests
data_dir: str | Path, glob_pattern: str = "*.parquet"
):
return [

Check warning on line 124 in pipelines/score_location.py

View check run for this annotation

Codecov / codecov/patch

pipelines/score_location.py#L124

Added line #L124 was not covered by tests
datetime.strptime(f.stem, "%Y-%m-%d").date()
for f in Path(data_dir).glob(glob_pattern)
]


def main(state, model_batch_dir_path: Path, eval_data_path: Path):
model_batch_dir_path = Path(model_batch_dir_path)
eval_data_path = Path(eval_data_path)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

Check warning on line 134 in pipelines/score_location.py

View check run for this annotation

Codecov / codecov/patch

pipelines/score_location.py#L130-L134

Added lines #L130 - L134 were not covered by tests

batch_info = parse_model_batch_dir_name(model_batch_dir_path.name)

Check warning on line 136 in pipelines/score_location.py

View check run for this annotation

Codecov / codecov/patch

pipelines/score_location.py#L136

Added line #L136 was not covered by tests

logger.info("Getting eval data...")
save_eval_data(

Check warning on line 139 in pipelines/score_location.py

View check run for this annotation

Codecov / codecov/patch

pipelines/score_location.py#L138-L139

Added lines #L138 - L139 were not covered by tests
state=state,
latest_comprehensive_path=eval_data_path,
output_data_dir=Path(model_run_dir, "data", "eval"),
last_eval_date=(
batch_info["report_date"] + timedelta(days=n_forecast_days)
),
**batch_info,
)

logger.info("Generating epiweekly datasets from daily datasets...")
generate_epiweekly(model_run_dir)

Check warning on line 150 in pipelines/score_location.py

View check run for this annotation

Codecov / codecov/patch

pipelines/score_location.py#L149-L150

Added lines #L149 - L150 were not covered by tests

logger.info("Data preparation complete.")

Check warning on line 152 in pipelines/score_location.py

View check run for this annotation

Codecov / codecov/patch

pipelines/score_location.py#L152

Added line #L152 was not covered by tests

logger.info("Fitting model")
fit_and_save_model(

Check warning on line 155 in pipelines/score_location.py

View check run for this annotation

Codecov / codecov/patch

pipelines/score_location.py#L154-L155

Added lines #L154 - L155 were not covered by tests
model_run_dir,
"pyrenew_e",
n_warmup=n_warmup,
n_samples=n_samples,
n_chains=n_chains,
)
logger.info("Model fitting complete")

Check warning on line 162 in pipelines/score_location.py

View check run for this annotation

Codecov / codecov/patch

pipelines/score_location.py#L162

Added line #L162 was not covered by tests

logger.info("Performing posterior prediction / forecasting...")

Check warning on line 164 in pipelines/score_location.py

View check run for this annotation

Codecov / codecov/patch

pipelines/score_location.py#L164

Added line #L164 was not covered by tests

n_days_past_last_training = n_forecast_days + exclude_last_n_days
generate_and_save_predictions(

Check warning on line 167 in pipelines/score_location.py

View check run for this annotation

Codecov / codecov/patch

pipelines/score_location.py#L166-L167

Added lines #L166 - L167 were not covered by tests
model_run_dir, "pyrenew_e", n_days_past_last_training
)

logger.info(

Check warning on line 171 in pipelines/score_location.py

View check run for this annotation

Codecov / codecov/patch

pipelines/score_location.py#L171

Added line #L171 was not covered by tests
"Performing baseline forecasting and non-target pathogen "
"forecasting..."
)
n_denominator_samples = n_samples * n_chains
timeseries_forecasts(

Check warning on line 176 in pipelines/score_location.py

View check run for this annotation

Codecov / codecov/patch

pipelines/score_location.py#L175-L176

Added lines #L175 - L176 were not covered by tests
model_run_dir,
"timeseries_e",
n_days_past_last_training,
n_denominator_samples,
)
logger.info("All forecasting complete.")

Check warning on line 182 in pipelines/score_location.py

View check run for this annotation

Codecov / codecov/patch

pipelines/score_location.py#L182

Added line #L182 was not covered by tests

logger.info("Converting inferencedata to parquet...")
convert_inferencedata_to_parquet(model_run_dir, "pyrenew_e")
logger.info("Conversion complete.")

Check warning on line 186 in pipelines/score_location.py

View check run for this annotation

Codecov / codecov/patch

pipelines/score_location.py#L184-L186

Added lines #L184 - L186 were not covered by tests

logger.info("Postprocessing forecast...")
postprocess_forecast(model_run_dir, "pyrenew_e", "timeseries_e")
logger.info("Postprocessing complete.")

Check warning on line 190 in pipelines/score_location.py

View check run for this annotation

Codecov / codecov/patch

pipelines/score_location.py#L188-L190

Added lines #L188 - L190 were not covered by tests

logger.info("Rendering webpage...")
render_webpage(model_run_dir)
logger.info("Rendering complete.")

Check warning on line 194 in pipelines/score_location.py

View check run for this annotation

Codecov / codecov/patch

pipelines/score_location.py#L192-L194

Added lines #L192 - L194 were not covered by tests

if score:
logger.info("Scoring forecast...")
score_forecast(model_run_dir)

Check warning on line 198 in pipelines/score_location.py

View check run for this annotation

Codecov / codecov/patch

pipelines/score_location.py#L196-L198

Added lines #L196 - L198 were not covered by tests

logger.info(

Check warning on line 200 in pipelines/score_location.py

View check run for this annotation

Codecov / codecov/patch

pipelines/score_location.py#L200

Added line #L200 was not covered by tests
"Single state pipeline complete "
f"for state {state} with "
f"report date {report_date}."
)
return None

Check warning on line 205 in pipelines/score_location.py

View check run for this annotation

Codecov / codecov/patch

pipelines/score_location.py#L205

Added line #L205 was not covered by tests


if __name__ == "__main__":
parser = argparse.ArgumentParser(

Check warning on line 209 in pipelines/score_location.py

View check run for this annotation

Codecov / codecov/patch

pipelines/score_location.py#L208-L209

Added lines #L208 - L209 were not covered by tests
description="Create fit data for disease modeling."
)
parser.add_argument(

Check warning on line 212 in pipelines/score_location.py

View check run for this annotation

Codecov / codecov/patch

pipelines/score_location.py#L212

Added line #L212 was not covered by tests
"--disease",
type=str,
required=True,
help="Disease to model (e.g., COVID-19, Influenza, RSV).",
)

parser.add_argument(

Check warning on line 219 in pipelines/score_location.py

View check run for this annotation

Codecov / codecov/patch

pipelines/score_location.py#L219

Added line #L219 was not covered by tests
"--state",
type=str,
required=True,
help=(
"Two letter abbreviation for the state to fit"
"(e.g. 'AK', 'AL', 'AZ', etc.)."
),
)

args = parser.parse_args()
numpyro.set_host_device_count(args.n_chains)
main(**vars(args))

Check warning on line 231 in pipelines/score_location.py

View check run for this annotation

Codecov / codecov/patch

pipelines/score_location.py#L229-L231

Added lines #L229 - L231 were not covered by tests

0 comments on commit 5db8bca

Please sign in to comment.