Skip to content

Commit

Permalink
add eval code
Browse files Browse the repository at this point in the history
  • Loading branch information
kaitejohnson committed Apr 15, 2024
1 parent 5a03930 commit 5096303
Show file tree
Hide file tree
Showing 46 changed files with 3,929 additions and 0 deletions.
538 changes: 538 additions & 0 deletions _targets_eval.R

Large diffs are not rendered by default.

20 changes: 20 additions & 0 deletions src/setup_eval.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# This is a script used to generate the global inputs needed to run the
# _targets_eval.R function. Editing the locations and forecast dates
# and model types tells the evaluation pipeline which combinations to
# pass to tar map to iterate over.
source(file.path("src", "write_eval_config.R"))
write_eval_config(
locations = c("MA", "WA"),
forecast_dates = c(
"2023-10-16", "2023-10-23", "2023-10-30",
"2023-11-06", "2023-11-13", "2023-11-20",
"2023-11-27", "2023-12-04", "2023-12-11",
"2023-12-18", "2023-12-25", "2024-01-01",
"2024-01-08", "2024-01-15", "2024-01-22",
"2024-01-29", "2024-02-05", "2024-02-12",
"2024-02-19"
),
scenarios = c("Status quo", "One site per jurisdiction"),
config_dir = file.path("input", "config", "eval"),
eval_date = "2024-03-25"
)
105 changes: 105 additions & 0 deletions src/write_eval_config.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
#' Write evaluation config file
#'
#' @param locations locations to iterate through, for a full run this should
#' be all 50 states + PR
#' @param forecast_dates the forecast dates we want to run the model on
#' @param scenatios the scenarios (which will pertain to site ids) to
#' run the model on
#' @param config_dir the directory where we want to save the config file
#'
#' @return
#' @export
#'
#' @examples
write_eval_config <- function(locations, forecast_dates,
scenarios,
config_dir,
eval_date) {
# get a dataframe of all the iteration combinations for the wastewater mapping
df_ww <- expand.grid(
location = locations,
forecast_date = forecast_dates,
scenario = scenarios
)
# No scenarios for the hosp admissions only, so we just need all combos
# of locations and forecast dates
df_hosp <- expand.grid(
location = locations,
forecast_date = forecast_dates
)

# Specify other variables
ww_data_dir <- file.path("input", "ww_data", "monday_datasets")
scenario_dir <- file.path("input", "config", "eval", "scenarios")
hosp_data_dir <- file.path("input", "hosp_data", "vintage_datasets")
# stan_models_dir <- system.file("stan", package = "cfaforecastrenewalww") #nolint
stan_models_dir <- file.path("cfaforecastrenewalww", "inst", "stan")
init_dir <- file.path("input", "init_lists")

ww_data_mapping <- "Monday: Monday, Wednesday: Monday"
calibration_time <- 90
forecast_time <- 28

iter_warmup <- 750
iter_sampling <- 500
n_chains <- 4
n_parallel_chains <- 4
adapt_delta <- 0.95
max_treedepth <- 12
seed <- 123

init_fps <- c()
for (i in 1:n_chains) {
init_fps <- c(init_fps, file.path(init_dir, glue::glue("init_{i}.json")))
}

# Pre-specified delay distributions
generation_interval <- read.csv(here::here(
"input", "saved_pmfs",
"generation_interval.csv"
)) |>
dplyr::pull(probability_mass)
inf_to_hosp <- read.csv(here::here(
"input", "saved_pmfs",
"inf_to_hosp.csv"
)) |>
dplyr::pull(probability_mass)

config <- list(
location_ww = df_ww |> dplyr::pull(location) |> as.vector(),
forecast_date_ww = df_ww |> dplyr::pull(forecast_date) |> as.vector(),
scenario = df_ww |> dplyr::pull(scenario) |> as.vector(),
location_hosp = df_hosp |> dplyr::pull(location) |> as.vector(),
forecast_date_hosp = df_hosp |> dplyr::pull(forecast_date) |> as.vector(),
eval_date = eval_date,
ww_data_dir = ww_data_dir,
scenario_dir = scenario_dir,
hosp_data_dir = hosp_data_dir,
stan_models_dir = stan_models_dir,
init_dir = init_dir,
init_fps = init_fps,
calibration_time = calibration_time,
forecast_time = forecast_time,
ww_data_mapping = ww_data_mapping,
# MCMC settings
iter_warmup = iter_warmup,
iter_sampling = iter_sampling,
n_chains = n_chains,
n_parallel_chains = n_parallel_chains,
adapt_delta = adapt_delta,
max_treedepth = max_treedepth,
seed = seed,
# Input delay distributions
generation_interval = generation_interval,
infection_feedback_pmf = generation_interval,
inf_to_hosp = inf_to_hosp
)

cfaforecastrenewalww::create_dir(config_dir)
yaml::write_yaml(config, file = file.path(
config_dir,
glue::glue("eval_config.yaml")
))

return(config)
}
2 changes: 2 additions & 0 deletions wweval/.Rbuildignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
^wweval\.Rproj$
^\.Rproj\.user$
75 changes: 75 additions & 0 deletions wweval/DESCRIPTION
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
Package: wweval
Title: Evaluation of wastewater informed COVID-19 hospital admissions forecasting
Version: 0.0.0.9000
Authors@R:
c(person(given = "Kaitlyn",
family = "Johnson",
role = c("aut", "cre"),
email = "[email protected]",
comment = c(ORCID = "0000-0001-8011-0012")),
person(given = "Sam",
family = "Abbott",
role = c("aut"),
email = "[email protected]",
comment = c(ORCID = "0000-0001-8057-8037")),
person(given = "Zachary",
family = "Susswein",
role = c("aut"),
email = "[email protected]"),
person(given = "Andrew",
family = "Magee",
role = c("aut"),
email = "[email protected]"),
person(given = "Dylan",
family = "Morris",
role = c("aut"),
email = "[email protected]",
comment = c(ORCID = "0000-0002-3655-406X")),
person(given = "Scott",
family = "Olesen",
role = c("aut"),
email = "[email protected]"),
person(given = "George",
family = "Vega Yon",
role = c("ctb"),
email = "[email protected]",
comment = c(ORCID = "0000-0002-3171-0844"))
)
Description: This package provides helper functions designed to
preprocess, postprocess, and analyze outputs from fitting a
semi-mechanistic renewal model with and without incorporating
wastewater.
License: Apache License (== 2.0)
URL: https://github.com/cdcgov/wastewater-informed-covid-forecasting/
BugReports: https://github.com/cdcgov/wastewater-informed-covid-forecasting/issues/
Depends:
R (>= 4.3.0)
Imports:
dplyr,
tidybayes,
lubridate,
ggplot2,
stats,
tidyr,
arrow,
glue,
jsonlite,
readr,
rlang,
cli,
covidcast,
data.table,
purrr,
scales,
scoringutils,
tibble,
cmdstanr (>= 0.7.1),
cfaforecastrenewalww
Additional_repositories: https://mc-stan.org/r-packages/
SystemRequirements: CmdStan (>=2.34.1)
Encoding: UTF-8
Roxygen: list(markdown = TRUE)
Config/Needs/check: rcmdcheck, testthat
RoxygenNote: 7.3.1
Remotes:
cfaforecastrenewalww=local::../cfaforecastrenewalww
65 changes: 65 additions & 0 deletions wweval/NAMESPACE
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Generated by roxygen2: do not edit by hand

