Skip to content

Commit

Permalink
test working
Browse files Browse the repository at this point in the history
  • Loading branch information
damonbayer committed Dec 19, 2024
1 parent 3af7f1c commit 11c7edb
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 9 deletions.
1 change: 0 additions & 1 deletion pipelines/forecast_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,6 @@ def main(
facility_level_nssp_data=facility_level_nssp_data,
state_level_nssp_data=state_level_nssp_data,
report_date=report_date,
state_level_report_date=state_report_date,
first_training_date=first_training_date,
last_training_date=last_training_date,
param_estimates=param_estimates,
Expand Down
69 changes: 63 additions & 6 deletions pipelines/prep_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import subprocess
import forecasttools
import polars as pl
import polars.selectors as cs

_disease_map = {
"COVID-19": "COVID-19/Omicron",
Expand All @@ -26,7 +27,8 @@ def get_nhsn(
raise FileExistsError(f"Output file {output_file} already exists")
my_list = [
"Rscript",
"pull_nhsn.R", # relies on being in the correct working directory
"pipelines/pull_nhsn.R", # relies on being in the correct working directory
# probably better executed as a function in hewr and an inline R script here
"--start-date",
f"{start_date}",
"--end-date",
Expand All @@ -44,7 +46,9 @@ def get_nhsn(
)
if result.returncode != 0:
raise RuntimeError(f"pull_and_save_nhsn: {result.stderr}")
dat = pl.read_csv(output_file, separator="\t")
dat = pl.read_csv(output_file, separator="\t").with_columns(
weekendingdate=pl.col("weekendingdate").cast(pl.Date)
)
os.remove(output_file)
return dat

Expand Down Expand Up @@ -325,6 +329,20 @@ def process_and_save_state(
jurisdictions=nhsn_state_abb,
)

nssp_training_dates = (
nssp_training_data.get_column("date").unique().to_list()
)
nhsn_training_dates = (
nhsn_training_data.get_column("weekendingdate").unique().to_list()
)

nhsn_first_date_index = next(
i
for i, x in enumerate(nssp_training_dates)
if x == min(nhsn_training_dates)
)
nhsn_step_size = 7

train_disease_ed_visits = (
nssp_training_data.filter(pl.col("disease") == disease)
.get_column("ed_visits")
Expand All @@ -337,23 +355,62 @@ def process_and_save_state(
.to_list()
)

train_disease_hospital_admissions = nhsn_training_data.get_column(
"hospital_admissions"
).to_list()

data_for_model_fit = {
"inf_to_ed_pmf": delay_pmf,
"generation_interval_pmf": generation_interval_pmf,
"right_truncation_pmf": right_truncation_pmf,
"data_observed_disease_ed_visits": train_disease_ed_visits,
"data_observed_total_hospital_admissions": train_total_ed_visits,
"data_observed_disease_hospital_admissions": train_disease_hospital_admissions,
"nssp_training_dates": nssp_training_dates,
"nhsn_training_dates": nhsn_training_dates,
"nhsn_first_date_index": nhsn_first_date_index,
"nhsn_step_size": nhsn_step_size,
"state_pop": state_pop,
"right_truncation_offset": right_truncation_offset,
}
data_dir = Path(model_run_dir, "data")
os.makedirs(data_dir, exist_ok=True)

with open(Path(data_dir, "data_for_model_fit.json"), "w") as json_file:
json.dump(data_for_model_fit, json_file, default=str)

nssp_training_data_long = nssp_training_data.unpivot(
on="ed_visits",
index=cs.exclude("ed_visits"),
variable_name="value_type",
)

nhsn_training_data_long = (
nhsn_training_data.rename(
{"weekendingdate": "date", "jurisdiction": "geo_value"}
)
.unpivot(
on="hospital_admissions",
index=cs.exclude("hospital_admissions"),
variable_name="value_type",
)
.with_columns(
pl.lit(disease).alias("disease"),
pl.lit("train").alias("data_type"),
)
)

combined_training_dat = pl.concat(
[nssp_training_data_long, nhsn_training_data_long],
how="diagonal_relaxed",
).sort(["date", "geo_value", "value_type"])

if logger is not None:
logger.info(f"Saving {state_abb} to {data_dir}")
nssp_training_data.write_csv(Path(data_dir, "data.tsv"), separator="\t")

with open(Path(data_dir, "data_for_model_fit.json"), "w") as json_file:
json.dump(data_for_model_fit, json_file)

# post processing not yet updated for combined nhsn and nssp data
nssp_training_data.write_csv(Path(data_dir, "data.tsv"), separator="\t")
combined_training_dat.write_csv(
Path(data_dir, "combined_data.tsv"), separator="\t"
)
return None
4 changes: 2 additions & 2 deletions pipelines/pull_nhsn.R
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ dat <- pull_nhsn(
jurisdictions = jurisdictions,
) |>
mutate(weekendingdate = as_date(weekendingdate)) |>
rename(nhsn_admissions = !!unname(columns)) |>
mutate(nhsn_admissions = as.numeric(nhsn_admissions))
rename(hospital_admissions = !!unname(columns)) |>
mutate(hospital_admissions = as.numeric(hospital_admissions))

write_tsv(dat, output_file)

0 comments on commit 11c7edb

Please sign in to comment.