Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP exp gamma implementation #89

Draft
wants to merge 2 commits into
base: prod
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,26 @@ repos:
hooks:
- id: style-files
- id: lintr
#####

########
# Python
- repo: https://github.com/psf/black
rev: 23.10.0
hooks:
- id: black
args: ['--line-length', '79']
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
hooks:
- id: isort
args: ['--profile', 'black',
'--line-length', '79']
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.0
hooks:
- id: ruff

#########
# Secrets
- repo: https://github.com/Yelp/detect-secrets
rev: v1.4.0
Expand Down
6 changes: 4 additions & 2 deletions _targets_eval.R
Original file line number Diff line number Diff line change
Expand Up @@ -869,7 +869,8 @@ downstream_targets <- list(
name = plot_summarized_scores_w_data,
command = get_plot_scores_w_data(
grouped_submission_scores,
eval_hosp_data
eval_hosp_data,
eval_config$figure_dir
),
pattern = map(grouped_submission_scores),
iteration = "list",
Expand Down Expand Up @@ -900,7 +901,8 @@ downstream_targets <- list(
name = plot_quantile_comparison,
command = get_plot_quantile_comparison(
all_hosp_quantiles,
eval_hosp_data
eval_hosp_data,
eval_config$figure_dir
),
pattern = map(all_hosp_quantiles),
iteration = "list"
Expand Down
165 changes: 135 additions & 30 deletions _targets_eval_postprocessing.R
Original file line number Diff line number Diff line change
Expand Up @@ -141,14 +141,15 @@ combined_targets <- list(
## Flags------------------------------------------------------------------
tar_target(
name = all_flags_ww,
command = combine_outputs(
output_type = "flags",
scenarios = eval_config$scenario,
forecast_dates = eval_config$forecast_dates,
locations = eval_config$location_ww,
eval_output_subdir = eval_config$output_dir,
model_type = "ww"
)
command =
combine_outputs(
output_type = "flags",
scenarios = eval_config$scenario,
forecast_dates = eval_config$forecast_date_ww,
locations = eval_config$location_ww,
eval_output_subdir = eval_config$output_dir,
model_type = "ww"
)
),
tar_target(
name = all_flags_hosp,
Expand All @@ -161,6 +162,25 @@ combined_targets <- list(
model_type = "hosp"
)
),
tar_target(
name = all_flags,
command = dplyr::bind_rows(all_flags_ww, all_flags_hosp)
),
tar_target(
name = convergence_df_ww,
command = get_convergence_df(
all_flags_ww,
default_scenario = "status_quo"
) |>
dplyr::rename(any_flags_ww = any_flags)
),
tar_target(
name = convergence_df_hosp,
command = get_convergence_df(all_flags_hosp,
default_scenario = "no_wastewater"
) |>
dplyr::rename(any_flags_hosp = any_flags)
),
### Scores from quantiles-------------------------------------------------
tar_target(
name = all_ww_scores_quantiles,
Expand All @@ -169,7 +189,7 @@ combined_targets <- list(
scenarios = eval_config$scenario,
forecast_dates = eval_config$forecast_date_ww,
locations = eval_config$location_ww,
eval_output_subdir = file.path("output", "eval"),
eval_output_subdir = eval_config$output_dir,
model_type = "ww"
)
),
Expand All @@ -180,7 +200,7 @@ combined_targets <- list(
scenarios = "no_wastewater",
forecast_dates = eval_config$forecast_date_hosp,
locations = eval_config$location_hosp,
eval_output_subdir = file.path("output", "eval"),
eval_output_subdir = eval_config$output_dir,
model_type = "hosp"
)
),
Expand All @@ -192,7 +212,7 @@ combined_targets <- list(
scenarios = eval_config$scenario,
forecast_dates = eval_config$forecast_date_ww,
locations = eval_config$location_ww,
eval_output_subdir = file.path("output", "eval"),
eval_output_subdir = eval_config$output_dir,
model_type = "ww"
)
),
Expand All @@ -203,7 +223,7 @@ combined_targets <- list(
scenarios = "no_wastewater",
forecast_dates = eval_config$forecast_date_hosp,
locations = eval_config$location_hosp,
eval_output_subdir = file.path("output", "eval"),
eval_output_subdir = eval_config$output_dir,
model_type = "hosp"
)
),
Expand All @@ -214,7 +234,7 @@ combined_targets <- list(
scenarios = eval_config$scenario,
forecast_dates = eval_config$forecast_date_ww,
locations = eval_config$location_ww,
eval_output_subdir = file.path("output", "eval"),
eval_output_subdir = eval_config$output_dir,
model_type = "ww"
)
),
Expand All @@ -226,7 +246,7 @@ combined_targets <- list(
scenarios = eval_config$scenario,
forecast_dates = eval_config$forecast_date_ww,
locations = eval_config$location_ww,
eval_output_subdir = file.path("output", "eval"),
eval_output_subdir = eval_config$output_dir,
model_type = "ww"
)
),
Expand All @@ -237,15 +257,110 @@ combined_targets <- list(
scenarios = "no_wastewater",
forecast_dates = eval_config$forecast_date_hosp,
locations = eval_config$location_hosp,
eval_output_subdir = file.path("output", "eval"),
eval_output_subdir = eval_config$output_dir,
model_type = "hosp"
)
)
)