export(add_time_indexing)
export(clean_ww_data)
export(date_of_ww_data)
export(filter_sites_by_scenario)
export(get_full_scores)
export(get_hosp_data_sizes)
export(get_hosp_indices)
export(get_hosp_values)
export(get_input_hosp_data)
export(get_input_ww_data)
export(get_last_hosp_data_date)
export(get_model_draws_w_data)
export(get_model_path)
export(get_plot_hosp_data_comparison)
export(get_plot_ww_data_comparison)
export(get_stan_data_list)
export(get_state_level_quantiles)
export(get_subpop_data)
export(get_ww_data_indices)
export(get_ww_data_sizes)
export(get_ww_values)
export(make_df)
export(sample_model)
importFrom(arrow,read_ipc_stream)
importFrom(arrow,read_parquet)
importFrom(arrow,write_parquet)
importFrom(cmdstanr,cmdstan_model)
importFrom(dplyr,arrange)
importFrom(dplyr,as_tibble)
importFrom(dplyr,distinct)
importFrom(dplyr,filter)
importFrom(dplyr,group_by)
importFrom(dplyr,left_join)
importFrom(dplyr,mutate)
importFrom(dplyr,pull)
importFrom(dplyr,rename)
importFrom(dplyr,row_number)
importFrom(dplyr,select)
importFrom(dplyr,ungroup)
importFrom(ggplot2,facet_grid)
importFrom(ggplot2,facet_wrap)
importFrom(ggplot2,geom_bar)
importFrom(ggplot2,geom_hline)
importFrom(ggplot2,geom_line)
importFrom(ggplot2,geom_point)
importFrom(ggplot2,geom_ribbon)
importFrom(ggplot2,geom_vline)
importFrom(ggplot2,ggplot)
importFrom(ggplot2,labs)
importFrom(ggplot2,scale_colour_discrete)
importFrom(ggplot2,scale_fill_discrete)
importFrom(ggplot2,scale_x_date)
importFrom(ggplot2,scale_y_continuous)
importFrom(ggplot2,theme)
importFrom(glue,glue)
importFrom(jsonlite,fromJSON)
importFrom(lubridate,ymd)
importFrom(readr,read_csv)
importFrom(readr,write_csv)
importFrom(rlang,sym)
importFrom(tidybayes,spread_draws)
importFrom(tidyr,pivot_longer)
importFrom(tidyr,pivot_wider)
63 changes: 63 additions & 0 deletions wweval/R/fake_fitting.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
generate_data <- function(model_type, location, forecast_date,
input_ww_data,
input_hosp_data,
params,
n = 10) {
true_beta <- stats::rnorm(n = 1, mean = 0, sd = 1)
x <- seq(from = -1, to = 1, length.out = n)
y <- stats::rnorm(n, x * true_beta, 1)

return(list(n = n, x = x, y = y, true_beta = true_beta))
}


