Skip to content

Commit

Permalink
Make setup_prod_job.py more configurable (#248)
Browse files Browse the repository at this point in the history
Make setup_prod_job more configurable
  • Loading branch information
dylanhmorris authored Dec 16, 2024
1 parent 8250114 commit a69fa0c
Showing 1 changed file with 74 additions and 9 deletions.
83 changes: 74 additions & 9 deletions pipelines/batch/setup_prod_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ def main(
output_subdir: str | Path = "./",
container_image_name: str = "pyrenew-hew",
container_image_version: str = "latest",
excluded_locations: list[str] = [
n_training_days: int = 90,
exclude_last_n_days: int = 1,
locations_include: list[str] = None,
locations_exclude: list[str] = [
"AS",
"GU",
"MO",
Expand Down Expand Up @@ -62,9 +65,28 @@ def main(
container_image_version
Version of the container to use. Default 'latest'.
excluded_locations
n_training_days
Number of training days of data to use for model fitting.
Default 90.
exclude_last_n_days
Number of days of available data to exclude from fitting.
Default 1. Note that we start the lookback for the
``n_training_days`` of data after these exclusions,
so there will always be ``n_training_days`` of observations
for fitting; ``exclude_last_n_days`` determines where
the date range of observations starts and ends.
locations_include
List of two-letter USPS location abbreviations for locations
to include in the job (unless explicitly excluded by
--locations-exclude). If ``None``, use all available
not-explicitly-excluded locations. Default ``None``.
locations_exclude
List of two letter USPS location abbreviations to
exclude from the job. Defaults to locations for which
exclude from the job. If ``None``, do not exclude any
locations. Defaults to a list of locations for which
we typically do not have available NSSP ED visit data:
``["AS", "GU", "MO", "MP", "PR", "UM", "VI", "WY"]``.
Expand Down Expand Up @@ -137,7 +159,7 @@ def main(
"python pipelines/forecast_state.py "
"--disease {disease} "
"--state {state} "
"--n-training-days 90 "
"--n-training-days {n_training_days} "
"--n-warmup {n_warmup} "
"--n-samples {n_samples} "
"--facility-level-nssp-data-dir nssp-etl/gold "
Expand All @@ -147,7 +169,7 @@ def main(
"--output-dir {output_dir} "
"--priors-path config/prod_priors.py "
"--report-date {report_date} "
"--exclude-last-n-days 1 "
"--exclude-last-n-days {exclude_last_n_days} "
"--no-score "
"--eval-data-path "
"nssp-archival-vintages/latest_comprehensive.parquet"
Expand All @@ -159,10 +181,16 @@ def main(
"https://www2.census.gov/geo/docs/reference/state.txt", separator="|"
)

loc_abbrs = locations.get_column("STUSAB").to_list() + ["US"]
if locations_include is None:
locations_include = loc_abbrs
if locations_exclude is None:
locations_exclude = []

all_locations = [
loc
for loc in locations.get_column("STUSAB").to_list() + ["US"]
if loc not in excluded_locations
for loc in loc_abbrs
if loc not in locations_exclude and loc in locations_include
]

for disease, state in itertools.product(disease_list, all_locations):
Expand All @@ -174,6 +202,8 @@ def main(
report_date="latest",
n_warmup=n_warmup,
n_samples=n_samples,
n_training_days=n_training_days,
exclude_last_n_days=exclude_last_n_days,
output_dir=str(Path("output", output_subdir)),
),
container_settings=container_settings,
Expand Down Expand Up @@ -226,7 +256,40 @@ def main(
)

parser.add_argument(
"--excluded-locations",
"--n-training-days",
type=int,
help=(
"Number of 'training days' of observed data "
"to use for model fitting."
),
default=90,
)

parser.add_argument(
"--exclude-last-n-days",
type=int,
help=(
"Number of days to drop from the end of the timeseries "
"of observed data when constructing the training data."
),
default=1,
)

parser.add_argument(
"--locations-include",
type=str,
help=(
"Two-letter USPS location abbreviations to "
"include in the job, as a whitespace-separated "
"string. If not set, include all ",
"available locations except any explicitly excluded "
"via --locations-exclude.",
),
default=None,
)

parser.add_argument(
"--locations-exclude",
type=str,
help=(
"Two-letter USPS location abbreviations to "
Expand All @@ -242,5 +305,7 @@ def main(
if __name__ == "__main__":
args = parser.parse_args()
args.diseases = args.diseases.split()
args.excluded_locations = args.excluded_locations.split()
if args.locations_include is not None:
args.locations_include = args.locations_include.split()
args.locations_exclude = args.locations_exclude.split()
main(**vars(args))

0 comments on commit a69fa0c

Please sign in to comment.