# Head-to-head comparison targets-------------------------------------------
# This set of targets will be conditioned on the presence of sufficient
# wastewater, whereas the below targets assume that for every location and
# forecast date we had to submit a forecast, and so we used the hospital
# admissions only model if wastewater was missing.
# These are only relevant for the status quo scenario
head_to_head_targets <- list(
tar_target(
name = all_ww_quantiles_sq,
command = all_ww_quantiles |>
dplyr::filter(scenario == "status_quo")
),
# Get a table of locations and forecast dates with sufficient wastewater
tar_target(
name = table_of_loc_dates_w_ww,
command = get_table_sufficient_ww(all_ww_quantiles)
),
# Get a table indicating whether there are locations and forecast dates with
# convergence issues
tar_target(
name = convergence_df,
command = convergence_df_hosp |>
dplyr::left_join(convergence_df_ww,
by = c("location", "forecast_date")
)
),
# Get the full set of quantiles, filtered down to only states and
# forecast dates with sufficient wastewater for both ww model and hosp only
# model. Then join the convergence df
tar_target(
name = hosp_quantiles_filtered,
command = dplyr::bind_rows(
all_ww_hosp_quantiles,
all_hosp_model_quantiles
) |>
dplyr::left_join(table_of_loc_dates_w_ww,
by = c("location", "forecast_date")
) |>
dplyr::filter(
ww_sufficient # filters to location forecast dates with sufficient ww
) |>
dplyr::left_join(
convergence_df,
by = c(
"location",
"forecast_date"
)
)
),
# Do the same thing for the sampled scores, combining ww and hosp under
# the status quo scenario, filtering to the locations and forecast dates
# with sufficient wastewater, and then joining the convergence flags
tar_target(
name = scores_filtered,
command = dplyr::bind_rows(
all_hosp_scores,
all_ww_scores |>
dplyr::filter(scenario == "status_quo")
) |>
dplyr::left_join(table_of_loc_dates_w_ww,
by = c("location", "forecast_date")
) |>
dplyr::filter(ww_sufficient) |>
dplyr::left_join(
convergence_df,
by = c(
"location",
"forecast_date"
)
)
),
# Repeat for the quantile-based scores
tar_target(
name = scores_quantiles_filtered,
command = dplyr::bind_rows(
all_hosp_scores_quantiles,
all_ww_scores_quantiles |>
dplyr::filter(scenario == "status_quo")
) |>
dplyr::left_join(table_of_loc_dates_w_ww,
by = c("location", "forecast_date")
) |>
dplyr::filter(
isTRUE(ww_sufficient)
) |>
dplyr::left_join(
convergence_df,
by = c(
"location",
"forecast_date"
)
)
)
)


# Head-to-head-scenario targets------------------------------------------------
head_to_head_scenario_targets <- list(
# Scenario targets------------------------------------------------
scenario_targets <- list(
tar_target(
name = all_raw_scores,
command = data.table::as.data.table(
Expand All @@ -262,18 +377,7 @@ head_to_head_scenario_targets <- list(
name = all_errors,
command = dplyr::bind_rows(all_hosp_errors, all_ww_errors)
),
tar_target(
name = all_flags,
command = dplyr::bind_rows(all_flags_ww, all_flags_hosp)
),
tar_target(
name = nonconverge_df,
command = all_flags |> distinct(
location, forecast_date, model, scenario
) |>
dplyr::mutate(convergence = FALSE) # Add a column that indicates did not pass
# convergence. We will use this in the head-to-head comparison to statify by convergence
),


## Raw scores-----------------------------------------
# These are the scores from each scenario and location without buffering
Expand Down Expand Up @@ -653,7 +757,8 @@ hub_comparison_plots <- list(
list(
upstream_targets,
combined_targets,
head_to_head_scenario_targets,
head_to_head_targets,
scenario_targets,
hub_targets,
hub_comparison_plots
)
5 changes: 5 additions & 0 deletions cfaforecastrenewalww/R/get_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,11 @@ site_level_inf_inits <- function(train_data, params, stan_data) {
autoreg_rt = abs(rnorm(1, autoreg_rt_a / (autoreg_rt_a + autoreg_rt_b), 0.05)),
log_r_mu_intercept = rnorm(1, convert_to_logmean(1, 0.1), convert_to_logsd(1, 0.1)),
error_site = matrix(rnorm(n_subpops * n_weeks, mean = 0, sd = 0.1), n_subpops, n_weeks),
zeta_bar = abs(matrix(
rnorm(n_subpops * (uot + ot + ht), mean = 0, sd = 0.1),
n_subpops, (uot + ot + ht)
)),
cv = abs(rnorm(1, 0.025, 0.005)),
autoreg_rt_site = abs(rnorm(1, 0.5, 0.05)),
autoreg_p_hosp = abs(rnorm(1, 1 / 100, 0.001)),
sigma_rt = abs(rnorm(1, 0, 0.01)),
Expand Down
Loading
Loading