generated from CDCgov/template
-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
5a03930
commit 5096303
Showing
46 changed files
with
3,929 additions
and
0 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# This is a script used to generate the global inputs needed to run the | ||
# _targets_eval.R function. Editing the locations and forecast dates | ||
# and model types tells the evaluation pipeline which combinations to | ||
# pass to tar map to iterate over. | ||
source(file.path("src", "write_eval_config.R")) | ||
write_eval_config( | ||
locations = c("MA", "WA"), | ||
forecast_dates = c( | ||
"2023-10-16", "2023-10-23", "2023-10-30", | ||
"2023-11-06", "2023-11-13", "2023-11-20", | ||
"2023-11-27", "2023-12-04", "2023-12-11", | ||
"2023-12-18", "2023-12-25", "2024-01-01", | ||
"2024-01-08", "2024-01-15", "2024-01-22", | ||
"2024-01-29", "2024-02-05", "2024-02-12", | ||
"2024-02-19" | ||
), | ||
scenarios = c("Status quo", "One site per jurisdiction"), | ||
config_dir = file.path("input", "config", "eval"), | ||
eval_date = "2024-03-25" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
#' Write evaluation config file | ||
#' | ||
#' @param locations locations to iterate through, for a full run this should | ||
#' be all 50 states + PR | ||
#' @param forecast_dates the forecast dates we want to run the model on | ||
#' @param scenatios the scenarios (which will pertain to site ids) to | ||
#' run the model on | ||
#' @param config_dir the directory where we want to save the config file | ||
#' | ||
#' @return | ||
#' @export | ||
#' | ||
#' @examples | ||
write_eval_config <- function(locations, forecast_dates, | ||
scenarios, | ||
config_dir, | ||
eval_date) { | ||
# get a dataframe of all the iteration combinations for the wastewater mapping | ||
df_ww <- expand.grid( | ||
location = locations, | ||
forecast_date = forecast_dates, | ||
scenario = scenarios | ||
) | ||
# No scenarios for the hosp admissions only, so we just need all combos | ||
# of locations and forecast dates | ||
df_hosp <- expand.grid( | ||
location = locations, | ||
forecast_date = forecast_dates | ||
) | ||
|
||
# Specify other variables | ||
ww_data_dir <- file.path("input", "ww_data", "monday_datasets") | ||
scenario_dir <- file.path("input", "config", "eval", "scenarios") | ||
hosp_data_dir <- file.path("input", "hosp_data", "vintage_datasets") | ||
# stan_models_dir <- system.file("stan", package = "cfaforecastrenewalww") #nolint | ||
stan_models_dir <- file.path("cfaforecastrenewalww", "inst", "stan") | ||
init_dir <- file.path("input", "init_lists") | ||
|
||
ww_data_mapping <- "Monday: Monday, Wednesday: Monday" | ||
calibration_time <- 90 | ||
forecast_time <- 28 | ||
|
||
iter_warmup <- 750 | ||
iter_sampling <- 500 | ||
n_chains <- 4 | ||
n_parallel_chains <- 4 | ||
adapt_delta <- 0.95 | ||
max_treedepth <- 12 | ||
seed <- 123 | ||
|
||
init_fps <- c() | ||
for (i in 1:n_chains) { | ||
init_fps <- c(init_fps, file.path(init_dir, glue::glue("init_{i}.json"))) | ||
} | ||
|
||
# Pre-specified delay distributions | ||
generation_interval <- read.csv(here::here( | ||
"input", "saved_pmfs", | ||
"generation_interval.csv" | ||
)) |> | ||
dplyr::pull(probability_mass) | ||
inf_to_hosp <- read.csv(here::here( | ||
"input", "saved_pmfs", | ||
"inf_to_hosp.csv" | ||
)) |> | ||
dplyr::pull(probability_mass) | ||
|
||
config <- list( | ||
location_ww = df_ww |> dplyr::pull(location) |> as.vector(), | ||
forecast_date_ww = df_ww |> dplyr::pull(forecast_date) |> as.vector(), | ||
scenario = df_ww |> dplyr::pull(scenario) |> as.vector(), | ||
location_hosp = df_hosp |> dplyr::pull(location) |> as.vector(), | ||
forecast_date_hosp = df_hosp |> dplyr::pull(forecast_date) |> as.vector(), | ||
eval_date = eval_date, | ||
ww_data_dir = ww_data_dir, | ||
scenario_dir = scenario_dir, | ||
hosp_data_dir = hosp_data_dir, | ||
stan_models_dir = stan_models_dir, | ||
init_dir = init_dir, | ||
init_fps = init_fps, | ||
calibration_time = calibration_time, | ||
forecast_time = forecast_time, | ||
ww_data_mapping = ww_data_mapping, | ||
# MCMC settings | ||
iter_warmup = iter_warmup, | ||
iter_sampling = iter_sampling, | ||
n_chains = n_chains, | ||
n_parallel_chains = n_parallel_chains, | ||
adapt_delta = adapt_delta, | ||
max_treedepth = max_treedepth, | ||
seed = seed, | ||
# Input delay distributions | ||
generation_interval = generation_interval, | ||
infection_feedback_pmf = generation_interval, | ||
inf_to_hosp = inf_to_hosp | ||
) | ||
|
||
cfaforecastrenewalww::create_dir(config_dir) | ||
yaml::write_yaml(config, file = file.path( | ||
config_dir, | ||
glue::glue("eval_config.yaml") | ||
)) | ||
|
||
return(config) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
^wweval\.Rproj$ | ||
^\.Rproj\.user$ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
Package: wweval | ||
Title: Evaluation of wastewater informed COVID-19 hospital admissions forecasting | ||
Version: 0.0.0.9000 | ||
Authors@R: | ||
c(person(given = "Kaitlyn", | ||
family = "Johnson", | ||
role = c("aut", "cre"), | ||
email = "[email protected]", | ||
comment = c(ORCID = "0000-0001-8011-0012")), | ||
person(given = "Sam", | ||
family = "Abbott", | ||
role = c("aut"), | ||
email = "[email protected]", | ||
comment = c(ORCID = "0000-0001-8057-8037")), | ||
person(given = "Zachary", | ||
family = "Susswein", | ||
role = c("aut"), | ||
email = "[email protected]"), | ||
person(given = "Andrew", | ||
family = "Magee", | ||
role = c("aut"), | ||
email = "[email protected]"), | ||
person(given = "Dylan", | ||
family = "Morris", | ||
role = c("aut"), | ||
email = "[email protected]", | ||
comment = c(ORCID = "0000-0002-3655-406X")), | ||
person(given = "Scott", | ||
family = "Olesen", | ||
role = c("aut"), | ||
email = "[email protected]"), | ||
person(given = "George", | ||
family = "Vega Yon", | ||
role = c("ctb"), | ||
email = "[email protected]", | ||
comment = c(ORCID = "0000-0002-3171-0844")) | ||
) | ||
Description: This package provides helper functions designed to | ||
preprocess, postprocess, and analyze outputs from fitting a | ||
semi-mechanistic renewal model with and without incorporating | ||
wastewater. | ||
License: Apache License (== 2.0) | ||
URL: https://github.com/cdcgov/wastewater-informed-covid-forecasting/ | ||
BugReports: https://github.com/cdcgov/wastewater-informed-covid-forecasting/issues/ | ||
Depends: | ||
R (>= 4.3.0) | ||
Imports: | ||
dplyr, | ||
tidybayes, | ||
lubridate, | ||
ggplot2, | ||
stats, | ||
tidyr, | ||
arrow, | ||
glue, | ||
jsonlite, | ||
readr, | ||
rlang, | ||
cli, | ||
covidcast, | ||
data.table, | ||
purrr, | ||
scales, | ||
scoringutils, | ||
tibble, | ||
cmdstanr (>= 0.7.1), | ||
cfaforecastrenewalww | ||
Additional_repositories: https://mc-stan.org/r-packages/ | ||
SystemRequirements: CmdStan (>=2.34.1) | ||
Encoding: UTF-8 | ||
Roxygen: list(markdown = TRUE) | ||
Config/Needs/check: rcmdcheck, testthat | ||
RoxygenNote: 7.3.1 | ||
Remotes: | ||
cfaforecastrenewalww=local::../cfaforecastrenewalww |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
# Generated by roxygen2: do not edit by hand | ||
|
||
export(add_time_indexing) | ||
export(clean_ww_data) | ||
export(date_of_ww_data) | ||
export(filter_sites_by_scenario) | ||
export(get_full_scores) | ||
export(get_hosp_data_sizes) | ||
export(get_hosp_indices) | ||
export(get_hosp_values) | ||
export(get_input_hosp_data) | ||
export(get_input_ww_data) | ||
export(get_last_hosp_data_date) | ||
export(get_model_draws_w_data) | ||
export(get_model_path) | ||
export(get_plot_hosp_data_comparison) | ||
export(get_plot_ww_data_comparison) | ||
export(get_stan_data_list) | ||
export(get_state_level_quantiles) | ||
export(get_subpop_data) | ||
export(get_ww_data_indices) | ||
export(get_ww_data_sizes) | ||
export(get_ww_values) | ||
export(make_df) | ||
export(sample_model) | ||
importFrom(arrow,read_ipc_stream) | ||
importFrom(arrow,read_parquet) | ||
importFrom(arrow,write_parquet) | ||
importFrom(cmdstanr,cmdstan_model) | ||
importFrom(dplyr,arrange) | ||
importFrom(dplyr,as_tibble) | ||
importFrom(dplyr,distinct) | ||
importFrom(dplyr,filter) | ||
importFrom(dplyr,group_by) | ||
importFrom(dplyr,left_join) | ||
importFrom(dplyr,mutate) | ||
importFrom(dplyr,pull) | ||
importFrom(dplyr,rename) | ||
importFrom(dplyr,row_number) | ||
importFrom(dplyr,select) | ||
importFrom(dplyr,ungroup) | ||
importFrom(ggplot2,facet_grid) | ||
importFrom(ggplot2,facet_wrap) | ||
importFrom(ggplot2,geom_bar) | ||
importFrom(ggplot2,geom_hline) | ||
importFrom(ggplot2,geom_line) | ||
importFrom(ggplot2,geom_point) | ||
importFrom(ggplot2,geom_ribbon) | ||
importFrom(ggplot2,geom_vline) | ||
importFrom(ggplot2,ggplot) | ||
importFrom(ggplot2,labs) | ||
importFrom(ggplot2,scale_colour_discrete) | ||
importFrom(ggplot2,scale_fill_discrete) | ||
importFrom(ggplot2,scale_x_date) | ||
importFrom(ggplot2,scale_y_continuous) | ||
importFrom(ggplot2,theme) | ||
importFrom(glue,glue) | ||
importFrom(jsonlite,fromJSON) | ||
importFrom(lubridate,ymd) | ||
importFrom(readr,read_csv) | ||
importFrom(readr,write_csv) | ||
importFrom(rlang,sym) | ||
importFrom(tidybayes,spread_draws) | ||
importFrom(tidyr,pivot_longer) | ||
importFrom(tidyr,pivot_wider) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
generate_data <- function(model_type, location, forecast_date, | ||
input_ww_data, | ||
input_hosp_data, | ||
params, | ||
n = 10) { | ||
true_beta <- stats::rnorm(n = 1, mean = 0, sd = 1) | ||
x <- seq(from = -1, to = 1, length.out = n) | ||
y <- stats::rnorm(n, x * true_beta, 1) | ||
|
||
return(list(n = n, x = x, y = y, true_beta = true_beta)) | ||
} | ||
|
||
|
||
orig_sample_model <- function(standata, compiled_model, init_lists, | ||
iter_warmup = 250, | ||
iter_sampling = 250, | ||
max_treedepth = 12, | ||
adapt_delta = 0.95, | ||
n_chains = 4, | ||
seed = 123) { | ||
fit_model <- function(compiled_model, standata) { | ||
fit <- compiled_model$sample( | ||
data = standata, | ||
init = init_lists, | ||
iter_warmup = iter_warmup, | ||
iter_sampling = iter_sampling, | ||
max_treedepth = max_treedepth, | ||
adapt_delta = adapt_delta, | ||
num_chains = n_chains, | ||
seed = seed | ||
) | ||
print(fit) | ||
return(fit) | ||
} | ||
|
||
safe_fit_model <- purrr::safely(fit_model) | ||
fit <- safe_fit_model(compiled_model, standata) | ||
obj <- fit$result | ||
return(obj) | ||
} | ||
|
||
|
||
|
||
|
||
get_draws_for_key_pars <- function(fit, pars) { | ||
raw_draws <- tidybayes::spread_draws(fit, !!!syms(pars)) | ||
return(raw_draws) | ||
} | ||
|
||
get_summary <- function(fit) { | ||
summary <- fit$summary() | ||
return(summary) | ||
} | ||
|
||
get_model_diagnostics <- function(fit) { | ||
diagnostics <- fit$sampler_diagnostics(format = "df") | ||
return(diagnostics) | ||
} | ||
|
||
get_draws <- function(fit) { | ||
draws <- fit$draws() | ||
return(draws) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
#' Get model path | ||
#' | ||
#' @param model_type string specifying the model to be run, options are either | ||
#' 'hosp' for the hospital admissions only model or | ||
#' 'ww' for the site-level infection dyanmics model using wastewater | ||
#' @param stan_models_dir directory where stan files are located | ||
#' | ||
#' @return string indicating path to correct stan file | ||
#' @export | ||
#' | ||
#' @examples model_path <- get_model_path("hosp", system.file("stan", | ||
#' package = "cfaforecastrenewalww" | ||
#' )) | ||
get_model_path <- function(model_type, stan_models_dir) { | ||
stopifnot("Model type is empty" = !is.null(model_type)) | ||
model_file_name <- if (model_type == "ww") { | ||
"renewal_ww_hosp_site_level_inf_dynamics" | ||
} else if (model_type == "hosp") { | ||
"renewal_ww_hosp" | ||
} else { | ||
NULL | ||
} | ||
|
||
stopifnot("Model type is not specified properly" = !is.null(model_file_name)) | ||
fp <- file.path(stan_models_dir, paste0(model_file_name, ".stan")) | ||
return(fp) | ||
} | ||
|
||
get_fake_model_path <- function(model_type, path_to_stan_models) { | ||
model_file_name <- if (model_type == "modx") { | ||
"x" | ||
} else { | ||
"y" | ||
} | ||
fp <- file.path(path_to_stan_models, paste0(model_file_name, ".stan")) | ||
return(fp) | ||
} |
Oops, something went wrong.