Skip to content

Commit

Permalink
Use forecasttools in score_hubverse, add explicit include functionali…
Browse files Browse the repository at this point in the history
…ty to prod job setup
  • Loading branch information
dylanhmorris committed Dec 16, 2024
1 parent 1db06a6 commit 844c2ad
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 7 deletions.
43 changes: 36 additions & 7 deletions pipelines/batch/setup_prod_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ def main(
output_subdir: str | Path = "./",
container_image_name: str = "pyrenew-hew",
container_image_version: str = "latest",
excluded_locations: list[str] = [
locations_include: list[str] = None,
locations_exclude: list[str] = [
"AS",
"GU",
"MO",
Expand Down Expand Up @@ -62,9 +63,16 @@ def main(
container_image_version
Version of the container to use. Default 'latest'.
excluded_locations
locations_include
List of two-letter USPS location abbreviations for locations
to include in the job (unless explicitly excluded by
--locations-exclude). If ``None``, use all available
not-explicitly-excluded locations. Default ``None``.
locations_exclude
List of two letter USPS location abbreviations to
exclude from the job. Defaults to locations for which
exclude from the job. If ``None``, do not exclude any
locations. Defaults to a list of locations for which
we typically do not have available NSSP ED visit data:
``["AS", "GU", "MO", "MP", "PR", "UM", "VI", "WY"]``.
Expand Down Expand Up @@ -159,10 +167,16 @@ def main(
"https://www2.census.gov/geo/docs/reference/state.txt", separator="|"
)

loc_abbrs = locations.get_column("STUSAB").to_list() + ["US"]
if locations_include is None:
locations_include = loc_abbrs
if locations_exclude is None:
locations_exclude = []

all_locations = [
loc
for loc in locations.get_column("STUSAB").to_list() + ["US"]
if loc not in excluded_locations
for loc in loc_abbrs
if loc not in locations_exclude and loc in locations_include
]

for disease, state in itertools.product(disease_list, all_locations):
Expand Down Expand Up @@ -226,7 +240,20 @@ def main(
)

parser.add_argument(
"--excluded-locations",
"--locations-include",
type=str,
help=(
"Two-letter USPS location abbreviations to "
"include in the job, as a whitespace-separated "
"string. If not set, include all ",
"available locations except any explicitly excluded "
"via --locations-exclude.",
),
default=None,
)

parser.add_argument(
"--locations-exclude",
type=str,
help=(
"Two-letter USPS location abbreviations to "
Expand All @@ -242,5 +269,7 @@ def main(
if __name__ == "__main__":
args = parser.parse_args()
args.diseases = args.diseases.split()
args.excluded_locations = args.excluded_locations.split()
if args.locations_include is not None:
args.locations_include = args.locations_include.split()
args.locations_exclude = args.locations_exclude.split()
main(**vars(args))
2 changes: 2 additions & 0 deletions pipelines/score_hubverse.R
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,8 @@ score_and_save <- function(observed_data_path,
purrr::pmap(read_and_prep_for_scoring) |>
dplyr::bind_rows()

message("Finished reading in forecasts and preparing for scoring.")
message("Scoring forecasts...")
full_scores <- hewr::score_hewr(
full_scoreable_table
)
Expand Down

0 comments on commit 844c2ad

Please sign in to comment.