Skip to content

Commit

Permalink
update codebase to reflect changes from 2024-03-25 run
Browse files Browse the repository at this point in the history
  • Loading branch information
kaitejohnson committed Mar 25, 2024
1 parent fdb92f6 commit 6f666cc
Show file tree
Hide file tree
Showing 23 changed files with 552 additions and 198 deletions.
102 changes: 90 additions & 12 deletions _targets.R
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ list(
command = c(),
deployment = "main"
),
tar_target(
name = exclude_states,
command = c(),
deployment = "main"
),
# If want to enforce a global change, need to tar_destroy to
# regenerate run_id
tar_target(
Expand All @@ -112,7 +117,7 @@ list(
command = write_config(
save_config = TRUE,
location = NULL,
prod_run = TRUE,
prod_run = FALSE,
run_id = run_id,
ww_geo_type = "state",
date_run = date_run,
Expand All @@ -132,7 +137,7 @@ list(
command = write_config(
save_config = TRUE,
location = "US",
prod_run = TRUE,
prod_run = FALSE,
run_id = run_id,
ww_geo_type = "state",
date_run = date_run,
Expand All @@ -152,7 +157,7 @@ list(
command = write_config(
save_config = TRUE,
location = NULL,
prod_run = TRUE,
prod_run = FALSE,
run_id = run_id,
date_run = date_run,
model_type = "site-level infection dynamics",
Expand Down Expand Up @@ -250,10 +255,10 @@ list(
),


# # Fit the model ------------------------------------------------------------
# # get a stacked long dataframe containing the quantiles(estimated
# # from all draws) and 100 samples of the draws from the posterior for the
# # generated quantities and the parameters
## Fit the model ------------------------------------------------------------
# get a stacked long dataframe containing the quantiles(estimated
# from all draws) and 100 samples of the draws from the posterior for the
# generated quantities and the parameters
tar_target(
name = df_of_filepaths_id,
command = do.call(
Expand Down Expand Up @@ -292,6 +297,18 @@ list(
iteration = "list",
deployment = "main"
),
tar_target(
name = plot_single_location_hosp_draws_log_id,
command = get_plot_draws(grouped_df_id,
"pred_hosp",
figure_output_subdirectory,
log_scale = TRUE,
show_calibration_data = FALSE
),
pattern = map(grouped_df_id),
iteration = "list",
deployment = "main"
),
tar_target(
name = plot_single_location_comb_quantiles_id,
command = get_combo_quantile_plot(
Expand Down Expand Up @@ -333,8 +350,8 @@ list(
deployment = "main"
),
tar_target(
name = plot_rt_site_level,
command = get_rt_site_level(
name = plot_rt_subpop_level,
command = get_rt_subpop_level(
grouped_df_id,
figure_output_subdirectory
),
Expand Down Expand Up @@ -470,6 +487,18 @@ list(
iteration = "list",
deployment = "main"
),
tar_target(
name = plot_single_location_hosp_draws_log_ho,
command = get_plot_draws(grouped_df,
"pred_hosp",
figure_output_subdirectory,
log_scale = TRUE,
show_calibration_data = FALSE
),
pattern = map(grouped_df),
iteration = "list",
deployment = "main"
),
tar_target(
name = plot_rt_ho,
command = get_rt_from_draws(
Expand Down Expand Up @@ -617,6 +646,18 @@ list(
iteration = "list",
deployment = "main"
),
tar_target(
name = plot_single_location_hosp_draws_log_sa,
command = get_plot_draws(grouped_df_sa,
"pred_hosp",
figure_output_subdirectory,
log_scale = TRUE,
show_calibration_data = FALSE
),
pattern = map(grouped_df_sa),
iteration = "list",
deployment = "main"
),
tar_target(
name = plot_single_location_ww_draws_sa,
command = get_plot_draws(df_of_filepaths_us,
Expand Down Expand Up @@ -687,6 +728,20 @@ list(
iteration = "list",
deployment = "main"
),
tar_target(
name = plot_mult_models_log,
command = get_plot_draws(
grouped_df_comb,
"pred_hosp",
grouping_var = "model_type",
figure_output_subdirectory,
log_scale = TRUE,
show_calibration_data = FALSE
),
pattern = map(grouped_df_comb),
iteration = "list",
deployment = "main"
),
tar_target(
name = rt_box_plot_id,
command = get_rt_boxplot_across_states(
Expand Down Expand Up @@ -782,6 +837,24 @@ list(
),
deployment = "main"
),
tar_target(
name = pdf_of_forecast_comparisons_log,
command = do.call(
save_to_pdf,
c(list(list_of_plots = plot_mult_models_log),
type_of_output = "forecasts_from_mult_model_types_log",
pdf_file_path = file.path(
pdf_output_subdirectory,
"internal"
),
model_name = config_vars_id$submitting_model_name,
n_row = 3,
n_col = 1,
config_vars_id
)
),
deployment = "main"
),
# Get model run diagnostics for submission -------------------------------------
tar_target(
name = full_diagnostics_df,
Expand All @@ -798,8 +871,9 @@ list(
tar_target(
name = loc_model_map_submission,
command = get_loc_model_map(
df_of_filepaths_id,
hosp_only_states
df_of_filepaths = df_of_filepaths_id,
hosp_only_states = hosp_only_states,
exclude_states = exclude_states
),
deployment = "main"
),
Expand Down Expand Up @@ -844,6 +918,7 @@ list(
command = get_location_notes_table(
full_diagnostics_df,
hosp_only_states,
exclude_states,
output_dir = repo_file_path,
prod_run = config_vars_id$prod_run
),
Expand All @@ -854,6 +929,7 @@ list(
command = get_metadata_yaml(
data_diagnostics_df = full_diagnostics_df,
hosp_only_states = hosp_only_states,
exclude_states = exclude_states,
output_dir = repo_file_path,
prod_run = config_vars_id$prod_run
)
Expand Down Expand Up @@ -933,6 +1009,7 @@ list(
command = get_loc_model_map(
df_of_filepaths_ho |> dplyr::filter(location != "US"),
hosp_only_states,
exclude_states = c(),
us_model_type = "hospital admissions only"
),
deployment = "main"
Expand Down Expand Up @@ -1002,7 +1079,8 @@ list(
tar_target(
name = loc_model_map_ww,
command = get_loc_model_map(df_of_filepaths_id,
hosp_only_states = c()
hosp_only_states = c(),
exclude_states = c()
),
deployment = "main"
),
Expand Down
3 changes: 2 additions & 1 deletion cfaforecastrenewalww/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ export(get_regions_for_mapping)
export(get_relative_forecast_dir)
export(get_rt_boxplot_across_states)
export(get_rt_from_draws)
export(get_rt_site_level)
export(get_rt_subpop_level)
export(get_scores)
export(get_secret)
export(get_site_county_map)
Expand Down Expand Up @@ -108,6 +108,7 @@ export(site_level_inf_inits)
export(soql_is_in)
export(standardize)
export(state_agg_inits)
export(subsample_sites)
export(summarize_scores)
export(to_simplex)
export(trajectories_to_quantiles)
Expand Down
92 changes: 76 additions & 16 deletions cfaforecastrenewalww/R/fit_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -582,18 +582,78 @@ fit_site_level_model <- function(train_data,
ww = NA,
site = NA,
site_index = NA,
subpop_index = NA,
lab = NA,
lab_wwtp_unique_id = NA,
subpop = NA,
below_LOD = NA,
flag_as_ww_outlier = NA,
lod_sewage = NA,
ww_pop = NA
ww_pop = NA,
subpop_size = NA
)

# This should be at the lab site level
# From train_data, identify whether we need to fit an additional subpop
pop_ww <- train_data %>%
select(site_index, ww_pop) %>%
filter(!is.na(site_index)) %>%
group_by(site_index) %>%
summarise(pop_avg = mean(ww_pop)) %>%
arrange(site_index, "desc") %>%
pull(pop_avg)
pop <- train_data %>%
select(pop) %>%
unique() %>%
pull(pop)

add_auxiliary_site <- ifelse(pop >= sum(pop_ww), TRUE, FALSE)

# From the input data, which labs align with which sites and
# which population sizes
site_map_raw <- train_data %>%
select(
lab_site_index, lab_wwtp_unique_id, site_index,
lab, site, ww_pop
) %>%
distinct() %>%
filter(!is.na(site_index))

# Create a new variable called subpop, link it to site indexes
site_subpop_map <- site_map_raw %>%
mutate(subpop = paste0("Site: ", site)) %>%
select(site_index, subpop, ww_pop) %>%
rename(
subpop_index = site_index,
subpop_size = ww_pop
)
# Link subpops to the rest of the site and lab metadata
site_map <- site_map_raw %>%
left_join(site_subpop_map, by = c("site_index" = "subpop_index")) %>%
mutate(subpop_index = site_index)

# If we add an auxiliary site, this means we need to track an additional
# subpopulation whose population is the difference between the state pop
# and the sum of the wastewater site populations
if (isTRUE(add_auxiliary_site)) {
extra_subpop <- data.frame(
subpop_index = max(site_map$site_index) + 1,
subpop = "Pop not on wastewater",
subpop_size = pop - sum(pop_ww)
)

subpop_map <- rbind(site_subpop_map, extra_subpop)
} else {
subpop_map <- site_subpop_map
}




# This should be at the lab site level, but you want to map to
# sites and subpops
gen_quants_draws_w_ww <- gen_quants_draws %>%
filter(name == "pred_ww") %>%
select(-site_index) %>%
select(-subpop_index) %>%
mutate(
include_ww = include_ww,
forecast_date = forecast_date,
Expand All @@ -603,13 +663,7 @@ fit_site_level_model <- function(train_data,
) %>%
left_join(date_df, by = "t") %>%
ungroup() %>%
left_join(
train_data %>%
select(
lab_site_index, lab_wwtp_unique_id, site_index,
lab, site
) %>%
distinct(),
left_join(site_map,
by = c("lab_site_index")
) %>%
left_join(
Expand All @@ -626,12 +680,15 @@ fit_site_level_model <- function(train_data,
ungroup() %>%
select(
lab_site_index, date, ww,
below_LOD, flag_as_ww_outlier, lod_sewage, ww_pop
below_LOD, flag_as_ww_outlier, lod_sewage
) %>%
filter(!is.na(ww)) %>%
unique(),
by = c("date", "lab_site_index")
) %>%
mutate(
subpop_index = site_index
) %>%
select(colnames(gen_quants_draws_non_ww))

if (model_type != "site-level infection dynamics") {
Expand All @@ -652,12 +709,15 @@ fit_site_level_model <- function(train_data,
) %>%
left_join(date_df, by = "t") %>%
ungroup() %>%
left_join(
train_data %>%
select(site_index, site, ww_pop) %>%
distinct(),
by = c("site_index")
left_join(subpop_map,
by = c("subpop_index")
) %>%
left_join(site_map %>%
select(
subpop_index,
site_index, site, ww_pop
) %>%
distinct(), by = "subpop_index") %>%
mutate(
lab_site_index = NA,
lab_wwtp_unique_id = NA,
Expand Down
Loading

0 comments on commit 6f666cc

Please sign in to comment.