orig_sample_model <- function(standata, compiled_model, init_lists,
iter_warmup = 250,
iter_sampling = 250,
max_treedepth = 12,
adapt_delta = 0.95,
n_chains = 4,
seed = 123) {
fit_model <- function(compiled_model, standata) {
fit <- compiled_model$sample(
data = standata,
init = init_lists,
iter_warmup = iter_warmup,
iter_sampling = iter_sampling,
max_treedepth = max_treedepth,
adapt_delta = adapt_delta,
num_chains = n_chains,
seed = seed
)
print(fit)
return(fit)
}

safe_fit_model <- purrr::safely(fit_model)
fit <- safe_fit_model(compiled_model, standata)
obj <- fit$result
return(obj)
}




get_draws_for_key_pars <- function(fit, pars) {
raw_draws <- tidybayes::spread_draws(fit, !!!syms(pars))
return(raw_draws)
}

get_summary <- function(fit) {
summary <- fit$summary()
return(summary)
}

get_model_diagnostics <- function(fit) {
diagnostics <- fit$sampler_diagnostics(format = "df")
return(diagnostics)
}

get_draws <- function(fit) {
draws <- fit$draws()
return(draws)
}
37 changes: 37 additions & 0 deletions wweval/R/filepath_mapping.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#' Get model path
#'
#' @param model_type string specifying the model to be run, options are either
#' 'hosp' for the hospital admissions only model or
#' 'ww' for the site-level infection dyanmics model using wastewater
#' @param stan_models_dir directory where stan files are located
#'
#' @return string indicating path to correct stan file
#' @export
#'
#' @examples model_path <- get_model_path("hosp", system.file("stan",
#' package = "cfaforecastrenewalww"
#' ))
get_model_path <- function(model_type, stan_models_dir) {
stopifnot("Model type is empty" = !is.null(model_type))
model_file_name <- if (model_type == "ww") {
"renewal_ww_hosp_site_level_inf_dynamics"
} else if (model_type == "hosp") {
"renewal_ww_hosp"
} else {
NULL
}

stopifnot("Model type is not specified properly" = !is.null(model_file_name))
fp <- file.path(stan_models_dir, paste0(model_file_name, ".stan"))
return(fp)
}

get_fake_model_path <- function(model_type, path_to_stan_models) {
model_file_name <- if (model_type == "modx") {
"x"
} else {
"y"
}
fp <- file.path(path_to_stan_models, paste0(model_file_name, ".stan"))
return(fp)
}
Loading

0 comments on commit 5096303

Please sign in to comment.