diff --git a/_targets_eval.R b/_targets_eval.R new file mode 100644 index 00000000..746d2b66 --- /dev/null +++ b/_targets_eval.R @@ -0,0 +1,538 @@ +# This defines the evaluation pipeline + +# Load packages required to define the pipeline: +library(targets) +library(stantargets) +library(tarchetypes) # tar_render() calls +library(crew) # To run in parallel +library(lubridate) +library(purrr, quietly = TRUE) + + +# Run with a crew controller +# See: https://books.ropensci.org/targets/crew.html +# This defaults to 8 cores but you likely want to use +# floor(/number used for MCMC) +# Currently on model fitting is run in parallel due to an assumption +# that IO costs outweighs any benefit for other tasks. +# This may not be the case. To change allocation alter the deployment +# argument of a target. +controller <- crew_controller_local( + workers = 8, + seconds_idle = 600 +) + +# Set target options: +tar_option_set( + workspace_on_error = TRUE, + packages = c("cfaforecastrenewalww", "wweval"), + # Run with a pre-specified crew controller + controller = controller, + # Setup storage on workers vs on the main node. + # This will only work on workers that have access to the data + # See https://books.ropensci.org/targets/performance.html#worker-storage + memory = "transient", + garbage_collection = TRUE, + storage = "worker", + retrieval = "worker", + format = "rds", # default storage format + error = NULL # tells errored targets to return NULL rather than + # have whole pipeline fail + # Set other options as needed. +) + +setup_interactive_dev_run <- function() { + list.files(file.path("wweval", "R"), full.names = TRUE) |> + purrr::walk(source) + tar_option_set( + packages = c( + "cmdstanr", + "tibble", + "ggplot2", + "dplyr", + "lubridate", + "cmdstanr", + "tidybayes", + "cfaforecastrenewalww", + "wweval" + ) + ) +} + +setup_interactive_dev_run() + +cfaforecastrenewalww::setup_secrets("secrets.yaml") + +# Need to specify the evaluation variable combinations outside of targets +eval_config <- yaml::read_yaml(file.path( + "input", "config", + "eval", "eval_config.yaml" +)) +# Get global parameter values +params <- cfaforecastrenewalww::get_params(file.path( + "input", "params.toml" +)) + + + +# Set up some global targets + +# Get the evaluation data from the specified evaluation date ---------------- +upstream_targets <- list( + tar_target( + name = eval_hosp_data, + command = get_input_hosp_data( + forecast_date = eval_config$eval_date, + location = unique(eval_config$location_ww), + hosp_data_dir = eval_config$hosp_data_dir, + calibration_time = 365 # Grab sufficient data for eval + ) + ), + tar_target( + name = eval_ww_data, + command = get_input_ww_data( + forecast_date = eval_config$eval_date, + location = unique(eval_config$location_ww), + scenario = "Status quo", + scenario_dir = eval_config$scenario_dir, + ww_data_dir = eval_config$ww_data_dir, + calibration_time = 365, # Grab sufficient data for eval + last_hosp_data_date = eval_config$eval_date, + ww_data_mapping = eval_config$ww_data_mapping + ) + ) +) + +# Iterate over forecast dates, locations, and scenarios separately for the two models +# For each iteration (forecast date, location): +# - load in and clean the ww and hosp dataset for each forecast date and location +# - based on the models data requirements, preprocess data and generate parameters +# - generate stan data +# - fit the model +# - extraction posterior samples for generated quantities and parameters of interest +# - join posterior draws with input data +# - summarize the generated quantities to only those we're scoring +# - score the generated quantities against evaluation data (most recent data) +# - save posterior draws and diagnostics + +# Wastewater model fitting loop----------------------------------------------- +mapped_ww <- tar_map( + values = list( + location = eval_config$location_ww, + forecast_date = eval_config$forecast_date_ww, + scenario = eval_config$scenario + ), + tar_target( + name = stan_model_path_target, + command = get_model_path( + model_type = "ww", + stan_models_dir = eval_config$stan_models_dir + ), + format = "file", + priority = 1 + ), + tar_target( + name = input_hosp_data, + command = get_input_hosp_data(forecast_date, location, + hosp_data_dir = eval_config$hosp_data_dir, + calibration_time = eval_config$calibration_time + ), + deployment = "main", + priority = 1 + ), + tar_target( + name = last_hosp_data_date, + command = get_last_hosp_data_date(input_hosp_data), + deployment = "main", + priority = 1 + ), + tar_target(input_ww_data, + command = get_input_ww_data(forecast_date, + location, + scenario, + scenario_dir = eval_config$scenario_dir, + ww_data_dir = eval_config$ww_data_dir, + calibration_time = eval_config$calibration_time, + last_hosp_data_date = last_hosp_data_date, + ww_data_mapping = eval_config$ww_data_mapping + ), + deployment = "main", + priority = 1 + ), + ## Get the stan data for this location, forecast_date, and scenario ---------- + tar_target( + name = standata, + command = get_stan_data_list( + model_type = "ww", + forecast_date, eval_config$forecast_time, + input_ww_data, input_hosp_data, + generation_interval = eval_config$generation_interval, + inf_to_hosp = eval_config$inf_to_hosp, + infection_feedback_pmf = eval_config$infection_feedback_pmf, + params + ), + deployment = "main", + priority = 1 + ), + ## Model fitting ---------------------------------------------------------- + tar_target( + name = init_lists, + command = get_inits( + model_type = "ww", standata, params, + n_chains = eval_config$n_chains + ), + deployment = "main", + priority = 1 + ), + tar_target( + name = compiled_model, + command = compile_model( + model_filepath = stan_model_path_target, + include_paths = eval_config$stan_models_dir + ), + deployment = "main", + priority = 1 + ), + tar_target( + name = ww_fit_obj, + command = sample_model(standata, compiled_model, init_lists, + iter_warmup = 250, # eval_config$iter_warmup, + iter_sampling = 25, # eval_config$iter_sampling, + adapt_delta = eval_config$adapt_delta, + n_chains = eval_config$n_chains, + max_treedepth = eval_config$max_treedepth, + seed = eval_config$seed + ), + deployment = "worker", + priority = 1 + ), + ## Post-processing--------------------------------------------------------- + tar_target( + name = ww_raw_draws, + command = ww_fit_obj$draws, + deployment = "main", + priority = 1 + ), + tar_target( + name = ww_diagnostics, + command = ww_fit_obj$diagnostics, + deployment = "main", + priority = 1 + ), + tar_target( + name = ww_diagnostic_summary, + command = ww_fit_obj$summary_diagnostics, + deployment = "main", + priority = 1 + ), + + # Get evaluation data from hospital admissions and wastewater + # Join draws with data + tar_target( + name = hosp_draws, + command = get_model_draws_w_data( + model_output = "hosp", + model_type = "ww", + draws = ww_raw_draws, + forecast_date = forecast_date, + scenario = scenario, + location = location, + input_data = input_hosp_data, + eval_data = eval_hosp_data, + last_hosp_data_date = last_hosp_data_date, + ot = eval_config$calibration_time, + forecast_time = eval_config$forecast_time + ), + deployment = "main", + priority = 1 + ), + tar_target( + name = ww_draws, + command = get_model_draws_w_data( + model_output = "ww", + model_type = "ww", + draws = ww_raw_draws, + forecast_date = forecast_date, + scenario = scenario, + location = location, + input_data = input_ww_data, + eval_data = eval_ww_data, + last_hosp_data_date = last_hosp_data_date, + ot = eval_config$calibration_time, + forecast_time = eval_config$forecast_time + ), + deployment = "main", + priority = 1 + ), + tar_target( + name = full_hosp_quantiles, + command = get_state_level_quantiles( + draws = hosp_draws + ), + deployment = "main", + priority = 1 + ), + tar_target( + name = hosp_quantiles, + command = full_hosp_quantiles |> + dplyr::filter(period != "calibration"), + deployment = "main", + priority = 1 + ), + ### Plot the draw comparison------------------------------------- + tar_target( + name = plot_hosp_draws, + command = get_plot_hosp_data_comparison( + hosp_draws, + location, + model_type = "ww" + ) + ), + tar_target( + name = plot_ww_draws, + command = get_plot_ww_data_comparison( + ww_draws, + location, + model_type = "ww" + ) + ), + + ## Score hospital admissions forecasts---------------------------------- + tar_target( + name = hosp_scores, + command = get_full_scores(hosp_draws, scenario) + ) + # Get a subset of samples for plotting + # Get a subset of quantiles for plotting +) # end tar map + +# Hospital admissions model fitting loop----------------------------------------------- +mapped_hosp <- tar_map( + values = list( + location = eval_config$location_hosp, + forecast_date = eval_config$forecast_date_hosp + ), + tar_target( + name = stan_model_path_target, + command = get_model_path( + model_type = "hosp", + stan_models_dir = eval_config$stan_models_dir + ), + format = "file", + deployment = "main" + ), + tar_target( + name = input_hosp_data, + command = get_input_hosp_data(forecast_date, location, + hosp_data_dir = eval_config$hosp_data_dir, + calibration_time = eval_config$calibration_time + ), + deployment = "main" + ), + tar_target( + name = last_hosp_data_date, + command = get_last_hosp_data_date(input_hosp_data), + deployment = "main" + ), + ## Get the stan data for this location, forecast_date, and scenario ---------- + tar_target( + name = standata, + command = get_stan_data_list( + model_type = "hosp", + forecast_date, eval_config$forecast_time, + input_ww_data = NA, + input_hosp_data = input_hosp_data, + generation_interval = eval_config$generation_interval, + inf_to_hosp = eval_config$inf_to_hosp, + infection_feedback_pmf = eval_config$infection_feedback_pmf, + params + ), + deployment = "main" + ), + ## Model fitting----------------------------------------------------------- + tar_target( + name = init_lists, + command = get_inits( + model_type = "hosp", + standata, params, + n_chains = eval_config$n_chains + ), + deployment = "main" + ), + tar_target( + name = compiled_model, + command = compile_model( + model_filepath = stan_model_path_target, + include_paths = eval_config$stan_models_dir + ), + deployment = "main" + ), + tar_target( + name = hosp_fit_obj, + command = sample_model(standata, compiled_model, init_lists, + iter_warmup = 250, # eval_config$iter_warmup, + iter_sampling = 25, # eval_config$iter_sampling, + adapt_delta = eval_config$adapt_delta, + max_treedepth = eval_config$max_treedepth, + seed = eval_config$seed + ), + deployment = "worker" + ), + ## Post-processing--------------------------------------------------------- + tar_target( + name = hosp_raw_draws, + command = hosp_fit_obj$draws, + deployment = "main" + ), + tar_target( + name = hosp_diagnostics, + command = hosp_fit_obj$diagnostics, + deployment = "main" + ), + tar_target( + name = hosp_diagnostic_summary, + command = hosp_fit_obj$summary_diagnostics, + deployment = "main" + ), + + # Get evaluation data from hospital admissions and wastewater + # Join draws with data + tar_target( + name = hosp_model_hosp_draws, + command = get_model_draws_w_data( + model_output = "hosp", + model_type = "hosp", + draws = hosp_raw_draws, + forecast_date = forecast_date, + scenario = "No wastewater", + location = location, + input_data = input_hosp_data, + eval_data = eval_hosp_data, + last_hosp_data_date = last_hosp_data_date, + ot = eval_config$calibration_time, + forecast_time = eval_config$forecast_time + ), + deployment = "main" + ), + tar_target( + name = full_hosp_model_quantiles, + command = get_state_level_quantiles( + draws = hosp_model_hosp_draws + ), + deployment = "main" + ), + tar_target( + name = hosp_model_quantiles, + command = full_hosp_model_quantiles |> + dplyr::filter(period != "calibration"), + deployment = "main" + ), + ### Plot the draw comparison------------------------------------- + tar_target( + name = plot_hosp_draws_hosp_model, + command = get_plot_hosp_data_comparison( + hosp_model_hosp_draws, + location, + model_type = "hosp" + ), + deployment = "main" + ), + ## Score the hospital admissions only model------------------------- + tar_target( + name = hosp_scores, + command = get_full_scores(hosp_model_hosp_draws, + scenario = "No wastewater" + ), + deployment = "main" + ) +) # end tar map + + +# Summarize the scores and outputs across groups------------------------------ +combined_ww_scores <- tar_combine( + name = all_ww_scores, + mapped_ww$hosp_scores, + command = dplyr::bind_rows(!!!.x, .id = "method") +) +combined_hosp_scores <- tar_combine( + name = all_hosp_scores, + mapped_hosp$hosp_scores, + command = dplyr::bind_rows(!!!.x, .id = "method") +) +combined_ww_hosp_quantiles <- tar_combine( + name = all_ww_hosp_quantiles, + mapped_ww$hosp_quantiles, + command = dplyr::bind_rows(!!!.x, .id = "method") +) +combined_hosp_model_quantiles <- tar_combine( + name = all_hosp_model_quantiles, + mapped_hosp$hosp_model_quantiles, + command = dplyr::bind_rows(!!!.x, .id = "method") +) + + +downstream_targets <- list( + tar_target( + name = all_scores, + command = rbind(all_hosp_scores, all_ww_scores) + ), + tar_target( + name = plot_scores, + command = get_plot_raw_scores(all_scores) + ), + tar_target( + name = summarized_scores, + command = scoringutils::summarize_scores(all_scores, + by = c( + "scenario", + "period", + "forecast_date", + "location" + ) + ) |> + dplyr::group_by(location) |> + targets::tar_group() + ), + tar_target( + name = grouped_all_scores, + command = all_scores |> + dplyr::group_by(location) |> + targets::tar_group(), + iteration = "group" + ), + tar_target( + name = plot_summarized_scores, + command = get_plot_summarized_scores(grouped_all_scores), + pattern = map(grouped_all_scores), + iteration = "list" + ), + tar_target( + name = all_hosp_quantiles, + command = rbind( + all_hosp_model_quantiles, + all_ww_hosp_quantiles + ) |> + dplyr::group_by(location) |> + targets::tar_group(), + iteration = "group" + ), + tar_target( + name = plot_quantile_comparison, + command = get_plot_quantile_comparison( + all_hosp_quantiles + ), + pattern = map(all_hosp_quantiles), + iteration = "list" + ) +) +# Generate figures and results from outputs----------------------------------- +list( + upstream_targets, + mapped_ww, + mapped_hosp, + combined_ww_scores, + combined_hosp_scores, + combined_ww_hosp_quantiles, + combined_hosp_model_quantiles, + downstream_targets +) diff --git a/src/setup_eval.R b/src/setup_eval.R new file mode 100644 index 00000000..3895e488 --- /dev/null +++ b/src/setup_eval.R @@ -0,0 +1,20 @@ +# This is a script used to generate the global inputs needed to run the +# _targets_eval.R function. Editing the locations and forecast dates +# and model types tells the evaluation pipeline which combinations to +# pass to tar map to iterate over. +source(file.path("src", "write_eval_config.R")) +write_eval_config( + locations = c("MA", "WA"), + forecast_dates = c( + "2023-10-16", "2023-10-23", "2023-10-30", + "2023-11-06", "2023-11-13", "2023-11-20", + "2023-11-27", "2023-12-04", "2023-12-11", + "2023-12-18", "2023-12-25", "2024-01-01", + "2024-01-08", "2024-01-15", "2024-01-22", + "2024-01-29", "2024-02-05", "2024-02-12", + "2024-02-19" + ), + scenarios = c("Status quo", "One site per jurisdiction"), + config_dir = file.path("input", "config", "eval"), + eval_date = "2024-03-25" +) diff --git a/src/write_eval_config.R b/src/write_eval_config.R new file mode 100644 index 00000000..5f6a2088 --- /dev/null +++ b/src/write_eval_config.R @@ -0,0 +1,105 @@ +#' Write evaluation config file +#' +#' @param locations locations to iterate through, for a full run this should +#' be all 50 states + PR +#' @param forecast_dates the forecast dates we want to run the model on +#' @param scenatios the scenarios (which will pertain to site ids) to +#' run the model on +#' @param config_dir the directory where we want to save the config file +#' +#' @return +#' @export +#' +#' @examples +write_eval_config <- function(locations, forecast_dates, + scenarios, + config_dir, + eval_date) { + # get a dataframe of all the iteration combinations for the wastewater mapping + df_ww <- expand.grid( + location = locations, + forecast_date = forecast_dates, + scenario = scenarios + ) + # No scenarios for the hosp admissions only, so we just need all combos + # of locations and forecast dates + df_hosp <- expand.grid( + location = locations, + forecast_date = forecast_dates + ) + + # Specify other variables + ww_data_dir <- file.path("input", "ww_data", "monday_datasets") + scenario_dir <- file.path("input", "config", "eval", "scenarios") + hosp_data_dir <- file.path("input", "hosp_data", "vintage_datasets") + # stan_models_dir <- system.file("stan", package = "cfaforecastrenewalww") #nolint + stan_models_dir <- file.path("cfaforecastrenewalww", "inst", "stan") + init_dir <- file.path("input", "init_lists") + + ww_data_mapping <- "Monday: Monday, Wednesday: Monday" + calibration_time <- 90 + forecast_time <- 28 + + iter_warmup <- 750 + iter_sampling <- 500 + n_chains <- 4 + n_parallel_chains <- 4 + adapt_delta <- 0.95 + max_treedepth <- 12 + seed <- 123 + + init_fps <- c() + for (i in 1:n_chains) { + init_fps <- c(init_fps, file.path(init_dir, glue::glue("init_{i}.json"))) + } + + # Pre-specified delay distributions + generation_interval <- read.csv(here::here( + "input", "saved_pmfs", + "generation_interval.csv" + )) |> + dplyr::pull(probability_mass) + inf_to_hosp <- read.csv(here::here( + "input", "saved_pmfs", + "inf_to_hosp.csv" + )) |> + dplyr::pull(probability_mass) + + config <- list( + location_ww = df_ww |> dplyr::pull(location) |> as.vector(), + forecast_date_ww = df_ww |> dplyr::pull(forecast_date) |> as.vector(), + scenario = df_ww |> dplyr::pull(scenario) |> as.vector(), + location_hosp = df_hosp |> dplyr::pull(location) |> as.vector(), + forecast_date_hosp = df_hosp |> dplyr::pull(forecast_date) |> as.vector(), + eval_date = eval_date, + ww_data_dir = ww_data_dir, + scenario_dir = scenario_dir, + hosp_data_dir = hosp_data_dir, + stan_models_dir = stan_models_dir, + init_dir = init_dir, + init_fps = init_fps, + calibration_time = calibration_time, + forecast_time = forecast_time, + ww_data_mapping = ww_data_mapping, + # MCMC settings + iter_warmup = iter_warmup, + iter_sampling = iter_sampling, + n_chains = n_chains, + n_parallel_chains = n_parallel_chains, + adapt_delta = adapt_delta, + max_treedepth = max_treedepth, + seed = seed, + # Input delay distributions + generation_interval = generation_interval, + infection_feedback_pmf = generation_interval, + inf_to_hosp = inf_to_hosp + ) + + cfaforecastrenewalww::create_dir(config_dir) + yaml::write_yaml(config, file = file.path( + config_dir, + glue::glue("eval_config.yaml") + )) + + return(config) +} diff --git a/wweval/.Rbuildignore b/wweval/.Rbuildignore new file mode 100644 index 00000000..ea59d3fb --- /dev/null +++ b/wweval/.Rbuildignore @@ -0,0 +1,2 @@ +^wweval\.Rproj$ +^\.Rproj\.user$ diff --git a/wweval/DESCRIPTION b/wweval/DESCRIPTION new file mode 100644 index 00000000..145b3e6b --- /dev/null +++ b/wweval/DESCRIPTION @@ -0,0 +1,75 @@ +Package: wweval +Title: Evaluation of wastewater informed COVID-19 hospital admissions forecasting +Version: 0.0.0.9000 +Authors@R: + c(person(given = "Kaitlyn", + family = "Johnson", + role = c("aut", "cre"), + email = "uox1@cdc.gov", + comment = c(ORCID = "0000-0001-8011-0012")), + person(given = "Sam", + family = "Abbott", + role = c("aut"), + email = "contact@samabbott.co.uk", + comment = c(ORCID = "0000-0001-8057-8037")), + person(given = "Zachary", + family = "Susswein", + role = c("aut"), + email = "utb2@cdc.gov"), + person(given = "Andrew", + family = "Magee", + role = c("aut"), + email = "rzg0@cdc.gov"), + person(given = "Dylan", + family = "Morris", + role = c("aut"), + email = "dylan@dylanhmorris.com", + comment = c(ORCID = "0000-0002-3655-406X")), + person(given = "Scott", + family = "Olesen", + role = c("aut"), + email = "ulp7@cdc.gov"), + person(given = "George", + family = "Vega Yon", + role = c("ctb"), + email = "g.vegayon@gmail.com", + comment = c(ORCID = "0000-0002-3171-0844")) + ) +Description: This package provides helper functions designed to + preprocess, postprocess, and analyze outputs from fitting a + semi-mechanistic renewal model with and without incorporating + wastewater. +License: Apache License (== 2.0) +URL: https://github.com/cdcgov/wastewater-informed-covid-forecasting/ +BugReports: https://github.com/cdcgov/wastewater-informed-covid-forecasting/issues/ +Depends: + R (>= 4.3.0) +Imports: + dplyr, + tidybayes, + lubridate, + ggplot2, + stats, + tidyr, + arrow, + glue, + jsonlite, + readr, + rlang, + cli, + covidcast, + data.table, + purrr, + scales, + scoringutils, + tibble, + cmdstanr (>= 0.7.1), + cfaforecastrenewalww +Additional_repositories: https://mc-stan.org/r-packages/ +SystemRequirements: CmdStan (>=2.34.1) +Encoding: UTF-8 +Roxygen: list(markdown = TRUE) +Config/Needs/check: rcmdcheck, testthat +RoxygenNote: 7.3.1 +Remotes: + cfaforecastrenewalww=local::../cfaforecastrenewalww diff --git a/wweval/NAMESPACE b/wweval/NAMESPACE new file mode 100644 index 00000000..df342309 --- /dev/null +++ b/wweval/NAMESPACE @@ -0,0 +1,65 @@ +# Generated by roxygen2: do not edit by hand + +export(add_time_indexing) +export(clean_ww_data) +export(date_of_ww_data) +export(filter_sites_by_scenario) +export(get_full_scores) +export(get_hosp_data_sizes) +export(get_hosp_indices) +export(get_hosp_values) +export(get_input_hosp_data) +export(get_input_ww_data) +export(get_last_hosp_data_date) +export(get_model_draws_w_data) +export(get_model_path) +export(get_plot_hosp_data_comparison) +export(get_plot_ww_data_comparison) +export(get_stan_data_list) +export(get_state_level_quantiles) +export(get_subpop_data) +export(get_ww_data_indices) +export(get_ww_data_sizes) +export(get_ww_values) +export(make_df) +export(sample_model) +importFrom(arrow,read_ipc_stream) +importFrom(arrow,read_parquet) +importFrom(arrow,write_parquet) +importFrom(cmdstanr,cmdstan_model) +importFrom(dplyr,arrange) +importFrom(dplyr,as_tibble) +importFrom(dplyr,distinct) +importFrom(dplyr,filter) +importFrom(dplyr,group_by) +importFrom(dplyr,left_join) +importFrom(dplyr,mutate) +importFrom(dplyr,pull) +importFrom(dplyr,rename) +importFrom(dplyr,row_number) +importFrom(dplyr,select) +importFrom(dplyr,ungroup) +importFrom(ggplot2,facet_grid) +importFrom(ggplot2,facet_wrap) +importFrom(ggplot2,geom_bar) +importFrom(ggplot2,geom_hline) +importFrom(ggplot2,geom_line) +importFrom(ggplot2,geom_point) +importFrom(ggplot2,geom_ribbon) +importFrom(ggplot2,geom_vline) +importFrom(ggplot2,ggplot) +importFrom(ggplot2,labs) +importFrom(ggplot2,scale_colour_discrete) +importFrom(ggplot2,scale_fill_discrete) +importFrom(ggplot2,scale_x_date) +importFrom(ggplot2,scale_y_continuous) +importFrom(ggplot2,theme) +importFrom(glue,glue) +importFrom(jsonlite,fromJSON) +importFrom(lubridate,ymd) +importFrom(readr,read_csv) +importFrom(readr,write_csv) +importFrom(rlang,sym) +importFrom(tidybayes,spread_draws) +importFrom(tidyr,pivot_longer) +importFrom(tidyr,pivot_wider) diff --git a/wweval/R/fake_fitting.R b/wweval/R/fake_fitting.R new file mode 100644 index 00000000..2f0f555a --- /dev/null +++ b/wweval/R/fake_fitting.R @@ -0,0 +1,63 @@ +generate_data <- function(model_type, location, forecast_date, + input_ww_data, + input_hosp_data, + params, + n = 10) { + true_beta <- stats::rnorm(n = 1, mean = 0, sd = 1) + x <- seq(from = -1, to = 1, length.out = n) + y <- stats::rnorm(n, x * true_beta, 1) + + return(list(n = n, x = x, y = y, true_beta = true_beta)) +} + + +orig_sample_model <- function(standata, compiled_model, init_lists, + iter_warmup = 250, + iter_sampling = 250, + max_treedepth = 12, + adapt_delta = 0.95, + n_chains = 4, + seed = 123) { + fit_model <- function(compiled_model, standata) { + fit <- compiled_model$sample( + data = standata, + init = init_lists, + iter_warmup = iter_warmup, + iter_sampling = iter_sampling, + max_treedepth = max_treedepth, + adapt_delta = adapt_delta, + num_chains = n_chains, + seed = seed + ) + print(fit) + return(fit) + } + + safe_fit_model <- purrr::safely(fit_model) + fit <- safe_fit_model(compiled_model, standata) + obj <- fit$result + return(obj) +} + + + + +get_draws_for_key_pars <- function(fit, pars) { + raw_draws <- tidybayes::spread_draws(fit, !!!syms(pars)) + return(raw_draws) +} + +get_summary <- function(fit) { + summary <- fit$summary() + return(summary) +} + +get_model_diagnostics <- function(fit) { + diagnostics <- fit$sampler_diagnostics(format = "df") + return(diagnostics) +} + +get_draws <- function(fit) { + draws <- fit$draws() + return(draws) +} diff --git a/wweval/R/filepath_mapping.R b/wweval/R/filepath_mapping.R new file mode 100644 index 00000000..0153c177 --- /dev/null +++ b/wweval/R/filepath_mapping.R @@ -0,0 +1,37 @@ +#' Get model path +#' +#' @param model_type string specifying the model to be run, options are either +#' 'hosp' for the hospital admissions only model or +#' 'ww' for the site-level infection dyanmics model using wastewater +#' @param stan_models_dir directory where stan files are located +#' +#' @return string indicating path to correct stan file +#' @export +#' +#' @examples model_path <- get_model_path("hosp", system.file("stan", +#' package = "cfaforecastrenewalww" +#' )) +get_model_path <- function(model_type, stan_models_dir) { + stopifnot("Model type is empty" = !is.null(model_type)) + model_file_name <- if (model_type == "ww") { + "renewal_ww_hosp_site_level_inf_dynamics" + } else if (model_type == "hosp") { + "renewal_ww_hosp" + } else { + NULL + } + + stopifnot("Model type is not specified properly" = !is.null(model_file_name)) + fp <- file.path(stan_models_dir, paste0(model_file_name, ".stan")) + return(fp) +} + +get_fake_model_path <- function(model_type, path_to_stan_models) { + model_file_name <- if (model_type == "modx") { + "x" + } else { + "y" + } + fp <- file.path(path_to_stan_models, paste0(model_file_name, ".stan")) + return(fp) +} diff --git a/wweval/R/format_data_for_stan.R b/wweval/R/format_data_for_stan.R new file mode 100644 index 00000000..4469f12e --- /dev/null +++ b/wweval/R/format_data_for_stan.R @@ -0,0 +1,801 @@ +#' Get stan data +#' +#' @param model_type string indicating which model we are getting data for +#' Options are `ww` or `hosp` +#' @param forecast_date string indicating the forecast date +#' @param forecast_time integer indicating the number of days to make a forecast +#' for +#' @param input_ww_data a dataframe with the input wastewater data +#' @param input_hosp_data a dataframe with the input hospital admissions data +#' @param generation_interval a vector with a zero-truncated normalized pmf of +#' the generation interval +#' @param inf_to_hosp a vector with a normalized pmf of the delay from infection +#' to hospital admissions +#' @param infection_feedback_pmf a vector with a normalized pmf dictating the +#' delay of infection feedback +#' @param params a dataframe of parameter names and numeric values +#' @param compute_likelihood indicator variable telling stan whether or not to +#' compute the likelihood, default = `1` +#' @param ww_outlier_col_name A string representing the name of the +#' column in the input_ww_data that provides a 0 if the data point is not an +#' outlier to be excluded from the model fit, or a 1 if it is to be excluded +#' default value is `flag_as_ww_outlier` +#' @param lod_col_name A string representing the name of the +#' column in the input_ww_data that provides a 0 if the data point is not above +#' the LOD and a 1 if the data is below the LOD, default value is `below_LOD` +#' @param ww_measurement_col_name A string representing the name of the column +#' in the input_ww_data that indicates the wastewater measurement value in +#' natural scale, default is `ww` +#' @param ww_value_lod_col_name A string representing the name of the column +#' in the input_ww_data that indicates the value of the LOD in natural scale, +#' default is `lod_sewage` +#' @param hosp_value_col_name A string represeting the name of the column in the +#' input_hosp-data that indicates the number of daily hospital admissions, +#' default is `daily_hosp_admits` +#' +#' @return a list named variables to pass to stan +#' @export +get_stan_data_list <- function(model_type, + forecast_date, + forecast_time, + input_ww_data, + input_hosp_data, + generation_interval, + inf_to_hosp, + infection_feedback_pmf, + params, + compute_likelihood = 1, + ww_outlier_col_name = "flag_as_ww_outlier", + lod_col_name = "below_LOD", + ww_measurement_col_name = "ww", + ww_value_lod_col_name = "lod_sewage", + hosp_value_col_name = "daily_hosp_admits") { + # Assign parameter names + par_names <- colnames(params) + for (i in seq_along(par_names)) { + assign(par_names[i], as.double(params[i])) + } + + # Indicator variable whether or not to include ww in likelihood + include_ww <- ifelse(model_type == "ww", 1, 0) + + last_hosp_data_date <- get_last_hosp_data_date(input_hosp_data) + + # Get state pop + pop <- input_hosp_data |> + dplyr::select(pop) |> + unique() |> + dplyr::pull(pop) + + stopifnot( + "More than one population size in training data" = + length(pop) == 1 + ) + + + if (include_ww == 1) { + # Test for presence of column names + stopifnot( + "Outlier column name isn't present in input dataset" = + ww_outlier_col_name %in% colnames(input_ww_data) + ) + + # Filter out wastewater outliers and arrange data for indexing + ww_data <- input_ww_data |> + dplyr::filter({{ ww_outlier_col_name }} != 1) |> + dplyr::arrange(date, site_index) + + ww_data_sizes <- get_ww_data_sizes(ww_data, lod_col_name) + ww_indices <- get_ww_data_indices(ww_data, input_hosp_data, + owt = ww_data_sizes$owt, + lod_col_name + ) + ww_values <- get_ww_values( + ww_data, ww_measurement_col_name, + ww_value_lod_col_name + ) + + stopifnot( + "Wastewater sampled times not equal to length of input ww data" = + length(ww_indices$ww_sampled_times) == ww_data_sizes$owt + ) + + + message("Prop of population size covered by wastewater: ", sum(ww_values$pop_ww) / pop) + + # Logic to determine the number of subpopulations to estimate R(t) for: + # First determine if we need to add an additional subpopulation + add_auxiliary_site <- ifelse(pop >= sum(ww_values$pop_ww), TRUE, FALSE) + # Then get the number of subpopulations, the population to normalize by + # (sum of the subpopulations), and the vector of sizes of each subpopulation + subpop_data <- get_subpop_data(add_auxiliary_site, + state_pop = pop, + pop_ww = ww_values$pop_ww, + n_ww_sites = ww_data_sizes$n_ww_sites + ) + } else { # Hospital admissions only model) + # Still need to specify wastewater input data, so set as 0s. Won't get + # used by stan to compute the likelihood. None of these will be used. + owt <- 1 + ww_sampled_times <- c(1) + log_conc <- c(1) + } + + # Get the remaining things needed for both models + hosp_data <- add_time_indexing(input_hosp_data) + hosp_data_sizes <- get_hosp_data_sizes( + hosp_data, + forecast_date, + forecast_time, + last_hosp_data_date, + uot, + hosp_value_col_name + ) + hosp_indices <- get_hosp_indices(hosp_data) + hosp_values <- get_hosp_values( + hosp_data, + ot = hosp_data_sizes$ot, + ht = hosp_data_sizes$ht, + hosp_value_col_name + ) + + if (include_ww == 1) { + message("Removed ", nrow(input_ww_data) - ww_data_sizes$owt, " outliers from WW data") + } + + # matrix to transform IHR from weekly to daily + ind_m <- get_ind_m( + hosp_data_sizes$ot + hosp_data_sizes$ht, + hosp_data_sizes$n_weeks + ) + # matrix to transform p_hosp RW from weekly to daily + p_hosp_m <- get_ind_m( + uot + hosp_data_sizes$ot + hosp_data_sizes$ht, + hosp_data_sizes$tot_weeks + ) + + # Estimate of number of initial infections + i0 <- mean(hosp_values$hosp_admits[1:7], na.rm = TRUE) / p_hosp_mean + + # package up parameters for stan data object + viral_shedding_pars <- c( + t_peak_mean, t_peak_sd, viral_peak_mean, viral_peak_sd, + duration_shedding_mean, duration_shedding_sd + ) + + hosp_delay_max <- length(inf_to_hosp) + + if (model_type == "ww") { + data_renewal <- list( + gt_max = gt_max, + hosp_delay_max = hosp_delay_max, + inf_to_hosp = inf_to_hosp, + dur_inf = dur_inf, + mwpd = ml_of_ww_per_person_day, + ot = hosp_data_sizes$ot, + n_subpops = subpop_data$n_subpops, + n_ww_sites = ww_data_sizes$n_ww_sites, + n_ww_lab_sites = ww_data_sizes$n_ww_lab_sites, + owt = ww_data_sizes$owt, + oht = hosp_data_sizes$oht, + n_censored = ww_data_sizes$n_censored, + n_uncensored = ww_data_sizes$n_uncensored, + uot = uot, + ht = hosp_data_sizes$ht, + n_weeks = hosp_data_sizes$n_weeks, + ind_m = ind_m, + tot_weeks = hosp_data_sizes$tot_weeks, + p_hosp_m = p_hosp_m, + generation_interval = generation_interval, + ts = 1:gt_max, + state_pop = pop, + subpop_size = subpop_data$subpop_size, + norm_pop = subpop_data$norm_pop, + ww_sampled_times = ww_indices$ww_sampled_times, + hosp_times = hosp_indices$hosp_times, + ww_sampled_lab_sites = ww_indices$ww_sampled_lab_sites, + ww_log_lod = ww_values$ww_lod, + ww_censored = ww_indices$ww_censored, + ww_uncensored = ww_indices$ww_uncensored, + hosp = hosp_values$hosp_admits, + day_of_week = hosp_values$day_of_week, + log_conc = ww_values$log_conc, + compute_likelihood = compute_likelihood, + include_ww = include_ww, + include_hosp = 1, + if_l = length(infection_feedback_pmf), + infection_feedback_pmf = infection_feedback_pmf, + # All the priors! + viral_shedding_pars = viral_shedding_pars, # tpeak, viral peak, dur_shed + autoreg_rt_a = autoreg_rt_a, + autoreg_rt_b = autoreg_rt_b, + autoreg_p_hosp_a = autoreg_p_hosp_a, + autoreg_p_hosp_b = autoreg_p_hosp_b, + inv_sqrt_phi_prior_mean = inv_sqrt_phi_prior_mean, + inv_sqrt_phi_prior_sd = inv_sqrt_phi_prior_sd, + r_prior_mean = r_prior_mean, + r_prior_sd = r_prior_sd, + log10_g_prior_mean = log10_g_prior_mean, + log10_g_prior_sd = log10_g_prior_sd, + i0_over_n_prior_a = 1 + i0_certainty * (i0 / pop), + i0_over_n_prior_b = 1 + i0_certainty * (1 - (i0 / pop)), + wday_effect_prior_mean = wday_effect_prior_mean, + wday_effect_prior_sd = wday_effect_prior_sd, + initial_growth_prior_mean = initial_growth_prior_mean, + initial_growth_prior_sd = initial_growth_prior_sd, + sigma_ww_site_prior_mean_mean = sigma_ww_site_prior_mean_mean, + sigma_ww_site_prior_mean_sd = sigma_ww_site_prior_mean_sd, + sigma_ww_site_prior_sd_mean = sigma_ww_site_prior_sd_mean, + sigma_ww_site_prior_sd_sd = sigma_ww_site_prior_sd_sd, + eta_sd_sd = eta_sd_sd, + sigma_i0_prior_mode = sigma_i0_prior_mode, + sigma_i0_prior_sd = sigma_i0_prior_sd, + p_hosp_prior_mean = p_hosp_mean, + p_hosp_sd_logit = p_hosp_sd_logit, + p_hosp_w_sd_sd = p_hosp_w_sd_sd, + ww_site_mod_sd_sd = ww_site_mod_sd_sd, + inf_feedback_prior_logmean = infection_feedback_prior_logmean, + inf_feedback_prior_logsd = infection_feedback_prior_logsd, + sigma_rt_prior = sigma_rt_prior, + log_phi_g_prior_mean = log_phi_g_prior_mean, + log_phi_g_prior_sd = log_phi_g_prior_sd, + ww_sampled_sites = ww_indices$ww_sampled_sites, + lab_site_to_site_map = ww_indices$lab_site_to_site_map + ) + } else if (model_type == "hosp") { + data_renewal <- list( + gt_max = gt_max, + hosp_delay_max = hosp_delay_max, + inf_to_hosp = inf_to_hosp, + dur_inf = dur_inf, # this is used bc drift approach needs currently infected + mwpd = ml_of_ww_per_person_day, + ot = hosp_data_sizes$ot, + owt = owt, + oht = hosp_data_sizes$oht, + uot = uot, + ht = hosp_data_sizes$ht, + n_weeks = hosp_data_sizes$n_weeks, + ind_m = ind_m, + tot_weeks = hosp_data_sizes$tot_weeks, + p_hosp_m = p_hosp_m, + generation_interval = generation_interval, + ts = 1:gt_max, + n = pop, + hosp_times = hosp_indices$hosp_times, + ww_sampled_times = ww_sampled_times, + hosp = hosp_values$hosp_admits, + day_of_week = hosp_values$day_of_week, + log_conc = log_conc, + compute_likelihood = compute_likelihood, + include_ww = include_ww, + include_hosp = 1, + if_l = length(infection_feedback_pmf), + infection_feedback_pmf = infection_feedback_pmf, + # Priors + viral_shedding_pars = viral_shedding_pars, # tpeak, viral peak, + # duration shedding + autoreg_rt_a = autoreg_rt_a, + autoreg_rt_b = autoreg_rt_b, + autoreg_p_hosp_a = autoreg_p_hosp_a, + autoreg_p_hosp_b = autoreg_p_hosp_b, + inv_sqrt_phi_prior_mean = inv_sqrt_phi_prior_mean, + inv_sqrt_phi_prior_sd = inv_sqrt_phi_prior_sd, + r_prior_mean = r_prior_mean, + r_prior_sd = r_prior_sd, + log10_g_prior_mean = log10_g_prior_mean, + log10_g_prior_sd = log10_g_prior_sd, + i0_over_n_prior_a = 1 + i0_certainty * (i0 / pop), + i0_over_n_prior_b = 1 + i0_certainty * (1 - (i0 / pop)), + wday_effect_prior_mean = wday_effect_prior_mean, + wday_effect_prior_sd = wday_effect_prior_sd, + initial_growth_prior_mean = initial_growth_prior_mean, + initial_growth_prior_sd = initial_growth_prior_sd, + sigma_ww_prior_mean = sigma_ww_site_prior_mean_mean, + eta_sd_sd = eta_sd_sd, + p_hosp_prior_mean = p_hosp_mean, + p_hosp_sd_logit = p_hosp_sd_logit, + p_hosp_w_sd_sd = p_hosp_w_sd_sd, + inf_feedback_prior_logmean = infection_feedback_prior_logmean, + inf_feedback_prior_logsd = infection_feedback_prior_logsd + ) + } else { + cli::cli_abort("Unknown model") + data_renewal <- list() + } + + stopifnot("Model type not specified properly" = !purrr::is_empty(data_renewal)) + + return(data_renewal) +} + + + +get_inits <- function(model_type, stan_data, params, + n_chains) { + # Assign parmeter names + par_names <- colnames(params) + for (i in seq_along(par_names)) { + assign(par_names[i], as.double(params[i])) + } + + pop <- ifelse(model_type == "ww", stan_data$state_pop, stan_data$n) + + n_weeks <- as.numeric(stan_data$n_weeks) + tot_weeks <- as.numeric(stan_data$tot_weeks) + ot <- as.numeric(stan_data$ot) + ht <- as.numeric(stan_data$ht) + + # Estimate of number of initial infections + i0 <- mean(stan_data$hosp[1:7], na.rm = TRUE) / p_hosp_mean + + if (model_type == "ww") { + n_subpops <- as.numeric(stan_data$n_subpops) + n_ww_lab_sites <- as.numeric(stan_data$n_ww_lab_sites) + get_init <- function() { + init_list <- list( + w = stats::rnorm(n_weeks - 1, 0, 0.01), + eta_sd = abs(stats::rnorm(1, 0, 0.01)), + eta_i0 = abs(stats::rnorm(n_subpops, 0, 0.01)), + sigma_i0 = abs(stats::rnorm(1, 0, 0.01)), + eta_growth = abs(stats::rnorm(n_subpops, 0, 0.01)), + sigma_growth = abs(stats::rnorm(1, 0, 0.01)), + autoreg_rt = abs(stats::rnorm(1, autoreg_rt_a / (autoreg_rt_a + autoreg_rt_b), 0.05)), + log_r_mu_intercept = stats::rnorm(1, convert_to_logmean(1, 0.1), convert_to_logsd(1, 0.1)), + error_site = matrix( + stats::rnorm(n_subpops * n_weeks, + mean = 0, + sd = 0.1 + ), + n_subpops, + n_weeks + ), + autoreg_rt_site = abs(stats::rnorm(1, 0.5, 0.05)), + autoreg_p_hosp = abs(stats::rnorm(1, 1 / 100, 0.001)), + sigma_rt = abs(stats::rnorm(1, 0, 0.01)), + i0_over_n = stats::plogis(stats::rnorm(1, stats::qlogis(i0 / pop), 0.05)), + initial_growth = stats::rnorm(1, 0, 0.001), + inv_sqrt_phi_h = 1 / sqrt(200) + stats::rnorm(1, 1 / 10000, 1 / 10000), + sigma_ww_site_mean = abs(stats::rnorm( + 1, sigma_ww_site_prior_mean_mean, + 0.1 * sigma_ww_site_prior_mean_sd + )), + sigma_ww_site_sd = abs(stats::rnorm( + 1, sigma_ww_site_prior_sd_mean, + 0.1 * sigma_ww_site_prior_sd_sd + )), + sigma_ww_site_raw = abs(stats::rnorm(n_ww_lab_sites, 0, 0.05)), + p_hosp_mean = stats::rnorm(1, stats::qlogis(p_hosp_mean), 0.01), + p_hosp_w = stats::rnorm(tot_weeks, 0, 0.01), + p_hosp_w_sd = abs(stats::rnorm(1, 0.01, 0.001)), + t_peak = stats::rnorm(1, t_peak_mean, 0.1 * t_peak_sd), + viral_peak = stats::rnorm(1, viral_peak_mean, 0.1 * viral_peak_sd), + dur_shed = stats::rnorm(1, duration_shedding_mean, 0.1 * duration_shedding_sd), + log10_g = stats::rnorm(1, log10_g_prior_mean, 0.5), + ww_site_mod_raw = abs(stats::rnorm(n_ww_lab_sites, 0, 0.05)), + ww_site_mod_sd = abs(stats::rnorm(1, 0, 0.05)), + hosp_wday_effect = to_simplex(stats::rnorm(7, 1 / 7, 0.01)), + infection_feedback = abs(stats::rnorm(1, 500, 20)) + ) + return(init_list) + } + } else if (model_type == "hosp") { + get_init <- function() { + init_list <- list( + w = stats::rnorm(n_weeks - 1, 0, 0.01), + eta_sd = abs(stats::rnorm(1, 0, 0.01)), + autoreg_rt = abs(stats::rnorm(1, autoreg_rt_a / (autoreg_rt_a + autoreg_rt_b), 0.05)), + log_r = stats::rnorm(1, convert_to_logmean(1, 0.1), convert_to_logsd(1, 0.1)), + i0_over_n = stats::plogis(stats::rnorm(1, stats::qlogis(i0 / pop), 0.05)), + initial_growth = stats::rnorm(1, 0, 0.001), + inv_sqrt_phi_h = 1 / (sqrt(200)) + stats::rnorm(1, 1 / 10000, 1 / 10000), + sigma_ww = abs(stats::rnorm(1, 0, 0.5)), + p_hosp_mean = stats::rnorm(1, stats::qlogis(p_hosp_mean), 0.01), + autoreg_p_hosp = rep(0, 0), + p_hosp_w = rep(0, 0), + p_hosp_w_sd = rep(0, 0), + t_peak = stats::rnorm(1, t_peak_mean, 0.1 * t_peak_sd), + viral_peak = stats::rnorm(1, viral_peak_mean, 0.1 * viral_peak_sd), + dur_shed = stats::rnorm(1, duration_shedding_mean, 0.1 * duration_shedding_sd), + log10_g = stats::rnorm(1, log10_g_prior_mean, 0.5), + hosp_wday_effect = to_simplex(stats::rnorm(7, 1 / 7, 0.01)), + infection_feedback = abs(stats::rnorm(1, 500, 20)) + ) + return(init_list) + } + } else { + message("model type specified incorrectly") + } + + init_lists <- c() + for (i in 1:n_chains) { # Run for-loop over lists + init_lists[[i]] <- get_init() + } + + return(init_lists) +} + + +#' Get the integer sizes of the wastewater input data +#' +#' @param ww_data Input wastewater dataframe containing one row +#' per observation, with outliers already removed +#' @param lod_col_name A string representing the name of the +#' column in the input_ww_data that provides a 0 if the data point is not above +#' the LOD and a 1 if the data is below the LOD, default value is `below_LOD` +#' +#' @return A list containing the integer sizes of the follow variables that +#' the stan model requires: +#' owt: number of wastewater observations +#' n_censored: number of censored wastewater observations (below the LOD) +#' n_uncensored: number of uncensored wastewter observations (above the LOD) +#' n_ww_sites: number of wastewater sites +#' n_ww_lab_sites: number of unique wastewater site-lab combinations +#' +#' @export +get_ww_data_sizes <- function(ww_data, + lod_col_name = "below_LOD") { + # Test for presence of column names + stopifnot( + "LOD column name isn't present in input dataset" = + lod_col_name %in% colnames(ww_data) + ) + + # Number of wastewater observations + owt <- nrow(ww_data) + # Number of censored wastewater observations + n_censored <- sum(ww_data[lod_col_name] == 1) + # Number of uncensored wastewater observations + n_uncensored <- owt - n_censored + + # Number of ww sites + n_ww_sites <- dplyr::n_distinct(ww_data$site_index) + + # Number of unique combinations of wastewater sites and labs + n_ww_lab_sites <- dplyr::n_distinct(ww_data$lab_site_index) + + data_sizes <- list( + owt = owt, + n_censored = n_censored, + n_uncensored = n_uncensored, + n_ww_sites = n_ww_sites, + n_ww_lab_sites = n_ww_lab_sites + ) + + return(data_sizes) +} + +#' Get wastewater data indices +#' +#' @param ww_data Input wastewater dataframe containing one row +#' per observation, with outliers already removed +#' @param input_hosp_data Input hospital admissions data frame with one row +#' per day and location +#' @param owt number of wastewater observations +#' @param lod_col_name A string representing the name of the +#' column in the input_ww_data that provides a 0 if the data point is not above +#' the LOD and a 1 if the data is below the LOD, default value is `below_LOD` +#' +#' @return A list containing the necessary vectors of indices that +#' the stan model requires: +#' ww_censored: the vector of time points that the wastewater observations are +#' censored (below the LOD) in order of the date and the site index +#' ww_uncensored: the vector of time points that the wastewater observations are +#' uncensored (above the LOD) in order of the date and the site index +#' ww_sampled_times: the vector of time points that the wastewater observations +#' are passed in in log_conc in order of the date and the site index +#' ww_sampled_sites: the vector of sites that correspond to the observations +#' passed in in log_conc in order of the date and the site index +#' ww_sampled_lab_sites: the vector of unique combinations of site and labs +#' that correspond to the observations passed in in log_conc in order of the +#' date and the site index +#' lab_site_to_site_map: the vector of sites that correspond to each lab-site +#' @export +get_ww_data_indices <- function(ww_data, + input_hosp_data, + owt, + lod_col_name = "below_LOD") { + # Vector of indices along the list of wastewater concentrations that + # correspond to censored observations + ww_data_with_index <- ww_data |> + dplyr::mutate(ind_rel_to_sampled_times = dplyr::row_number()) + ww_censored <- ww_data_with_index |> + dplyr::filter(.data[[lod_col_name]] == 1) |> + dplyr::pull(ind_rel_to_sampled_times) + ww_uncensored <- ww_data_with_index |> + dplyr::filter(.data[[lod_col_name]] == 0) |> + dplyr::pull(ind_rel_to_sampled_times) + stopifnot( + "Length of censored vectors incorrect" = + length(ww_censored) + length(ww_uncensored) == owt + ) + + + # Need to get the times of wastewater sampling, starting at the first + # day of hospital admissions data + ww_date_df <- data.frame( + date = seq( + from = min(input_hosp_data$date), + to = max(ww_data$date), + by = "days" + ), + t = 1:(as.integer(max(ww_data$date) - min(input_hosp_data$date)) + 1) + ) + + # Left join the data mapped to time to the wastewater data + spine_ww <- ww_data |> + dplyr::left_join(ww_date_df, by = "date") + + # Pull just the vector of times of wastewater observations + ww_sampled_times <- spine_ww |> + dplyr::pull(t) + + # Pull just the indexes of the sites that correspond to the vector of + # sampled times + ww_sampled_sites <- ww_data$site_index + + # Pull just the indexes of the lab-sites that correspond to the vector of + # sampled times + ww_sampled_lab_sites <- ww_data$lab_site_index + + # Need a vector of indices indicating the site for each lab-site + lab_site_to_site_map <- ww_data |> + dplyr::select(lab_site_index, site_index) |> + dplyr::arrange(lab_site_index, "desc") |> + dplyr::distinct() |> + dplyr::pull(site_index) + + ww_data_indices <- list( + ww_censored = ww_censored, + ww_uncensored = ww_uncensored, + ww_sampled_times = ww_sampled_times, + ww_sampled_sites = ww_sampled_sites, + ww_sampled_lab_sites = ww_sampled_lab_sites, + lab_site_to_site_map = lab_site_to_site_map + ) + + return(ww_data_indices) +} + +#' Get wastewater data values +#' +#' @param ww_data Input wastewater dataframe containing one row +#' per observation, with outliers already removed +#' @param ww_measurement_col_name A string representing the name of the column +#' in the input_ww_data that indicates the wastewater measurement value in +#' natural scale, default is `ww` +#' @param ww_lod_value_col_name A string representing the name of the column +#' in the ww_data that indicates the value of the LOD in natural scale, +#' default is `lod_sewage` +#' @param ww_site_pop_col_name A string representing the name of the column in +#' the ww_data that indicates the number of people represented by that wastewater +#' catchment +#' @param one_pop_per_site a boolean variable indicating if there should only +#' be on catchment area population per site, default is `TRUE` bc this is what +#' the stan model expects +#' +#' @return A list containing the necessary vectors of values that +#' the stan model requires: +#' ww_lod: a vector of the LODs of the corresponding wastewater measurement +#' pop_ww: a vector of the population sizes of the wastewater catchment areas +#' in order of the sites by site_index +#' log_conc: a vector of the log of the wastewater concentration observation +#' @export +get_ww_values <- function(ww_data, + ww_measurement_col_name = "ww", + ww_lod_value_col_name = "lod_sewage", + ww_site_pop_col_name = "ww_pop", + one_pop_per_site = TRUE) { + # Get the vector of log LOD values corresponding to each observation + ww_lod <- ww_data |> + dplyr::pull({{ ww_lod_value_col_name }}) |> + log() + + # Get a vector of population sizes + if (isTRUE(one_pop_per_site)) { + # Want one population per site during the model calibration period, + # so just take the average across the populations reported for each observation + pop_ww <- ww_data |> + dplyr::select(site_index, {{ ww_site_pop_col_name }}) |> + dplyr::group_by(site_index) |> + dplyr::summarise(pop_avg = mean(.data[[ww_site_pop_col_name]])) |> + dplyr::arrange(site_index, "desc") |> + dplyr::pull(pop_avg) + } else { + # Want a vector of length of the number of observations, corresponding to + # the population at that time + pop_ww <- ww_data |> + dplyr::pull({{ ww_site_pop_col_name }}) + } + + + # Get the vector of log wastewater concentrations + log_conc <- ww_data |> + dplyr::mutate(log_conc = as.numeric(log(!!sym(ww_measurement_col_name) + 1e-8))) |> + dplyr::pull(log_conc) + + ww_values <- list( + ww_lod = ww_lod, + pop_ww = pop_ww, + log_conc = log_conc + ) + + return(ww_values) +} + +#' Add time indexing to hospital admissions data +#' +#' @param input_hosp_data data frame with dates and admissions, +#' but without time indexing. +#' +#' @return The same data frame, with an added +#' time index, including NA rows if dates internal +#' to the timeseries are missing admissions data. +#' @export +#' +#' @examples +#' hosp_data_example <- tibble::tibble( +#' date = lubridate::ymd("2024-01-01", "2024-01-02", "2024-01-06"), +#' daily_hosp_admits = c(5, 3, 8) +#' ) +#' hosp_data_w_t <- add_time_indexing(hosp_data_example) +add_time_indexing <- function(input_hosp_data) { + date_df <- tibble::tibble(date = seq( + from = min(input_hosp_data$date), + to = max(input_hosp_data$date), + by = "days" + )) |> + dplyr::mutate(t = dplyr::row_number()) + + hosp_data <- input_hosp_data |> + dplyr::left_join(date_df, by = "date") |> + dplyr::arrange(date) + + return(hosp_data) +} + +#' Get subpopulation data +#' +#' @param add_auxiliary_site Boolean indicating whether to add another +#' subpopulation in addition to the wastewater sites to estimate R(t) of +#' @param state_pop The state population size +#' @param pop_ww The population size in each of the wastewater sites +#' @param n_ww_sites The number of wastewater sites +#' +#' @return A list containing the necessary integers and vectors that stan +#' needs to estiamte infection dynamics for each subpopulation +#' @export +#' +#' @examples subpop_data <- get_subpop_data(TRUE, 100000, c(1000, 500), 2) +get_subpop_data <- function(add_auxiliary_site, + state_pop, + pop_ww, + n_ww_sites) { + if (add_auxiliary_site) { + # In most cases, wastewater catchment coverage < entire state. + # So here we add a subpopulation that represents the population not + # covered by wastewater surveillance + norm_pop <- state_pop + n_subpops <- n_ww_sites + 1 + subpop_size <- c(pop_ww, state_pop - sum(pop_ww)) + } else { + message("Sum of wastewater catchment areas is greater than state pop") + norm_pop <- sum(pop_ww) + # If sum catchment areas > state pop, + # use sum of catchment area pop to normalize + n_subpops <- n_ww_sites # Only divide the state into n_site subpops + subpop_size <- pop_ww + } + + subpop_data <- list( + norm_pop = norm_pop, + n_subpops = n_subpops, + subpop_size = subpop_size + ) + return(subpop_data) +} + +#' Get hospital data integer sizes for stan +#' +#' @param input_hosp_data a dataframe with the input hospital admissions data +#' @param forecast_date string indicating the forecast date +#' @param forecast_time integer indicating the number of days to make a forecast +#' for +#' @param last_hosp_data_date string indicating the date of the last observed +#' hospital admission +#' @param uot integer indicating the time of model initialization when there are +#' no observations +#' @param hosp_value_col_name A string represeting the name of the column in the +#' input_hosp-data that indicates the number of daily hospital admissions, +#' default is `daily_hosp_admits` +#' +#' @return A list containing the integer sizes of the follow variables that +#' the stan model requires: +#' ht: integer indicating horizon time for the model(hospital admissions +#' nowcast + forecast time in days) +#' ot: integer indicating the total duration of time that the hospital admissions +#' model has available calibration data +#' oht: integer indicating the number of hospital admission observations +#' n_weeks: number of weeks (rounded up) that hospital admissions are generated +#' from the model +#' tot_weeks: number of week(rounded up) that infections are generated for +#' @export +get_hosp_data_sizes <- function(input_hosp_data, + forecast_date, + forecast_time, + last_hosp_data_date, + uot, + hosp_value_col_name = "daily_hosp_admits") { + nowcast_time <- as.integer(lubridate::ymd(forecast_date) - lubridate::ymd(last_hosp_data_date)) + ht <- nowcast_time + forecast_time + ot <- nrow(input_hosp_data) + oht <- input_hosp_data |> + dplyr::filter(!is.na(.data[[hosp_value_col_name]])) |> + nrow() + n_weeks <- ceiling((ot + ht) / 7) + tot_weeks <- ceiling((ot + uot + ht) / 7) + hosp_data_sizes <- list( + ht = ht, + ot = ot, + oht = oht, + n_weeks = n_weeks, + tot_weeks = tot_weeks + ) + return(hosp_data_sizes) +} +#' Get hospital admissions indices +#' +#' @param input_hosp_data a dataframe with the input hospital admissions data +#' +#' @return A list containing the vectors of indices that +#' the stan model requires: +#' hosp_times: a vector of integer times corresponding to the times when the +#' hospital admissions observations were made +#' @export +get_hosp_indices <- function(input_hosp_data) { + hosp_times <- input_hosp_data |> + dplyr::pull(t) + + hosp_indices <- list( + hosp_times = hosp_times + ) + return(hosp_indices) +} + +#' Get hospital admissions values +#' +#' @param input_hosp_data a dataframe with the input hospital admissions data +#' @param ot integer indicating the total duration of time that the hospital admissions +#' model has available calibration data in days +#' @param ht integer indicating the number of days to produce hospital admissions +#' outside the calibration period (forecast + nowcast time) in days +#' @param hosp_value_col_name A string represeting the name of the column in the +#' input_hosp-data that indicates the number of daily hospital admissions, +#' default is `daily_hosp_admits` +#' +#' @return A list containing the necessary vectors of values that +#' the stan model requires: +#' hosp_admits: a vector of number of daily hospital admissions observations +#' day_of_week: a vector indicating the day of the week of each of the dates +#' in the calibration and forecast period +# +#' @export +get_hosp_values <- function(input_hosp_data, + ot, + ht, + hosp_value_col_name = "daily_hosp_admits") { + hosp_admits <- input_hosp_data |> + dplyr::pull({{ hosp_value_col_name }}) + + full_dates <- seq( + from = min(input_hosp_data$date), + to = min(input_hosp_data$date) + lubridate::days(ht + ot - 1), + by = "days" + ) + day_of_week <- lubridate::wday(full_dates, week_start = 1) + + hosp_values <- list( + hosp_admits = hosp_admits, + day_of_week = day_of_week + ) + return(hosp_values) +} diff --git a/wweval/R/get_input_data.R b/wweval/R/get_input_data.R new file mode 100644 index 00000000..dc9a74aa --- /dev/null +++ b/wweval/R/get_input_data.R @@ -0,0 +1,264 @@ +#' Get input wastewater data +#' +#' @param forecast_date_i The forecast date for this iteration, +#' formatted as a character string in IS08601 format (YYYY-MM-DD). +#' @param location_i The location (state or other jurisdiction) +#' for this iteration, formatted as a string (uppercase USPS two-letter +#' abbreviation, e.g. AK for Alaska, DC for the District of Columbia, +#' PR for Puerto Rico). +#' @param scenario_i The scenario for this iteration, formatted as a +#' string +#' @param scenario_dir A string indicating the path to the directory +#' containing the csvs with the wwtp ids needed for each scenario +#' @param ww_data_dir A string indicating the path to the directory +#' containing time stamped wastewater datasets +#' @param calibration_time The duration of the model calibration period +#' (relative to the last hospital admissions data point) in units of +#' model timesteps (typically days). +#' @param last_hosp_data_date A date indicating the date of last reported +#' hospital admission as of the forecast date +#' @param ww_data_mapping A string indicating how to map the +#' forecast date to the wastewater dates (see [date_of_ww_data()] +#' for more details) +#' +#' @return a dataframe containing the transformed and clean NWSS data +#' at the site and lab label for the forecast date and location specified +#' @export +get_input_ww_data <- function(forecast_date_i, + location_i, + scenario_i, + scenario_dir, + ww_data_dir, + calibration_time, + last_hosp_data_date, + ww_data_mapping) { + # Load in the appropriate time-stamped NWSS dataset. This depends on + # the date `ww_data_mapping` which is a string that we will specify + # in the config + date_to_pull <- date_of_ww_data( + forecast_date_i, ww_data_mapping, + ww_data_dir + ) + + ww_data_path <- file.path(ww_data_dir, paste0(date_to_pull, ".csv")) + raw_nwss_data <- readr::read_csv(ww_data_path, show_col_types = FALSE) + + # Use package functions to subset NWSS data + ww_data <- raw_nwss_data |> + init_subset_nwss_data() + # Get the data corresponding to the scenario + subsetted_ww_data <- filter_sites_by_scenario( + ww_data, scenario_i, + scenario_dir + ) + ww <- subsetted_ww_data |> + clean_ww_data() |> + filter( + location %in% c(!!location_i), + date >= lubridate::ymd(!!last_hosp_data_date) - + lubridate::days(!!calibration_time) + lubridate::days(1) + ) + + # Get extra columns that identify wastewater outliers + ww_w_outliers <- flag_ww_outliers(ww) |> + select( + date, location, ww, site, lab, lab_wwtp_unique_id, + ww_pop, below_LOD, lod_sewage, flag_as_ww_outlier + ) + # If more than one location, than this data isn't being used for fitting + # And we don't wanto generate these + if (length(location_i) == 1) { + site_map <- ww_w_outliers |> + distinct(site) |> + mutate(site_index = row_number()) + site_lab_map <- ww_w_outliers |> + distinct(lab_wwtp_unique_id) |> + mutate(lab_site_index = row_number()) + + ww <- ww_w_outliers |> + left_join(site_map, by = "site") |> + left_join(site_lab_map, by = "lab_wwtp_unique_id") + } else { + ww <- ww_w_outliers + } + + + return(ww) +} + +#' Filter sites by scenario +#' +#' @param init_subset_nwss_data a dataframe of the raw NWSS data +#' filtered to exclude solids and upstream sites +#' @param scenario a string indicating what scenario we are running. +#' Default is "Status quo" which uses all the data we have available +#' @param scenario_dir a string indicating the file path where the +#' scenario csvs will live. Default is NA because we don't need this +#' for the status quo scenario +#' +#' @return a dataframe that only contains the ww data from +#' the sites in the list pertaining to the scenario +#' @export +filter_sites_by_scenario <- function(init_subset_nwss_data, + scenario = "Status quo", + scenario_dir = NA) { + list_of_wwtp_ids <- if (scenario == "Status quo") { + list_of_wwtp_ids <- unique(init_subset_nwss_data$wwtp_name) + } else { + list_of_wwtp_ids <- utils::read.csv(file.path( + scenario_dir, + glue::glue("{scenario}.csv") + )) |> + pull(wwtp_name) + } + + filtered_nwss_data <- init_subset_nwss_data |> + dplyr::filter(wwtp_name %in% !!list_of_wwtp_ids) + + return(filtered_nwss_data) +} + +#' Get input hospital admissions data +#' +#' @param forecast_date_i The forecast date for this iteration +#' @param location_i The location (state) for this iteration +#' @param hosp_data_dir A string indicating the path to the directory containing +#' time stamped hospital admissions datasets +#' @param calibration_time A numeric indicating the duration of model +#' calibration (based on the last hospital admissions data point) +#' @param load_from_covidcast boolean indicating whether or not the hospital +#' admissions datasets should be loaded directly from covidcast. +#' `default = FALSE` because we are assuming that we have already created a +#' folder with time stamped datasets +#' +#' @return a dataframe containing the cleaned hospital admissions needed as +#' an input to the stan model for the specified forecast date and location +#' @export +get_input_hosp_data <- function(forecast_date_i, location_i, + hosp_data_dir, calibration_time, + load_from_covidcast = FALSE) { + fp <- file.path(hosp_data_dir, paste0(forecast_date_i, ".csv")) + + # Load in the appropriate time-stamped hospital admissions dataset + if (isTRUE(load_from_covidcast)) { + hosp_raw <- quiet(covidcast::covidcast_signal( + "hhs", "confirmed_admissions_covid_1d", + geo_type = "state", + geo_values = "*", + as_of = forecast_date_i + )) + + hosp <- hosp_raw |> + as_tibble() |> + mutate(abbreviation = toupper(geo_value)) |> + left_join(state_population_table, by = "abbreviation") |> + rename( + date = time_value, + daily_hosp_admits = value, + pop = population + ) |> + select(date, ABBR = abbreviation, daily_hosp_admits, pop) + message("Writing full time stamped dataset to local storage") + + readr::write_csv(hosp, fp) + } else { + hosp <- readr::read_csv(fp) + } + last_hosp_data_date <- max(hosp$date, na.rm = TRUE) + input_hosp <- hosp |> + rename(location = ABBR) |> + filter( + location %in% c(!!location_i), + date >= ( + ymd(!!last_hosp_data_date) - + lubridate::days(!!calibration_time) + + lubridate::days(1) + ) + ) + return(input_hosp) +} + +#' Date of wastewater data +#' +#' @param forecast_date the forecast date for this iteration +#' @param ww_data_mapping a string that tells this function how to pick +#' data pull dates from forecast dates. This function needs to be configured +#' for each new string. +#' @param ww_data_dir A string indicating the path to the directory containing +#' time stamped wastewater datasets +#' +#' @return the date to get the ww data from +#' @export +date_of_ww_data <- function(forecast_date, ww_data_mapping, + ww_data_dir) { + if (is.null(ww_data_mapping)) { + dates <- gsub(".{4}$", "", list.files(ww_data_dir)) + # Get the nearest date less than the forecast date + date_to_pull <- as.character(max(dates[dates < ymd(forecast_date)], na.rm = TRUE)) + } else if (ww_data_mapping == "Monday: Monday, Wednesday: Monday") { + # Error if mapping is Monday to Wednesday and forecast date is neither + stopifnot( + "Forecast date is not a Monday or Wednesday" = + lubridate::wday(forecast_date) == 2 || lubridate::wday(forecast_date) == 4 + ) + + if (lubridate::wday(forecast_date) == 2) { + date_to_pull <- as.character(ymd(forecast_date)) + } else if (lubridate::wday(forecast_date) == 4) { + date_to_pull <- as.character( + ymd(forecast_date) - lubridate::days(2) + ) + } + } else { # Anything else right now we don't have algorithm written for, + # so prompt + date_to_pull <- NA + } + + stopifnot( + "Need to write case to specify which wastewater data to pull" = + !is.na(date_to_pull) + ) + + return(date_to_pull) +} + +#' Get last hospital admissions data point date +#' +#' @param input_hosp the hospital admissions dataset for that location and +#' forecast date +#' +#' @return a date indicating the last day of observed data +#' @export +get_last_hosp_data_date <- function(input_hosp) { + last_hosp_data_date <- max(input_hosp$date, na.rm = TRUE) + return(last_hosp_data_date) +} + + +#' Clean wastewater data +#' +#' @param nwss_subset the raw nwss data filtered down to only the columns we use +#' +#' @return A site-lab level dataset with names and variables that can be used +#' for model fitting +#' @export +clean_ww_data <- function(nwss_subset) { + ww_data <- nwss_subset |> + ungroup() |> + rename( + date = sample_collect_date, + ww = pcr_target_avg_conc, + ww_pop = population_served + ) |> + mutate( + location = toupper(wwtp_jurisdiction), + site = wwtp_name, + lab = lab_id + ) |> + select( + date, location, ww, site, lab, lab_wwtp_unique_id, ww_pop, + below_LOD, lod_sewage + ) + + return(ww_data) +} diff --git a/wweval/R/plots.R b/wweval/R/plots.R new file mode 100644 index 00000000..6528ed1e --- /dev/null +++ b/wweval/R/plots.R @@ -0,0 +1,305 @@ +#' Get plot of wastewater data compared to model draws +#' +#' @param draws_w_data A long tidy dataframe containing draws from the model of +#' the estimated wastewater concentrations in each site joined with both the data +#' the model was calibrated to and the later observed data for evaluating the +#' future predicted concentrations against. +#' @param location the jursidiction the data is from +#' @param model_type type of model the output is from, default is `ww` +#' @param n_draws number of draws to plot, default = 100 +#' +#' @return a ggplot object faceted by site showing the draws +#' @export +get_plot_ww_data_comparison <- function(draws_w_data, + location, + model_type = "ww", + n_draws = 100) { + sampled_draws <- sample(1:max(draws_w_data$draw), n_draws) + draws_w_data_subsetted <- draws_w_data |> + dplyr::filter( + draw %in% !!sampled_draws, + name == "pred_ww" + ) + + p <- ggplot(draws_w_data_subsetted) + + geom_line(aes(x = date, y = value, group = draw, color = site_lab_name), + linewidth = 0.1, alpha = 0.1, + show.legend = FALSE + ) + + geom_point(aes(x = date, y = eval_data), + fill = "white", size = 1, shape = 21, + show.legend = FALSE + ) + + geom_point(aes(x = date, y = calib_data), + color = "black", + show.legend = FALSE + ) + + geom_vline(aes(xintercept = ymd(forecast_date)), linetype = "dashed") + + scale_y_continuous(trans = "log10") + + facet_wrap(~site_lab_name, scales = "free") + + geom_point( + data = draws_w_data_subsetted |> filter(below_LOD == 1), + aes(x = date, y = value), color = "red", size = 1.1 + ) + + geom_point( + data = draws_w_data_subsetted |> filter(flag_as_ww_outlier == 1), + aes(x = date, y = value), color = "blue", size = 1.1 + ) + + xlab("") + + ylab("Genome copies per mL") + + ggtitle(glue::glue( + "Site-level expected observed wastewater concentration in {location} from {model_type} model" + )) + + theme_bw() + + scale_color_discrete() + + scale_fill_discrete() + + scale_x_date( + date_breaks = "2 weeks", + labels = scales::date_format("%Y-%m-%d") + ) + + theme_bw() + + theme( + axis.text.x = element_text( + size = 8, vjust = 1, + hjust = 1, angle = 45 + ), + axis.title.x = element_text(size = 12), + axis.title.y = element_text(size = 12), + plot.title = element_text( + size = 10, + vjust = 0.5, hjust = 0.5 + ) + ) + return(p) +} + +#' Get plot of hospital admissions data compared to model draws +#' +#' @param draws_w_data A long tidy dataframe containing draws from the model of +#' the estimated hospital admissions joined with both the data +#' the model was calibrated to and the later observed data for evaluating the +#' future predicted concentrations against. +#' @param location the jursidiction the data is from +#' @param model_type type of model the output is from, options are +#' "ww" or "hosp" +#' @param n_draws number of draws to plot, default = 100 +#' +#' @return a ggplot object showing the model draws of hospital admissions +#' alongside the calibration and forecast data +#' @export +get_plot_hosp_data_comparison <- function(draws_w_data, + location, + model_type, + n_draws = 100) { + sampled_draws <- sample(1:max(draws_w_data$draw), n_draws) + draws_w_data_subsetted <- draws_w_data |> + dplyr::filter( + draw %in% !!sampled_draws, + name == "pred_hosp" + ) + + plot_color <- ifelse(model_type == "ww", "cornflowerblue", "purple4") + + + p <- ggplot(draws_w_data_subsetted) + + geom_line(aes(x = date, y = value, group = draw), + color = plot_color, + linewidth = 0.1, alpha = 0.1, + show.legend = FALSE + ) + + geom_point(aes(x = date, y = eval_data), + fill = "white", size = 1, shape = 21, + show.legend = FALSE + ) + + geom_point(aes(x = date, y = calib_data), + color = "black", + show.legend = FALSE + ) + + geom_vline(aes(xintercept = ymd(forecast_date)), linetype = "dashed") + + scale_y_continuous(trans = "log10") + + xlab("") + + ylab("Daily hospital admissions") + + ggtitle(glue::glue( + "Site-level expected observed hospital admissions in {location} from {model_type} model" + )) + + theme_bw() + + scale_color_discrete() + + scale_fill_discrete() + + scale_x_date( + date_breaks = "2 weeks", + labels = scales::date_format("%Y-%m-%d") + ) + + theme_bw() + + theme( + axis.text.x = element_text( + size = 8, vjust = 1, + hjust = 1, angle = 45 + ), + axis.title.x = element_text(size = 12), + axis.title.y = element_text(size = 12), + plot.title = element_text( + size = 10, + vjust = 0.5, hjust = 0.5 + ) + ) + return(p) +} + +get_plot_quantile_comparison <- function(hosp_quantiles, + days_to_show_forecast = 7) { + location <- hosp_quantiles |> + pull(location) |> + unique() + n_scenarios <- hosp_quantiles |> + dplyr::pull(scenario) |> + unique() |> + length() + + + quantiles_wide <- hosp_quantiles |> + dplyr::filter(quantile %in% c(0.025, 0.25, 0.5, 0.75, 0.975)) |> + tidyr::pivot_wider( + id_cols = c( + forecast_date, period, scenario, + date, t, eval_data + ), + names_from = quantile, + values_from = value + ) + + p <- ggplot(quantiles_wide) + + geom_point(aes(x = date, y = eval_data)) + + geom_line(aes(x = date, y = eval_data)) + + geom_line( + data = quantiles_wide |> filter( + date >= forecast_date, + date <= forecast_date + days_to_show_forecast + ), + aes( + x = date, y = `0.5`, group = forecast_date, + color = scenario + ) + ) + + geom_ribbon( + data = quantiles_wide |> filter( + date >= forecast_date, + date <= forecast_date + days_to_show_forecast + ), + aes( + x = date, ymin = `0.025`, ymax = `0.975`, + fill = scenario, group = forecast_date + ), alpha = 0.1, + show.legend = FALSE + ) + + geom_ribbon( + data = quantiles_wide |> filter( + date >= forecast_date, + date <= forecast_date + days_to_show_forecast + ), + aes( + x = date, ymin = `0.25`, ymax = `0.75`, + fill = scenario, group = forecast_date + ), alpha = 0.1, + show.legend = FALSE + ) + + facet_wrap(~scenario, nrow = n_scenarios) + + theme_bw() + + xlab("") + + ylab("Daily hospital admissions") + + ggtitle(glue::glue( + "Forecasted vs later observed hospital admissions in {location}" + )) + + scale_color_discrete() + + scale_fill_discrete() + + scale_x_date( + date_breaks = "2 weeks", + labels = scales::date_format("%Y-%m-%d") + ) + + theme_bw() + + theme( + axis.text.x = element_text( + size = 8, vjust = 1, + hjust = 1, angle = 45 + ), + axis.title.x = element_text(size = 12), + axis.title.y = element_text(size = 12), + plot.title = element_text( + size = 10, + vjust = 0.5, hjust = 0.5 + ) + ) + + return(p) +} + + +get_plot_raw_scores <- function(all_scores, + score_metric = "crps") { + ggplot(all_scores) + + geom_point(aes(x = date, y = crps, color = scenario, group = c(forecast_date))) + + facet_grid(forecast_date ~ location) + + theme_bw() +} + +get_plot_summarized_scores <- function(all_scores, + score_metric = "crps") { + scores <- all_scores |> + dplyr::select(-tar_group) |> + data.table::as.data.table() + + summarized_scores <- all_scores |> + data.table::as.data.table() |> + scoringutils::summarize_scores( + by = c("scenario", "period", "forecast_date", "location") + ) + + n_periods <- summarized_scores |> + dplyr::pull(period) |> + unique() |> + length() + location <- summarized_scores |> + pull(location) |> + unique() + + summary_across_dates <- all_scores |> + data.table::as.data.table() |> + scoringutils::summarize_scores( + by = c("period", "location", "scenario") + ) + + p <- ggplot(summarized_scores) + + geom_bar(aes(x = forecast_date, y = {{ score_metric }}, fill = scenario), + stat = "identity", position = "dodge", alpha = 0.5 + ) + + geom_hline( + data = summary_across_dates, + aes(yintercept = get(score_metric), color = scenario) + ) + + facet_wrap(~period, nrow = n_periods) + + theme_bw() + + xlab("") + + ylab(glue::glue("{score_metric} by forecast date and scenario")) + + ggtitle(glue::glue( + "Score comparison over time in {location}" + )) + + scale_color_discrete() + + scale_fill_discrete() + + scale_x_date( + date_breaks = "1 weeks", + labels = scales::date_format("%Y-%m-%d") + ) + + theme( + axis.text.x = element_text( + size = 8, vjust = 1, + hjust = 1, angle = 45 + ), + axis.title.x = element_text(size = 12), + axis.title.y = element_text(size = 12), + plot.title = element_text( + size = 10, + vjust = 0.5, hjust = 0.5 + ) + ) + + return(p) +} diff --git a/wweval/R/post_processing.R b/wweval/R/post_processing.R new file mode 100644 index 00000000..b6d09249 --- /dev/null +++ b/wweval/R/post_processing.R @@ -0,0 +1,170 @@ +#' Get model draws combined with input and evaluation data +#' +#' @param model_output the type of model expected observation you want, +#' options are "hosp" and "ww" +#' @param model_type The type of model, options are "ww" and "hosp" +#' @param draws The raw draws dataframe +#' @param forecast_date The date the forecast was made +#' @param scenario A name for the scenario that the input +#' data represents, as a string. +#' @param location The location for which the model is being run. +#' @param input_data The input dataset used for fitting the model. +#' @param eval_data The retrospective dataset used to evaluate the model +#' (should have data beyond the forecast date) +#' @param last_hosp_data_date The date of the last hospital admissions data point +#' that the model is calibrated to +#' @param ot An integer indicating the number of days the model is calibrated +#' to hospital admissions data +#' @param forecast_time An integer indicate the time in days of the forecast +#' +#' @return a dataframe of model draws subsetted to only the specified output +#' type, joined with the evaluation data and the input calibration data +#' @export +get_model_draws_w_data <- function(model_output, + model_type, + draws, + forecast_date, + scenario, + location, + input_data, + eval_data, + last_hosp_data_date, + ot, + forecast_time = 28) { + nowcast_time <- as.integer(ymd(forecast_date) - ymd(last_hosp_data_date)) + ht <- nowcast_time + forecast_time + # Date spine for joining data + date_df <- tibble::tibble(date = seq( + from = min(input_data$date), + to = min(input_data$date) + ot + ht, + by = "days" + )) |> + mutate(t = row_number()) + + eval_data <- eval_data |> + dplyr::filter(location == !!location) + stopifnot( + "More than one location in eval data that is getting joined" = + eval_data |> dplyr::pull(location) |> unique() |> length() == 1 + ) + + + # Dataframe with columns + if (model_output == "hosp") { + pop <- input_data |> + pull(pop) |> + unique() + + draws_w_data <- draws |> + spread_draws(pred_hosp[t]) |> + rename(value = pred_hosp) |> + mutate( + draw = `.draw`, + name = "pred_hosp" + ) |> + select(name, t, value, draw) |> + left_join(date_df, by = "t") |> + left_join(input_data |> select(-pop, -location), + by = c("date") + ) |> + rename(calib_data = daily_hosp_admits) |> + left_join(eval_data |> select(-pop, -location), + by = c("date") + ) |> + rename(eval_data = daily_hosp_admits) |> + mutate( + forecast_date = lubridate::ymd(!!forecast_date), + model_type = !!model_type, + location = !!location, + pop = !!pop, + scenario = !!scenario + ) |> + ungroup() + } + + if (model_output == "ww") { + # Then we also want to output the wastewater predictions + lab_site_map <- input_data |> + select(lab_site_index, site, lab, location) |> + distinct() + # Get mean population in the site over the calibration period, this + # is the same pop size we use in the model fitting + site_pop_map <- input_data |> + select(site, ww_pop, date) |> + group_by(site) |> + summarise(ww_pop = mean(ww_pop)) + + draws_w_data <- draws |> + spread_draws(pred_ww[lab_site_index, t]) |> + rename(value = pred_ww) |> + mutate( + draw = `.draw`, + name = "pred_ww", + value = exp(value) + ) |> + select(name, lab_site_index, t, value, draw) |> + left_join(date_df, by = "t") |> + left_join(lab_site_map, by = "lab_site_index") |> + left_join(site_pop_map, by = c("site")) |> + left_join( + input_data |> select( + date, lab_site_index, + ww, below_LOD, + lod_sewage, flag_as_ww_outlier + ), + by = c("date", "lab_site_index") + ) |> + rename(calib_data = ww) |> + left_join(eval_data |> select(date, ww, lab, site), + by = c("date", "lab", "site") + ) |> + rename(eval_data = ww) |> + mutate( + forecast_date = ymd(!!forecast_date), + model_type = !!model_type, + scenario = !!scenario, + site_lab_name = glue::glue("Site: {site}, Lab: {lab}"), + location = !!location + ) |> + ungroup() + } + return(draws_w_data) +} + + +#' Get quantiles for state-level generated quantities +#' +#' @param draws a dataframe containing all the draws from the model estimated +#' state-level quantities +#' +#' @return a dataframe containing the quantile value for the quantiles +#' required for the Hub submission +#' @export +get_state_level_quantiles <- function(draws) { + quantiles <- cfaforecastrenewalww::trajectories_to_quantiles( + draws, + timepoint_cols = "date", + value_col = "value", + id_cols = c("location", "name", "scenario", "model_type") + ) |> + dplyr::rename( + quantile = quantile_level, + value = quantile_value + ) |> + dplyr::left_join( + draws |> + select(-draw, -value) |> + unique(), + by = c("date", "name", "location", "scenario", "model_type") + ) |> + dplyr::mutate( + period = dplyr::case_when( + !is.na(calib_data) ~ "calibration", + date <= forecast_date ~ "nowcast", + TRUE ~ "forecast" + ), + quantile = round(quantile, 4) + ) + + return(quantiles) +} diff --git a/wweval/R/sample_model.R b/wweval/R/sample_model.R new file mode 100644 index 00000000..d72d70aa --- /dev/null +++ b/wweval/R/sample_model.R @@ -0,0 +1,55 @@ +#' Fit the model +#' +#' @param standata a list of elements to pass to stan +#' @param compiled_model the compiled model object +#' @param init_lists nested list of initial parameter values for each chain +#' @param iter_warmup number of iterations to save in MCMC sampling, +#' default = 250 +#' @param iter_sampling number of iterations to save in MCMC sampling, +#' default = 250 +#' @param max_treedepth maximum treedepth of MCMC sampling, defauly = 12 +#' @param adapt_delta MCMC accaptance probability, default = 0.95 +#' @param n_chains number of independent MCMC chains to run, default = 4 +#' @param seed seed of random number generator default = 123 +#' +#' +#' @return a list containing draws, diagnostics, and summary_diagnostics +#' @export +sample_model <- function(standata, + compiled_model, + init_lists, + iter_warmup = 250, + iter_sampling = 250, + max_treedepth = 12, + adapt_delta = 0.95, + n_chains = 4, + seed = 123) { + if (!inherits(compiled_model, "CmdStanModel")) { + cli::cli_abort(paste0( + "Argument `compiled_model` must be a ", + "cmdstanr::CmdStanModel object; got a ", + "{class(compiled_model)} object instead" + )) + } + fit <- compiled_model$sample( + data = standata, + init = init_lists, + iter_warmup = iter_warmup, + iter_sampling = iter_sampling, + max_treedepth = max_treedepth, + adapt_delta = adapt_delta, + chains = n_chains, + seed = seed + ) + + draws <- fit$draws() + diagnostics <- fit$sampler_diagnostics(format = "df") + summary_diagnostics <- fit$diagnostic_summary() + + draws_and_diagnostics <- list( + draws = draws, + diagnostics = diagnostics, + summary_diagnostics = summary_diagnostics + ) + return(draws_and_diagnostics) +} diff --git a/wweval/R/score.R b/wweval/R/score.R new file mode 100644 index 00000000..d2de36f6 --- /dev/null +++ b/wweval/R/score.R @@ -0,0 +1,54 @@ +#' Get the scores for ever day for a particular location and forecast date +#' +#' @param draws a dataframe of the model estimated quantity you are evaluating +#' alongside the evaluation data +#' @param scenario a string indicating the wastewater data scenario we're +#' running +#' @param metrics Vector of scoring metrics to output, passed as the +#' `metrics` argument to [scoringutils::score()]. Default +#' `c("crps", "dss", "bias", "mad", "ae_median", "se_mean")`. +#' @param ... Additional named arguments passed to [scoringutils::score()]. +#' +#' @return a dataframe containing a score for each day in the nowcast +#' and forecast period +#' @export +get_full_scores <- function(draws, + scenario, + metrics = c( + "crps", "dss", "bias", + "mad", "ae_median", "se_mean" + ), + ...) { + # Filter to after the last date + last_calib_date <- max(draws$date[!is.na(draws$calib_data)]) + + forecasted_draws <- draws |> + filter(date > last_calib_date) |> + ungroup() |> + # Rename for scoring utils + rename( + true_value = eval_data, + prediction = value, + sample = draw, + model = model_type + ) |> + select( + location, + forecast_date, + date, + true_value, + prediction, + sample, + model + ) |> + mutate( + period = ifelse(date <= forecast_date, "nowcast", "forecast"), + scenario = !!scenario + ) + + + scores <- forecasted_draws |> + scoringutils::score(metrics = metrics, ...) + + return(scores) +} diff --git a/wweval/R/set_up.R b/wweval/R/set_up.R new file mode 100644 index 00000000..c616bf09 --- /dev/null +++ b/wweval/R/set_up.R @@ -0,0 +1,16 @@ +#' Make dataframe +#' +#' @param config a config file with a list of locations, forecast_dates, and +#' model types +#' +#' @return an expanded dataframe that contains 3 columns with all combinations +#' of locations, forecast_dates and model types +#' @export +make_df <- function(config) { + df <- as.data.frame(expand.grid( + location = config$location, + forecast_date = config$forecast_date, + model_type = config$model_type + )) + return(df) +} diff --git a/wweval/R/utils.R b/wweval/R/utils.R new file mode 100644 index 00000000..c3f61d60 --- /dev/null +++ b/wweval/R/utils.R @@ -0,0 +1,15 @@ +#' @title Suppress output and messages for code. +#' @description Used in the pipeline. +#' @return The result of running the code. +#' @param code Code to run quietly. +#' @examples +#' library(cmdstanr) +#' compile_model("stan/model.stan") +#' quiet(fit_model("stan/model.stan", simulate_data_discrete())) +#' out +#' @noRd +quiet <- function(code) { + sink(nullfile()) + on.exit(sink()) + suppressMessages(code) +} diff --git a/wweval/R/wweval-package.R b/wweval/R/wweval-package.R new file mode 100644 index 00000000..ad4946f3 --- /dev/null +++ b/wweval/R/wweval-package.R @@ -0,0 +1,19 @@ +#' @keywords internal +"_PACKAGE" + + +#' @importFrom arrow read_parquet read_ipc_stream write_parquet +#' @importFrom jsonlite fromJSON +#' @importFrom readr read_csv write_csv +#' @importFrom glue glue +#' @importFrom rlang sym +#' @importFrom lubridate ymd +#' @importFrom tidybayes spread_draws +#' @importFrom dplyr filter left_join select pull distinct mutate as_tibble +#' rename ungroup arrange row_number group_by +#' @importFrom tidyr pivot_wider pivot_longer +#' @importFrom ggplot2 ggplot facet_wrap geom_line geom_hline geom_point geom_bar +#' theme scale_y_continuous scale_colour_discrete scale_fill_discrete geom_ribbon +#' scale_x_date facet_grid geom_vline labs +#' @importFrom cmdstanr cmdstan_model +NULL diff --git a/wweval/man/add_time_indexing.Rd b/wweval/man/add_time_indexing.Rd new file mode 100644 index 00000000..a17f6bbd --- /dev/null +++ b/wweval/man/add_time_indexing.Rd @@ -0,0 +1,27 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/format_data_for_stan.R +\name{add_time_indexing} +\alias{add_time_indexing} +\title{Add time indexing to hospital admissions data} +\usage{ +add_time_indexing(input_hosp_data) +} +\arguments{ +\item{input_hosp_data}{data frame with dates and admissions, +but without time indexing.} +} +\value{ +The same data frame, with an added +time index, including NA rows if dates internal +to the timeseries are missing admissions data. +} +\description{ +Add time indexing to hospital admissions data +} +\examples{ +hosp_data_example <- tibble::tibble( + date = lubridate::ymd("2024-01-01", "2024-01-02", "2024-01-06"), + daily_hosp_admits = c(5, 3, 8) +) +hosp_data_w_t <- add_time_indexing(hosp_data_example) +} diff --git a/wweval/man/clean_ww_data.Rd b/wweval/man/clean_ww_data.Rd new file mode 100644 index 00000000..bc6309f1 --- /dev/null +++ b/wweval/man/clean_ww_data.Rd @@ -0,0 +1,18 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/get_input_data.R +\name{clean_ww_data} +\alias{clean_ww_data} +\title{Clean wastewater data} +\usage{ +clean_ww_data(nwss_subset) +} +\arguments{ +\item{nwss_subset}{the raw nwss data filtered down to only the columns we use} +} +\value{ +A site-lab level dataset with names and variables that can be used +for model fitting +} +\description{ +Clean wastewater data +} diff --git a/wweval/man/date_of_ww_data.Rd b/wweval/man/date_of_ww_data.Rd new file mode 100644 index 00000000..c2758974 --- /dev/null +++ b/wweval/man/date_of_ww_data.Rd @@ -0,0 +1,24 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/get_input_data.R +\name{date_of_ww_data} +\alias{date_of_ww_data} +\title{Date of wastewater data} +\usage{ +date_of_ww_data(forecast_date, ww_data_mapping, ww_data_dir) +} +\arguments{ +\item{forecast_date}{the forecast date for this iteration} + +\item{ww_data_mapping}{a string that tells this function how to pick +data pull dates from forecast dates. This function needs to be configured +for each new string.} + +\item{ww_data_dir}{A string indicating the path to the directory containing +time stamped wastewater datasets} +} +\value{ +the date to get the ww data from +} +\description{ +Date of wastewater data +} diff --git a/wweval/man/filter_sites_by_scenario.Rd b/wweval/man/filter_sites_by_scenario.Rd new file mode 100644 index 00000000..5eafd73d --- /dev/null +++ b/wweval/man/filter_sites_by_scenario.Rd @@ -0,0 +1,30 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/get_input_data.R +\name{filter_sites_by_scenario} +\alias{filter_sites_by_scenario} +\title{Filter sites by scenario} +\usage{ +filter_sites_by_scenario( + init_subset_nwss_data, + scenario = "Status quo", + scenario_dir = NA +) +} +\arguments{ +\item{init_subset_nwss_data}{a dataframe of the raw NWSS data +filtered to exclude solids and upstream sites} + +\item{scenario}{a string indicating what scenario we are running. +Default is "Status quo" which uses all the data we have available} + +\item{scenario_dir}{a string indicating the file path where the +scenario csvs will live. Default is NA because we don't need this +for the status quo scenario} +} +\value{ +a dataframe that only contains the ww data from +the sites in the list pertaining to the scenario +} +\description{ +Filter sites by scenario +} diff --git a/wweval/man/get_full_scores.Rd b/wweval/man/get_full_scores.Rd new file mode 100644 index 00000000..1229c8cd --- /dev/null +++ b/wweval/man/get_full_scores.Rd @@ -0,0 +1,33 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/score.R +\name{get_full_scores} +\alias{get_full_scores} +\title{Get the scores for ever day for a particular location and forecast date} +\usage{ +get_full_scores( + draws, + scenario, + metrics = c("crps", "dss", "bias", "mad", "ae_median", "se_mean"), + ... +) +} +\arguments{ +\item{draws}{a dataframe of the model estimated quantity you are evaluating +alongside the evaluation data} + +\item{scenario}{a string indicating the wastewater data scenario we're +running} + +\item{metrics}{Vector of scoring metrics to output, passed as the +\code{metrics} argument to \code{\link[scoringutils:score]{scoringutils::score()}}. Default +\code{c("crps", "dss", "bias", "mad", "ae_median", "se_mean")}.} + +\item{...}{Additional named arguments passed to \code{\link[scoringutils:score]{scoringutils::score()}}.} +} +\value{ +a dataframe containing a score for each day in the nowcast +and forecast period +} +\description{ +Get the scores for ever day for a particular location and forecast date +} diff --git a/wweval/man/get_hosp_data_sizes.Rd b/wweval/man/get_hosp_data_sizes.Rd new file mode 100644 index 00000000..d507182d --- /dev/null +++ b/wweval/man/get_hosp_data_sizes.Rd @@ -0,0 +1,48 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/format_data_for_stan.R +\name{get_hosp_data_sizes} +\alias{get_hosp_data_sizes} +\title{Get hospital data integer sizes for stan} +\usage{ +get_hosp_data_sizes( + input_hosp_data, + forecast_date, + forecast_time, + last_hosp_data_date, + uot, + hosp_value_col_name = "daily_hosp_admits" +) +} +\arguments{ +\item{input_hosp_data}{a dataframe with the input hospital admissions data} + +\item{forecast_date}{string indicating the forecast date} + +\item{forecast_time}{integer indicating the number of days to make a forecast +for} + +\item{last_hosp_data_date}{string indicating the date of the last observed +hospital admission} + +\item{uot}{integer indicating the time of model initialization when there are +no observations} + +\item{hosp_value_col_name}{A string represeting the name of the column in the +input_hosp-data that indicates the number of daily hospital admissions, +default is \code{daily_hosp_admits}} +} +\value{ +A list containing the integer sizes of the follow variables that +the stan model requires: +ht: integer indicating horizon time for the model(hospital admissions +nowcast + forecast time in days) +ot: integer indicating the total duration of time that the hospital admissions +model has available calibration data +oht: integer indicating the number of hospital admission observations +n_weeks: number of weeks (rounded up) that hospital admissions are generated +from the model +tot_weeks: number of week(rounded up) that infections are generated for +} +\description{ +Get hospital data integer sizes for stan +} diff --git a/wweval/man/get_hosp_indices.Rd b/wweval/man/get_hosp_indices.Rd new file mode 100644 index 00000000..fdbc4f7a --- /dev/null +++ b/wweval/man/get_hosp_indices.Rd @@ -0,0 +1,20 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/format_data_for_stan.R +\name{get_hosp_indices} +\alias{get_hosp_indices} +\title{Get hospital admissions indices} +\usage{ +get_hosp_indices(input_hosp_data) +} +\arguments{ +\item{input_hosp_data}{a dataframe with the input hospital admissions data} +} +\value{ +A list containing the vectors of indices that +the stan model requires: +hosp_times: a vector of integer times corresponding to the times when the +hospital admissions observations were made +} +\description{ +Get hospital admissions indices +} diff --git a/wweval/man/get_hosp_values.Rd b/wweval/man/get_hosp_values.Rd new file mode 100644 index 00000000..610c6a36 --- /dev/null +++ b/wweval/man/get_hosp_values.Rd @@ -0,0 +1,36 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/format_data_for_stan.R +\name{get_hosp_values} +\alias{get_hosp_values} +\title{Get hospital admissions values} +\usage{ +get_hosp_values( + input_hosp_data, + ot, + ht, + hosp_value_col_name = "daily_hosp_admits" +) +} +\arguments{ +\item{input_hosp_data}{a dataframe with the input hospital admissions data} + +\item{ot}{integer indicating the total duration of time that the hospital admissions +model has available calibration data in days} + +\item{ht}{integer indicating the number of days to produce hospital admissions +outside the calibration period (forecast + nowcast time) in days} + +\item{hosp_value_col_name}{A string represeting the name of the column in the +input_hosp-data that indicates the number of daily hospital admissions, +default is \code{daily_hosp_admits}} +} +\value{ +A list containing the necessary vectors of values that +the stan model requires: +hosp_admits: a vector of number of daily hospital admissions observations +day_of_week: a vector indicating the day of the week of each of the dates +in the calibration and forecast period +} +\description{ +Get hospital admissions values +} diff --git a/wweval/man/get_input_hosp_data.Rd b/wweval/man/get_input_hosp_data.Rd new file mode 100644 index 00000000..49b6249d --- /dev/null +++ b/wweval/man/get_input_hosp_data.Rd @@ -0,0 +1,37 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/get_input_data.R +\name{get_input_hosp_data} +\alias{get_input_hosp_data} +\title{Get input hospital admissions data} +\usage{ +get_input_hosp_data( + forecast_date_i, + location_i, + hosp_data_dir, + calibration_time, + load_from_covidcast = FALSE +) +} +\arguments{ +\item{forecast_date_i}{The forecast date for this iteration} + +\item{location_i}{The location (state) for this iteration} + +\item{hosp_data_dir}{A string indicating the path to the directory containing +time stamped hospital admissions datasets} + +\item{calibration_time}{A numeric indicating the duration of model +calibration (based on the last hospital admissions data point)} + +\item{load_from_covidcast}{boolean indicating whether or not the hospital +admissions datasets should be loaded directly from covidcast. +\code{default = FALSE} because we are assuming that we have already created a +folder with time stamped datasets} +} +\value{ +a dataframe containing the cleaned hospital admissions needed as +an input to the stan model for the specified forecast date and location +} +\description{ +Get input hospital admissions data +} diff --git a/wweval/man/get_input_ww_data.Rd b/wweval/man/get_input_ww_data.Rd new file mode 100644 index 00000000..fdde7e09 --- /dev/null +++ b/wweval/man/get_input_ww_data.Rd @@ -0,0 +1,53 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/get_input_data.R +\name{get_input_ww_data} +\alias{get_input_ww_data} +\title{Get input wastewater data} +\usage{ +get_input_ww_data( + forecast_date_i, + location_i, + scenario_i, + scenario_dir, + ww_data_dir, + calibration_time, + last_hosp_data_date, + ww_data_mapping +) +} +\arguments{ +\item{forecast_date_i}{The forecast date for this iteration, +formatted as a character string in IS08601 format (YYYY-MM-DD).} + +\item{location_i}{The location (state or other jurisdiction) +for this iteration, formatted as a string (uppercase USPS two-letter +abbreviation, e.g. AK for Alaska, DC for the District of Columbia, +PR for Puerto Rico).} + +\item{scenario_i}{The scenario for this iteration, formatted as a +string} + +\item{scenario_dir}{A string indicating the path to the directory +containing the csvs with the wwtp ids needed for each scenario} + +\item{ww_data_dir}{A string indicating the path to the directory +containing time stamped wastewater datasets} + +\item{calibration_time}{The duration of the model calibration period +(relative to the last hospital admissions data point) in units of +model timesteps (typically days).} + +\item{last_hosp_data_date}{A date indicating the date of last reported +hospital admission as of the forecast date} + +\item{ww_data_mapping}{A string indicating how to map the +forecast date to the wastewater dates (see \code{\link[=date_of_ww_data]{date_of_ww_data()}} +for more details)} +} +\value{ +a dataframe containing the transformed and clean NWSS data +at the site and lab label for the forecast date and location specified +} +\description{ +Get input wastewater data +} diff --git a/wweval/man/get_last_hosp_data_date.Rd b/wweval/man/get_last_hosp_data_date.Rd new file mode 100644 index 00000000..2055d93c --- /dev/null +++ b/wweval/man/get_last_hosp_data_date.Rd @@ -0,0 +1,18 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/get_input_data.R +\name{get_last_hosp_data_date} +\alias{get_last_hosp_data_date} +\title{Get last hospital admissions data point date} +\usage{ +get_last_hosp_data_date(input_hosp) +} +\arguments{ +\item{input_hosp}{the hospital admissions dataset for that location and +forecast date} +} +\value{ +a date indicating the last day of observed data +} +\description{ +Get last hospital admissions data point date +} diff --git a/wweval/man/get_model_draws_w_data.Rd b/wweval/man/get_model_draws_w_data.Rd new file mode 100644 index 00000000..df963c2c --- /dev/null +++ b/wweval/man/get_model_draws_w_data.Rd @@ -0,0 +1,55 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/post_processing.R +\name{get_model_draws_w_data} +\alias{get_model_draws_w_data} +\title{Get model draws combined with input and evaluation data} +\usage{ +get_model_draws_w_data( + model_output, + model_type, + draws, + forecast_date, + scenario, + location, + input_data, + eval_data, + last_hosp_data_date, + ot, + forecast_time = 28 +) +} +\arguments{ +\item{model_output}{the type of model expected observation you want, +options are "hosp" and "ww"} + +\item{model_type}{The type of model, options are "ww" and "hosp"} + +\item{draws}{The raw draws dataframe} + +\item{forecast_date}{The date the forecast was made} + +\item{scenario}{A name for the scenario that the input +data represents, as a string.} + +\item{location}{The location for which the model is being run.} + +\item{input_data}{The input dataset used for fitting the model.} + +\item{eval_data}{The retrospective dataset used to evaluate the model +(should have data beyond the forecast date)} + +\item{last_hosp_data_date}{The date of the last hospital admissions data point +that the model is calibrated to} + +\item{ot}{An integer indicating the number of days the model is calibrated +to hospital admissions data} + +\item{forecast_time}{An integer indicate the time in days of the forecast} +} +\value{ +a dataframe of model draws subsetted to only the specified output +type, joined with the evaluation data and the input calibration data +} +\description{ +Get model draws combined with input and evaluation data +} diff --git a/wweval/man/get_model_path.Rd b/wweval/man/get_model_path.Rd new file mode 100644 index 00000000..72a8a6f2 --- /dev/null +++ b/wweval/man/get_model_path.Rd @@ -0,0 +1,26 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/filepath_mapping.R +\name{get_model_path} +\alias{get_model_path} +\title{Get model path} +\usage{ +get_model_path(model_type, stan_models_dir) +} +\arguments{ +\item{model_type}{string specifying the model to be run, options are either +'hosp' for the hospital admissions only model or +'ww' for the site-level infection dyanmics model using wastewater} + +\item{stan_models_dir}{directory where stan files are located} +} +\value{ +string indicating path to correct stan file +} +\description{ +Get model path +} +\examples{ +model_path <- get_model_path("hosp", system.file("stan", + package = "cfaforecastrenewalww" +)) +} diff --git a/wweval/man/get_plot_hosp_data_comparison.Rd b/wweval/man/get_plot_hosp_data_comparison.Rd new file mode 100644 index 00000000..5120af2f --- /dev/null +++ b/wweval/man/get_plot_hosp_data_comparison.Rd @@ -0,0 +1,33 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/plots.R +\name{get_plot_hosp_data_comparison} +\alias{get_plot_hosp_data_comparison} +\title{Get plot of hospital admissions data compared to model draws} +\usage{ +get_plot_hosp_data_comparison( + draws_w_data, + location, + model_type, + n_draws = 100 +) +} +\arguments{ +\item{draws_w_data}{A long tidy dataframe containing draws from the model of +the estimated hospital admissions joined with both the data +the model was calibrated to and the later observed data for evaluating the +future predicted concentrations against.} + +\item{location}{the jursidiction the data is from} + +\item{model_type}{type of model the output is from, options are +"ww" or "hosp"} + +\item{n_draws}{number of draws to plot, default = 100} +} +\value{ +a ggplot object showing the model draws of hospital admissions +alongside the calibration and forecast data +} +\description{ +Get plot of hospital admissions data compared to model draws +} diff --git a/wweval/man/get_plot_ww_data_comparison.Rd b/wweval/man/get_plot_ww_data_comparison.Rd new file mode 100644 index 00000000..b17ed461 --- /dev/null +++ b/wweval/man/get_plot_ww_data_comparison.Rd @@ -0,0 +1,31 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/plots.R +\name{get_plot_ww_data_comparison} +\alias{get_plot_ww_data_comparison} +\title{Get plot of wastewater data compared to model draws} +\usage{ +get_plot_ww_data_comparison( + draws_w_data, + location, + model_type = "ww", + n_draws = 100 +) +} +\arguments{ +\item{draws_w_data}{A long tidy dataframe containing draws from the model of +the estimated wastewater concentrations in each site joined with both the data +the model was calibrated to and the later observed data for evaluating the +future predicted concentrations against.} + +\item{location}{the jursidiction the data is from} + +\item{model_type}{type of model the output is from, default is \code{ww}} + +\item{n_draws}{number of draws to plot, default = 100} +} +\value{ +a ggplot object faceted by site showing the draws +} +\description{ +Get plot of wastewater data compared to model draws +} diff --git a/wweval/man/get_stan_data_list.Rd b/wweval/man/get_stan_data_list.Rd new file mode 100644 index 00000000..d0746dee --- /dev/null +++ b/wweval/man/get_stan_data_list.Rd @@ -0,0 +1,78 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/format_data_for_stan.R +\name{get_stan_data_list} +\alias{get_stan_data_list} +\title{Get stan data} +\usage{ +get_stan_data_list( + model_type, + forecast_date, + forecast_time, + input_ww_data, + input_hosp_data, + generation_interval, + inf_to_hosp, + infection_feedback_pmf, + params, + compute_likelihood = 1, + ww_outlier_col_name = "flag_as_ww_outlier", + lod_col_name = "below_LOD", + ww_measurement_col_name = "ww", + ww_value_lod_col_name = "lod_sewage", + hosp_value_col_name = "daily_hosp_admits" +) +} +\arguments{ +\item{model_type}{string indicating which model we are getting data for +Options are \code{ww} or \code{hosp}} + +\item{forecast_date}{string indicating the forecast date} + +\item{forecast_time}{integer indicating the number of days to make a forecast +for} + +\item{input_ww_data}{a dataframe with the input wastewater data} + +\item{input_hosp_data}{a dataframe with the input hospital admissions data} + +\item{generation_interval}{a vector with a zero-truncated normalized pmf of +the generation interval} + +\item{inf_to_hosp}{a vector with a normalized pmf of the delay from infection +to hospital admissions} + +\item{infection_feedback_pmf}{a vector with a normalized pmf dictating the +delay of infection feedback} + +\item{params}{a dataframe of parameter names and numeric values} + +\item{compute_likelihood}{indicator variable telling stan whether or not to +compute the likelihood, default = \code{1}} + +\item{ww_outlier_col_name}{A string representing the name of the +column in the input_ww_data that provides a 0 if the data point is not an +outlier to be excluded from the model fit, or a 1 if it is to be excluded +default value is \code{flag_as_ww_outlier}} + +\item{lod_col_name}{A string representing the name of the +column in the input_ww_data that provides a 0 if the data point is not above +the LOD and a 1 if the data is below the LOD, default value is \code{below_LOD}} + +\item{ww_measurement_col_name}{A string representing the name of the column +in the input_ww_data that indicates the wastewater measurement value in +natural scale, default is \code{ww}} + +\item{ww_value_lod_col_name}{A string representing the name of the column +in the input_ww_data that indicates the value of the LOD in natural scale, +default is \code{lod_sewage}} + +\item{hosp_value_col_name}{A string represeting the name of the column in the +input_hosp-data that indicates the number of daily hospital admissions, +default is \code{daily_hosp_admits}} +} +\value{ +a list named variables to pass to stan +} +\description{ +Get stan data +} diff --git a/wweval/man/get_state_level_quantiles.Rd b/wweval/man/get_state_level_quantiles.Rd new file mode 100644 index 00000000..c77882f6 --- /dev/null +++ b/wweval/man/get_state_level_quantiles.Rd @@ -0,0 +1,19 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/post_processing.R +\name{get_state_level_quantiles} +\alias{get_state_level_quantiles} +\title{Get quantiles for state-level generated quantities} +\usage{ +get_state_level_quantiles(draws) +} +\arguments{ +\item{draws}{a dataframe containing all the draws from the model estimated +state-level quantities} +} +\value{ +a dataframe containing the quantile value for the quantiles +required for the Hub submission +} +\description{ +Get quantiles for state-level generated quantities +} diff --git a/wweval/man/get_subpop_data.Rd b/wweval/man/get_subpop_data.Rd new file mode 100644 index 00000000..55f8f60b --- /dev/null +++ b/wweval/man/get_subpop_data.Rd @@ -0,0 +1,28 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/format_data_for_stan.R +\name{get_subpop_data} +\alias{get_subpop_data} +\title{Get subpopulation data} +\usage{ +get_subpop_data(add_auxiliary_site, state_pop, pop_ww, n_ww_sites) +} +\arguments{ +\item{add_auxiliary_site}{Boolean indicating whether to add another +subpopulation in addition to the wastewater sites to estimate R(t) of} + +\item{state_pop}{The state population size} + +\item{pop_ww}{The population size in each of the wastewater sites} + +\item{n_ww_sites}{The number of wastewater sites} +} +\value{ +A list containing the necessary integers and vectors that stan +needs to estiamte infection dynamics for each subpopulation +} +\description{ +Get subpopulation data +} +\examples{ +subpop_data <- get_subpop_data(TRUE, 100000, c(1000, 500), 2) +} diff --git a/wweval/man/get_ww_data_indices.Rd b/wweval/man/get_ww_data_indices.Rd new file mode 100644 index 00000000..a855d4f4 --- /dev/null +++ b/wweval/man/get_ww_data_indices.Rd @@ -0,0 +1,40 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/format_data_for_stan.R +\name{get_ww_data_indices} +\alias{get_ww_data_indices} +\title{Get wastewater data indices} +\usage{ +get_ww_data_indices(ww_data, input_hosp_data, owt, lod_col_name = "below_LOD") +} +\arguments{ +\item{ww_data}{Input wastewater dataframe containing one row +per observation, with outliers already removed} + +\item{input_hosp_data}{Input hospital admissions data frame with one row +per day and location} + +\item{owt}{number of wastewater observations} + +\item{lod_col_name}{A string representing the name of the +column in the input_ww_data that provides a 0 if the data point is not above +the LOD and a 1 if the data is below the LOD, default value is \code{below_LOD}} +} +\value{ +A list containing the necessary vectors of indices that +the stan model requires: +ww_censored: the vector of time points that the wastewater observations are +censored (below the LOD) in order of the date and the site index +ww_uncensored: the vector of time points that the wastewater observations are +uncensored (above the LOD) in order of the date and the site index +ww_sampled_times: the vector of time points that the wastewater observations +are passed in in log_conc in order of the date and the site index +ww_sampled_sites: the vector of sites that correspond to the observations +passed in in log_conc in order of the date and the site index +ww_sampled_lab_sites: the vector of unique combinations of site and labs +that correspond to the observations passed in in log_conc in order of the +date and the site index +lab_site_to_site_map: the vector of sites that correspond to each lab-site +} +\description{ +Get wastewater data indices +} diff --git a/wweval/man/get_ww_data_sizes.Rd b/wweval/man/get_ww_data_sizes.Rd new file mode 100644 index 00000000..fec18968 --- /dev/null +++ b/wweval/man/get_ww_data_sizes.Rd @@ -0,0 +1,28 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/format_data_for_stan.R +\name{get_ww_data_sizes} +\alias{get_ww_data_sizes} +\title{Get the integer sizes of the wastewater input data} +\usage{ +get_ww_data_sizes(ww_data, lod_col_name = "below_LOD") +} +\arguments{ +\item{ww_data}{Input wastewater dataframe containing one row +per observation, with outliers already removed} + +\item{lod_col_name}{A string representing the name of the +column in the input_ww_data that provides a 0 if the data point is not above +the LOD and a 1 if the data is below the LOD, default value is \code{below_LOD}} +} +\value{ +A list containing the integer sizes of the follow variables that +the stan model requires: +owt: number of wastewater observations +n_censored: number of censored wastewater observations (below the LOD) +n_uncensored: number of uncensored wastewter observations (above the LOD) +n_ww_sites: number of wastewater sites +n_ww_lab_sites: number of unique wastewater site-lab combinations +} +\description{ +Get the integer sizes of the wastewater input data +} diff --git a/wweval/man/get_ww_values.Rd b/wweval/man/get_ww_values.Rd new file mode 100644 index 00000000..f8af5716 --- /dev/null +++ b/wweval/man/get_ww_values.Rd @@ -0,0 +1,45 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/format_data_for_stan.R +\name{get_ww_values} +\alias{get_ww_values} +\title{Get wastewater data values} +\usage{ +get_ww_values( + ww_data, + ww_measurement_col_name = "ww", + ww_lod_value_col_name = "lod_sewage", + ww_site_pop_col_name = "ww_pop", + one_pop_per_site = TRUE +) +} +\arguments{ +\item{ww_data}{Input wastewater dataframe containing one row +per observation, with outliers already removed} + +\item{ww_measurement_col_name}{A string representing the name of the column +in the input_ww_data that indicates the wastewater measurement value in +natural scale, default is \code{ww}} + +\item{ww_lod_value_col_name}{A string representing the name of the column +in the ww_data that indicates the value of the LOD in natural scale, +default is \code{lod_sewage}} + +\item{ww_site_pop_col_name}{A string representing the name of the column in +the ww_data that indicates the number of people represented by that wastewater +catchment} + +\item{one_pop_per_site}{a boolean variable indicating if there should only +be on catchment area population per site, default is \code{TRUE} bc this is what +the stan model expects} +} +\value{ +A list containing the necessary vectors of values that +the stan model requires: +ww_lod: a vector of the LODs of the corresponding wastewater measurement +pop_ww: a vector of the population sizes of the wastewater catchment areas +in order of the sites by site_index +log_conc: a vector of the log of the wastewater concentration observation +} +\description{ +Get wastewater data values +} diff --git a/wweval/man/make_df.Rd b/wweval/man/make_df.Rd new file mode 100644 index 00000000..fb48a58e --- /dev/null +++ b/wweval/man/make_df.Rd @@ -0,0 +1,19 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/set_up.R +\name{make_df} +\alias{make_df} +\title{Make dataframe} +\usage{ +make_df(config) +} +\arguments{ +\item{config}{a config file with a list of locations, forecast_dates, and +model types} +} +\value{ +an expanded dataframe that contains 3 columns with all combinations +of locations, forecast_dates and model types +} +\description{ +Make dataframe +} diff --git a/wweval/man/sample_model.Rd b/wweval/man/sample_model.Rd new file mode 100644 index 00000000..e5f1084d --- /dev/null +++ b/wweval/man/sample_model.Rd @@ -0,0 +1,45 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/sample_model.R +\name{sample_model} +\alias{sample_model} +\title{Fit the model} +\usage{ +sample_model( + standata, + compiled_model, + init_lists, + iter_warmup = 250, + iter_sampling = 250, + max_treedepth = 12, + adapt_delta = 0.95, + n_chains = 4, + seed = 123 +) +} +\arguments{ +\item{standata}{a list of elements to pass to stan} + +\item{compiled_model}{the compiled model object} + +\item{init_lists}{nested list of initial parameter values for each chain} + +\item{iter_warmup}{number of iterations to save in MCMC sampling, +default = 250} + +\item{iter_sampling}{number of iterations to save in MCMC sampling, +default = 250} + +\item{max_treedepth}{maximum treedepth of MCMC sampling, defauly = 12} + +\item{adapt_delta}{MCMC accaptance probability, default = 0.95} + +\item{n_chains}{number of independent MCMC chains to run, default = 4} + +\item{seed}{seed of random number generator default = 123} +} +\value{ +a list containing draws, diagnostics, and summary_diagnostics +} +\description{ +Fit the model +} diff --git a/wweval/man/wweval-package.Rd b/wweval/man/wweval-package.Rd new file mode 100644 index 00000000..322a1710 --- /dev/null +++ b/wweval/man/wweval-package.Rd @@ -0,0 +1,37 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/wweval-package.R +\docType{package} +\name{wweval-package} +\alias{wweval} +\alias{wweval-package} +\title{wweval: Evaluation of wastewater informed COVID-19 hospital admissions forecasting} +\description{ +This package provides helper functions designed to preprocess, postprocess, and analyze outputs from fitting a semi-mechanistic renewal model with and without incorporating wastewater. +} +\seealso{ +Useful links: +\itemize{ + \item \url{https://github.com/cdcgov/wastewater-informed-covid-forecasting/} + \item Report bugs at \url{https://github.com/cdcgov/wastewater-informed-covid-forecasting/issues/} +} + +} +\author{ +\strong{Maintainer}: Kaitlyn Johnson \email{uox1@cdc.gov} (\href{https://orcid.org/0000-0001-8011-0012}{ORCID}) + +Authors: +\itemize{ + \item Sam Abbott \email{contact@samabbott.co.uk} (\href{https://orcid.org/0000-0001-8057-8037}{ORCID}) + \item Zachary Susswein \email{utb2@cdc.gov} + \item Andrew Magee \email{rzg0@cdc.gov} + \item Dylan Morris \email{dylan@dylanhmorris.com} (\href{https://orcid.org/0000-0002-3655-406X}{ORCID}) + \item Scott Olesen \email{ulp7@cdc.gov} +} + +Other contributors: +\itemize{ + \item George Vega Yon \email{g.vegayon@gmail.com} (\href{https://orcid.org/0000-0002-3171-0844}{ORCID}) [contributor] +} + +} +\keyword{internal} diff --git a/wweval/tests/testthat/test_filepath_mapping.R b/wweval/tests/testthat/test_filepath_mapping.R new file mode 100644 index 00000000..4c07fa79 --- /dev/null +++ b/wweval/tests/testthat/test_filepath_mapping.R @@ -0,0 +1,35 @@ +test_that("get_model_path returns correct paths", { + # Assume stan_models_dir is a directory containing your .stan files + stan_models_dir <- tempdir() + + # Create fake .stan files for testing purposes + file.create(file.path(stan_models_dir, "renewal_ww_hosp_site_level_inf_dynamics.stan")) + file.create(file.path(stan_models_dir, "renewal_ww_hosp.stan")) + + + # Test for 'ww' model type + expect_equal( + get_model_path("ww", stan_models_dir), + file.path(stan_models_dir, "renewal_ww_hosp_site_level_inf_dynamics.stan") + ) + + # Test for 'hosp' model type + expect_equal( + get_model_path("hosp", stan_models_dir), + file.path(stan_models_dir, "renewal_ww_hosp.stan") + ) +}) + +test_that("get_model_path throws an error for invalid model types", { + # Test for invalid model type + expect_error( + get_model_path("invalid_type", stan_models_dir), + "Model type is not specified properly" + ) + + # Test for NULL (missing) model type + expect_error( + get_model_path(NULL, stan_models_dir), + "Model type is empty" + ) +}) diff --git a/wweval/tests/testthat/test_get_input_data.R b/wweval/tests/testthat/test_get_input_data.R new file mode 100644 index 00000000..59a76905 --- /dev/null +++ b/wweval/tests/testthat/test_get_input_data.R @@ -0,0 +1,125 @@ +test_that("Test data_of_ww_data returns the correct date to pull the wastewater data", { + create_fake_ww_files <- function(temp_dir, dates) { + file_paths <- file.path(temp_dir, paste0(dates, ".csv")) + file.create(file_paths) + } + # Create fake wastewater data files for specific dates + sample_dates <- c("2024-03-16", "2024-03-17", "2024-03-18") + temp_dir <- tempdir() + create_fake_ww_files(temp_dir, sample_dates) + + + # Test case where ww_data_mapping is NULL + forecast_date <- "2024-03-18" + expect_equal( + date_of_ww_data(forecast_date, NULL, temp_dir), + "2024-03-17" # Most recent date + ) + + # Test case where ww_data_mapping has a specific rule (Monday: Monday, Wednesday: Monday) + forecast_date <- "2024-03-20" # Wednesday + expect_equal( + date_of_ww_data(forecast_date, "Monday: Monday, Wednesday: Monday", temp_dir), + "2024-03-18" # Previous Monday + ) + + forecast_date <- "2024-03-18" # Monday + expect_equal( + date_of_ww_data(forecast_date, "Monday: Monday, Wednesday: Monday", temp_dir), + "2024-03-18" # Same day (Monday) + ) + + # Test case where ww_data_mapping is something else and should return NA + forecast_date <- "2024-03-18" + expect_error( + date_of_ww_data(forecast_date, "Some other mapping", temp_dir), + "Need to write case to specify which wastewater data to pull" + ) + + # Test case where forecast date is not a Monday or a Wednesday + forecast_date <- "2024-03-19" + expect_error( + date_of_ww_data(forecast_date, "Monday: Monday, Wednesday: Monday", temp_dir), + "Forecast date is not a Monday or Wednesday" + ) +}) + + +test_that("Last hospital admissions data is returned properly", { + fake_df <- data.frame( + date = seq( + from = ymd("2024-03-01"), + to = ymd("2024-03-10"), + by = "days" + ), + daily_hosp_admits = sample.int(10, + size = 10, + replace = TRUE + ) + ) + returned_last_hosp_data_date <- get_last_hosp_data_date(fake_df) + # Test the fake df returns what we'd expect + testthat::expect_equal( + returned_last_hosp_data_date, + ymd("2024-03-10") + ) +}) + + +test_that("clean_ww_data correctly cleans and renames columns", { + fake_nwss_subset <- tibble::tibble( + sample_collect_date = as.Date(c("2021-01-01", "2021-01-02")), + pcr_target_avg_conc = c(100, 200), + population_served = c(50000, 60000), + wwtp_jurisdiction = c("ak", "al"), + wwtp_name = c(132, 142), + lab_id = c(5, 5), + lab_wwtp_unique_id = c(1, 2), + below_LOD = c(FALSE, FALSE), + lod_sewage = c(10, 20) + ) + + # Expected output after cleaning + expected_nwss_output <- tibble::tibble( + date = as.Date(c("2021-01-01", "2021-01-02")), + location = c("AK", "AL"), + ww = c(100, 200), + site = c(132, 142), + lab = c(5, 5), + lab_wwtp_unique_id = c(1, 2), + ww_pop = c(50000, 60000), + below_LOD = c(FALSE, FALSE), + lod_sewage = c(10, 20) + ) + # Apply the cleaning function to our sample data + cleaned_data <- clean_ww_data(fake_nwss_subset) + + # Check if the cleaned data matches our expected output + expect_equal(cleaned_data, expected_nwss_output) + + # Check if all expected columns are present + expect_true(all(names(expected_nwss_output) %in% names(cleaned_data))) + + # Check if 'location' and 'site' columns are correctly transformed to uppercase + expect_true(all(toupper(fake_nwss_subset$wwtp_jurisdiction) == cleaned_data$location)) + expect_true(all(cleaned_data$site == fake_nwss_subset$wwtp_name)) + + # Check if the date column is renamed correctly + expect_equal(cleaned_data$date, fake_nwss_subset$sample_collect_date) + + # Check if the ww (wastewater) column is renamed correctly + expect_equal(cleaned_data$ww, fake_nwss_subset$pcr_target_avg_conc) + + # Check if the ww_pop (population served) column is renamed correctly + expect_equal(cleaned_data$ww_pop, fake_nwss_subset$population_served) + + # Check for correct renaming of lab_id to lab + expect_equal(cleaned_data$lab, fake_nwss_subset$lab_id) + + # Ensure no extra columns are present in the cleaned data + expected_colnames <- c( + "date", "location", "ww", "site", "lab", + "lab_wwtp_unique_id", "ww_pop", "below_LOD", "lod_sewage" + ) + expect_equal(sort(names(cleaned_data)), sort(expected_colnames)) +}) diff --git a/wweval/tests/testthat/test_get_state_level_quantiles.R b/wweval/tests/testthat/test_get_state_level_quantiles.R new file mode 100644 index 00000000..e867f1fb --- /dev/null +++ b/wweval/tests/testthat/test_get_state_level_quantiles.R @@ -0,0 +1,64 @@ +sample_draws <- data.frame( + date = as.Date("2023-01-01") + 0:4, + draw = rnorm(5), + location = rep("Location1", times = 5), + value = rnorm(5), + name = rep("Name1", times = 5), + calib_data = c(3, 4, NA, NA, NA), + forecast_date = as.Date("2023-01-03"), + scenario = rep("Status quo", 5), + model_type = rep("ww", 5) +) + + +# Test case: Check if the output has expected columns after processing +test_that("get_state_level_quantiles returns expected columns", { + result <- get_state_level_quantiles(sample_draws) + + expected_columns <- c("quantile", "value", "name", "period") + + expect_true(all(expected_columns %in% names(result))) +}) + +# Test case: Check if period column is calculated correctly +test_that("Period column is calculated correctly in get_state_level_quantiles", { + result <- get_state_level_quantiles(sample_draws) + + # Check if 'period' column has correct values based on 'date' and 'forecast_date' + expect_true(all(result$period %in% c("calibration", "nowcast", "forecast"))) + + # Ensure that rows with date <= forecast_date without calibration data are labeled as "nowcast" + expect_true(all(result$period[result$date <= sample_draws$forecast_date & # nolint + is.na(result$calib_data)] == "nowcast")) # nolint + + # Ensure that rows with date > forecast_date are labeled as "forecast" + expect_true(all(result$period[result$date > sample_draws$forecast_date] == "forecast")) +}) + +# Test case: Check if the join preserves all unique combinations of 't' and 'name' +test_that("Join operation in get_state_level_quantiles preserves uniqueness", { + result <- get_state_level_quantiles(sample_draws) + + unique_combinations <- unique(sample_draws[c("date", "name")]) + + result_combinations <- unique(result[c("date", "name")]) + + expect_equal(nrow(unique_combinations), nrow(result_combinations)) +}) + +# Test case: Check if quantile levels are correctly assigned +test_that("Quantile levels are correctly assigned in get_state_level_quantiles", { + result <- get_state_level_quantiles(sample_draws) + + # Assuming trajectories_to_quantiles returns a fixed set of quantile levels + expected_quantile_levels <- c(0.025, 0.25, 0.5, 0.75, 0.975) + + expect_true(all(expected_quantile_levels %in% round(result$quantile, 4))) +}) + +# Test case: Check that there are no NAs in the date values +test_that("There aren't NAs where there shouldn't be, all dates compelte", { + result <- get_state_level_quantiles(sample_draws) + + expect_true(!any(is.na(result$date))) +}) diff --git a/wweval/tests/testthat/test_get_ww_stan_data.R b/wweval/tests/testthat/test_get_ww_stan_data.R new file mode 100644 index 00000000..c9505c10 --- /dev/null +++ b/wweval/tests/testthat/test_get_ww_stan_data.R @@ -0,0 +1,215 @@ +# Sample data to use in tests +sample_ww_data <- dplyr::tibble( + date = as.Date(c("2020-12-30", "2021-01-02", "2021-01-03", "2021-01-04")), + site_index = c(1, 1, 2, 2), + lab_site_index = c(1, 2, 3, 3), + below_LOD = c(0, 1, 0, 1), + ww = c(100, 200, 150, 250), + lod_sewage = c(10, 20, 15, 25), + ww_pop = c(5000, 5500, 3000, 3000), + other_col = c(10, 20, 30, 40) +) + +# Create sample hospital admissions data +sample_hosp_data <- dplyr::tibble( + date = as.Date(c("2020-12-30", "2020-12-31", "2021-01-01")) +) + + +# Test that the function returns correct counts +test_that("Function returns correct counts", { + result <- get_ww_data_sizes(sample_ww_data) + + expect_equal(result$owt, nrow(sample_ww_data)) + expect_equal(result$n_censored, sum(sample_ww_data$below_LOD == 1)) + expect_equal(result$n_uncensored, sum(sample_ww_data$below_LOD == 0)) +}) + +# Test that function handles error properly when LOD column is missing +test_that("Error is thrown when LOD column is missing", { + expect_error( + get_ww_data_sizes(sample_ww_data, "nonexistent_column"), + "LOD column name isn't present in input dataset" + ) +}) + +# Test that function works with different LOD column names +test_that("Function works with different LOD column names", { + # Rename below_LOD to new_LOD_col for testing purposes + renamed_sample <- dplyr::rename(sample_ww_data, new_LOD_col = below_LOD) + + result <- get_ww_data_sizes(renamed_sample, "new_LOD_col") + + expect_equal(result$n_censored, sum(renamed_sample$new_LOD_col == 1)) +}) + +# Test that number of unique sites and lab_sites are calculated correctly +test_that("Number of unique sites and lab_sites are calculated correctly", { + result <- get_ww_data_sizes(sample_ww_data) + + expect_equal(result$n_ww_sites, length(unique(sample_ww_data$site_index))) + expect_equal(result$n_ww_lab_sites, length(unique(sample_ww_data$lab_site_index))) +}) + +# Test that the function returns a list with the correct names +test_that("Function returns a list with correct names", { + result <- get_ww_data_sizes(sample_ww_data) + + expected_names <- c("owt", "n_censored", "n_uncensored", "n_ww_sites", "n_ww_lab_sites") + expect_equal(names(result), expected_names) +}) + + + + +# Test that function returns correct indices +test_that("Function returns correct indices", { + result <- get_ww_data_indices(sample_ww_data, + sample_hosp_data, + owt = nrow(sample_ww_data) + ) + + expect_equal(result$ww_censored, which(sample_ww_data$below_LOD == 1)) + expect_equal(result$ww_uncensored, which(sample_ww_data$below_LOD == 0)) +}) + +# Test that function throws an error when owt does not match expected length +test_that("Error is thrown when owt does not match expected length", { + expect_error( + get_ww_data_indices(sample_ww_data, + sample_hosp_data, + owt = nrow(sample_ww_data) + 1 + ), + "Length of censored vectors incorrect" + ) +}) + +# Test that sampled times are calculated correctly +test_that("Sampled times are calculated correctly", { + result <- get_ww_data_indices(sample_ww_data, + sample_hosp_data, + owt = nrow(sample_ww_data) + ) + + expected_times <- data.frame( + date = seq( + from = min(sample_hosp_data$date), + to = max(sample_ww_data$date), by = "days" + ), + t = 1:as.integer(max(sample_ww_data$date) - + min(sample_hosp_data$date) + 1) # nolint + ) + ww_data <- sample_ww_data |> + dplyr::left_join(expected_times, by = "date") + expected_t <- ww_data$t + + expect_equal(result$ww_sampled_times, expected_t) +}) + +# Test that sampled sites and lab-sites indices are correct +test_that("Sampled sites and lab-sites indices are correct", { + result <- get_ww_data_indices(sample_ww_data, + sample_hosp_data, + owt = nrow(sample_ww_data) + ) + + expect_equal(result$ww_sampled_sites, sample_ww_data$site_index) + expect_equal(result$ww_sampled_lab_sites, sample_ww_data$lab_site_index) +}) + +# Test that lab-site to site map is correct +test_that("Lab-site to site map is correct", { + result <- get_ww_data_indices(sample_ww_data, + sample_hosp_data, + owt = nrow(sample_ww_data) + ) + + # Create the expected mapping manually for the test case + set_mapping <- c(1, 1, 2) + + expect_equal(result$lab_site_to_site_map, set_mapping) +}) + +# Test that function returns correct log LOD values +test_that("Function returns correct log LOD values", { + result <- get_ww_values(sample_ww_data) + + expected_lod <- log(sample_ww_data$lod_sewage) + expect_equal(result$ww_lod, expected_lod) +}) + +# Test that population averages are calculated correctly when one_pop_per_site is TRUE +test_that("Population averages are calculated correctly for one_pop_per_site = TRUE", { + result <- get_ww_values(sample_ww_data) + + expected_pop_avg <- sample_ww_data |> + dplyr::group_by(site_index) |> + dplyr::summarise(pop_avg = mean(ww_pop)) |> + dplyr::pull(pop_avg) + + expect_equal(result$pop_ww, expected_pop_avg) +}) + +# Test that population vector is returned correctly when one_pop_per_site is FALSE +test_that("Population vector is returned correctly for one_pop_per_site = FALSE", { + result <- get_ww_values(sample_ww_data, + one_pop_per_site = FALSE + ) + + expect_equal(result$pop_ww, sample_ww_data$ww_pop) +}) + +# Test that function returns correct log concentration values +test_that("Function returns correct log concentration values", { + result <- get_ww_values(sample_ww_data) + + # Adding a small constant to avoid taking log of zero + expected_log_conc <- log(sample_ww_data$ww + 1e-8) + + expect_equal(result$log_conc, expected_log_conc) +}) + +# Test that function handles different measurement column names +test_that("Function handles different measurement column names", { + # Rename 'ww' to 'new_ww_col' for testing purposes + renamed_sample <- dplyr::rename(sample_ww_data, new_ww_col = ww) + + result <- get_ww_values(renamed_sample, + ww_measurement_col_name = "new_ww_col" + ) + + expected_log_conc <- log(renamed_sample$new_ww_col + 1e-8) + + expect_equal(result$log_conc, expected_log_conc) +}) + +test_that("Function handles different LOD value column names", { + # Rename 'lod_sewage' to 'new_lod_sewage' for testing purposes + renamed_sample <- dplyr::rename(sample_ww_data, new_lod_sewage = lod_sewage) + + result <- get_ww_values(renamed_sample, + ww_lod_value_col_name = "new_lod_sewage" + ) + + expected_lod <- log(renamed_sample$new_lod_sewage) + + expect_equal(result$ww_lod, expected_lod) +}) + +# Test that function handles different population column names +test_that("Function handles different population column names", { + # Rename 'ww_pop' to 'new_ww_pop' for testing purposes + renamed_sample <- dplyr::rename(sample_ww_data, new_ww_pop = ww_pop) + + result_true <- get_ww_values(renamed_sample, + ww_site_pop_col_name = "new_ww_pop", + one_pop_per_site = TRUE + ) + + expected_pop_avg_true <- renamed_sample |> + dplyr::group_by(site_index) |> + dplyr::summarise(pop_avg = mean(new_ww_pop)) |> + dplyr::pull(pop_avg) + + expect_equal(result_true$pop_ww, expected_pop_avg_true) +}) diff --git a/wweval/tests/testthat/test_sample_model.R b/wweval/tests/testthat/test_sample_model.R new file mode 100644 index 00000000..96c33159 --- /dev/null +++ b/wweval/tests/testthat/test_sample_model.R @@ -0,0 +1,58 @@ +stan_program <- " +data { + int n; + vector[n] y; +} +parameters { + real mu; +} +model { + y ~ normal(mu, 1); +} +" + +stanfile <- cmdstanr::write_stan_file(stan_program) +stanmodel <- cmdstanr::cmdstan_model(stanfile) + +test_that(paste0( + "sample_model errors if given something other ", + "than a CmdStanModel object as `compiled_model`" +), { + expect_error( + sample_model( + list(), + c(1, 2, 3), + NULL + ), + "must be a cmdstanr::CmdStanModel object" + ) +}) + +test_that(paste0( + "sample_model gives a different error if passed a valid ", + "CmdStanModel but not the needed data to sample from it" +), { + expect_error( + quiet(suppressWarnings(sample_model( + list(), + stanmodel, + NULL + ))), + "No chains finished successfully" + ) +}) + +test_that(paste0( + "sample_model works if passed a valid model" +), { + expect_no_error( + quiet(sample_model( + list( + n = 5, + y = c(2.52, 2.1, 1.7, 1.2, 1.8) + ), + stanmodel, + NULL + )) + ) +})