diff --git a/pipelines/batch/setup_prod_job.py b/pipelines/batch/setup_prod_job.py index 73c6c4ef..414d3d1f 100644 --- a/pipelines/batch/setup_prod_job.py +++ b/pipelines/batch/setup_prod_job.py @@ -22,7 +22,10 @@ def main( output_subdir: str | Path = "./", container_image_name: str = "pyrenew-hew", container_image_version: str = "latest", - excluded_locations: list[str] = [ + n_training_days: int = 90, + exclude_last_n_days: int = 1, + locations_include: list[str] = None, + locations_exclude: list[str] = [ "AS", "GU", "MO", @@ -62,9 +65,28 @@ def main( container_image_version Version of the container to use. Default 'latest'. - excluded_locations + n_training_days + Number of training days of data to use for model fitting. + Default 90. + + exclude_last_n_days + Number of days of available data to exclude from fitting. + Default 1. Note that we start the lookback for the + ``n_training_days`` of data after these exclusions, + so there will always be ``n_training_days`` of observations + for fitting; ``exclude_last_n_days`` determines where + the date range of observations starts and ends. + + locations_include + List of two-letter USPS location abbreviations for locations + to include in the job (unless explicitly excluded by + --locations-exclude). If ``None``, use all available + not-explicitly-excluded locations. Default ``None``. + + locations_exclude List of two letter USPS location abbreviations to - exclude from the job. Defaults to locations for which + exclude from the job. If ``None``, do not exclude any + locations. Defaults to a list of locations for which we typically do not have available NSSP ED visit data: ``["AS", "GU", "MO", "MP", "PR", "UM", "VI", "WY"]``. @@ -137,7 +159,7 @@ def main( "python pipelines/forecast_state.py " "--disease {disease} " "--state {state} " - "--n-training-days 90 " + "--n-training-days {n_training_days} " "--n-warmup {n_warmup} " "--n-samples {n_samples} " "--facility-level-nssp-data-dir nssp-etl/gold " @@ -147,7 +169,7 @@ def main( "--output-dir {output_dir} " "--priors-path config/prod_priors.py " "--report-date {report_date} " - "--exclude-last-n-days 1 " + "--exclude-last-n-days {exclude_last_n_days} " "--no-score " "--eval-data-path " "nssp-archival-vintages/latest_comprehensive.parquet" @@ -159,10 +181,16 @@ def main( "https://www2.census.gov/geo/docs/reference/state.txt", separator="|" ) + loc_abbrs = locations.get_column("STUSAB").to_list() + ["US"] + if locations_include is None: + locations_include = loc_abbrs + if locations_exclude is None: + locations_exclude = [] + all_locations = [ loc - for loc in locations.get_column("STUSAB").to_list() + ["US"] - if loc not in excluded_locations + for loc in loc_abbrs + if loc not in locations_exclude and loc in locations_include ] for disease, state in itertools.product(disease_list, all_locations): @@ -174,6 +202,8 @@ def main( report_date="latest", n_warmup=n_warmup, n_samples=n_samples, + n_training_days=n_training_days, + exclude_last_n_days=exclude_last_n_days, output_dir=str(Path("output", output_subdir)), ), container_settings=container_settings, @@ -226,7 +256,40 @@ def main( ) parser.add_argument( - "--excluded-locations", + "--n-training-days", + type=int, + help=( + "Number of 'training days' of observed data " + "to use for model fitting." + ), + default=90, +) + +parser.add_argument( + "--exclude-last-n-days", + type=int, + help=( + "Number of days to drop from the end of the timeseries " + "of observed data when constructing the training data." + ), + default=1, +) + +parser.add_argument( + "--locations-include", + type=str, + help=( + "Two-letter USPS location abbreviations to " + "include in the job, as a whitespace-separated " + "string. If not set, include all ", + "available locations except any explicitly excluded " + "via --locations-exclude.", + ), + default=None, +) + +parser.add_argument( + "--locations-exclude", type=str, help=( "Two-letter USPS location abbreviations to " @@ -242,5 +305,7 @@ def main( if __name__ == "__main__": args = parser.parse_args() args.diseases = args.diseases.split() - args.excluded_locations = args.excluded_locations.split() + if args.locations_include is not None: + args.locations_include = args.locations_include.split() + args.locations_exclude = args.locations_exclude.split() main(**vars(args))