diff --git a/.github/workflows/pkgdown.yaml b/.github/workflows/pkgdown.yaml index 4360c1b4..f6ae8c84 100644 --- a/.github/workflows/pkgdown.yaml +++ b/.github/workflows/pkgdown.yaml @@ -23,6 +23,8 @@ jobs: contents: write id-token: write pages: write + outputs: + page_artifact_id: ${{ steps.upload-artifact.outputs.artifact_id }} steps: - uses: actions/checkout@v4 @@ -52,6 +54,7 @@ jobs: shell: Rscript {0} - name: Upload artifact for GH pages deployment + id: upload-artifact uses: actions/upload-pages-artifact@v3 with: path: "docs/" @@ -72,3 +75,33 @@ jobs: steps: - name: Deploy to GitHub pages uses: actions/deploy-pages@v4 + + post-page-artifact: + # only comment on PRs + if: ${{ github.event_name == 'pull_request' }} + needs: build + runs-on: ubuntu-latest + permissions: + contents: read + pull-requests: write + env: + GH_TOKEN: ${{ github.token }} + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Find Comment + uses: peter-evans/find-comment@v3 + id: fc + with: + issue-number: ${{ github.event.pull_request.number }} + comment-author: 'github-actions[bot]' + body-includes: Your page is ready to preview + + - name: Create or update comment + uses: peter-evans/create-or-update-comment@v4 + with: + comment-id: ${{ steps.fc.outputs.comment-id }} + issue-number: ${{ github.event.pull_request.number }} + body: | + Thank you for your contribution, @${{ github.triggering_actor }} :rocket:! Your page is ready to preview [here](https://github.com/${{github.repository}}/actions/runs/${{ github.run_id }}/artifacts/${{ needs.build.outputs.page_artifact_id }}) + edit-mode: replace diff --git a/.github/workflows/r-cmd-check.yaml b/.github/workflows/r-cmd-check.yaml index 65f3caee..c35d2481 100644 --- a/.github/workflows/r-cmd-check.yaml +++ b/.github/workflows/r-cmd-check.yaml @@ -7,13 +7,16 @@ on: jobs: check-package: - runs-on: ubuntu-latest + strategy: + matrix: + r-version: ["4.1.0", "release"] + os: [windows-latest, ubuntu-latest] + runs-on: ${{matrix.os}} steps: - uses: actions/checkout@v4 - - uses: actions/checkout@v4 - uses: r-lib/actions/setup-r@v2 with: - r-version: "release" + r-version: ${{matrix.r-version}} use-public-rspm: true extra-repositories: "https://mc-stan.org/r-packages/" - name: "Set up dependencies for wwinference" @@ -24,6 +27,7 @@ jobs: uses: epinowcast/actions/install-cmdstan@v1 with: cmdstan-version: "latest" + num-cores: 2 - name: "Check wwinference package" uses: r-lib/actions/check-r-package@v2 with: diff --git a/DESCRIPTION b/DESCRIPTION index 29e66080..79d31046 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: wwinference Title: Jointly infers infection dynamics from wastewater data and epidemiological indicators -Version: 0.0.0.9000 +Version: 0.1.0 Authors@R: c( person(given = "Kaitlyn", family = "Johnson", @@ -23,7 +23,7 @@ Authors@R: c( email = "xuk0@cdc.gov"), person(given = "George", family = "Vega Yon", - role = c("ctb"), + role = c("aut"), email = "g.vegayon@gmail.com", comment = c(ORCID = "0000-0002-3171-0844")), person(given = "Damon", @@ -37,7 +37,25 @@ Authors@R: c( person(given = "Scott", family = "Olesen", role = c("aut"), - email = "ulp7@cdc.gov") + email = "ulp7@cdc.gov"), + person(given = "Adam", + family = "Howes", + role = c("ctb"), + email = "adamthowes@gmail.com", + comment = c(ORCID = "0000-0003-2386-4031")), + person(given = "Chirag", + family = "Kumar", + role = c("ctb"), + email = "kzs9@cdc.gov"), + person(given = "Alexander", + family = "Keyel", + role = c("ctb"), + email = "alexander.keyel@health.ny.gov", + comment = c(ORCID = "000-0001-5256-6274")), + person(given = "Hannah", + family = "Cohen", + role = c("ctb"), + email = "llg4@cdc.gov") ) Description: An implementation of a hierarchical semi-mechanistic renewal approach jointly calibrating to multiple wastewater concentrations datasets from @@ -54,7 +72,7 @@ License: Apache License (>= 2) URL: https://github.com/cdcgov/ww-inference-model/, https://cdcgov.github.io/ww-inference-model/ BugReports: https://github.com/cdcgov/ww-inference-model/issues/ Depends: - R (>= 4.3.0) + R (>= 4.1.0) SystemRequirements: CmdStan (>=2.35.0) Encoding: UTF-8 Roxygen: list(markdown = TRUE) diff --git a/NAMESPACE b/NAMESPACE index 6d110162..abe1299b 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -1,11 +1,13 @@ # Generated by roxygen2: do not edit by hand -S3method(get_draws_df,data.frame) -S3method(get_draws_df,default) -S3method(get_draws_df,wwinference_fit) +S3method(get_draws,data.frame) +S3method(get_draws,default) +S3method(get_draws,wwinference_fit) S3method(get_model_diagnostic_flags,default) S3method(get_model_diagnostic_flags,wwinference_fit) +S3method(plot,wwinference_fit_draws) S3method(print,wwinference_fit) +S3method(print,wwinference_fit_draws) S3method(summary,wwinference_fit) export(add_pmfs) export(add_time_indexing) @@ -20,10 +22,14 @@ export(generate_simulated_data) export(get_count_data_sizes) export(get_count_indices) export(get_count_values) +export(get_date_time_spine) +export(get_draws) export(get_draws_df) export(get_ind_m) export(get_input_count_data_for_stan) export(get_input_ww_data_for_stan) +export(get_lab_site_site_spine) +export(get_lab_site_subpop_spine) export(get_mcmc_options) export(get_model_diagnostic_flags) export(get_model_spec) @@ -32,11 +38,10 @@ export(get_plot_forecasted_counts) export(get_plot_global_rt) export(get_plot_subpop_rt) export(get_plot_ww_conc) +export(get_site_subpop_spine) export(get_stan_data) -export(get_subpop_data) -export(get_ww_data_indices) export(get_ww_data_sizes) -export(get_ww_values) +export(get_ww_indices_and_values) export(independence_corr_func_r) export(indicate_ww_exclusions) export(parameter_diagnostics) @@ -44,6 +49,7 @@ export(preprocess_count_data) export(preprocess_ww_data) export(rand_corr_matrix_func) export(spatial_rt_process) +export(summary_diagnostics) export(to_simplex) export(validate_paramlist) export(wwinference) diff --git a/NEWS.md b/NEWS.md index 42c158eb..49dcb767 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,4 @@ -# wwinference 0.0.1 (dev) - +# wwinference 0.1.0 This is the first major release, focused on providing an initial version of the package. Note the package is still flagged as in development, though the authors plan on using it for production work in the coming weeks. diff --git a/R/checkers.R b/R/checkers.R index 4e28959b..79a43f94 100644 --- a/R/checkers.R +++ b/R/checkers.R @@ -13,19 +13,26 @@ #' @param date_vector vector of dates #' @param max_date string indicating the maximum date in ISO8601 convention #' e.g. YYYY-MM-DD +#' @param arg_dates string to print the name of the data you are checking the +#' dates for +#' @param arg_max_date string to print the name of the maximum date you are +#' checkign the data for #' @param call Calling environment to be passed to [cli::cli_abort()] for #' traceback. #' #' @return NULL, invisibly assert_no_dates_after_max <- function(date_vector, - max_date, call = rlang::caller_env()) { + max_date, + arg_dates = "y", + arg_max_date = "x", + call = rlang::caller_env()) { if (max(date_vector) > max_date) { cli::cli_abort( c( - "The data passed in has observations beyond the specified", - "maximum date. Either this is the incorrect vintaged", - "data, or the data needs to be filtered to only contain", - "observations before the maximum date" + "The {.arg_dates {arg_dates}} passed in has observations after the ", + "specified {.arg_max_date {arg_max_date}}. Check that this is the ", + "dataset you intended to use with the given ", + "{.arg_max_date {arg_max_date}}." ), call = call, class = "wwinference_input_data_error" @@ -211,6 +218,46 @@ assert_no_repeated_elements <- function(x, arg = "x", invisible() } +#' Check a set of columns in a data frame uniquely identify +#' data frame rows. +#' +#' @description +#' Equivalently, this checks that when grouping by the columns in question, +#' each group has a single entry +#' +#' @param df the dataframe to check +#' @param unique_key_columns Columns that, taken together, should +#' uniquely identify a row in the data frame. +#' @param arg the name of the unique grouping to check +#' @param call Calling environment to be passed to [cli::cli_abort()] for +#' traceback. +#' @param add_err_msg string containing an additional error message, +#' default is the empty string (`""`) +#' +#' @return NULL, invisibly +assert_cols_det_unique_row <- function(df, + unique_key_columns, + arg = "x", + call = rlang::caller_env(), + add_err_msg = "") { + duplicated_rows <- df |> dplyr::filter(dplyr::n() > 1, + .by = {{ unique_key_columns }} + ) + + if (nrow(duplicated_rows) != 0) { + cli::cli_abort( + c("The data has more than one observation per {.arg {arg}}", + add_err_msg, + "i" = "Multiple observations in a {.arg {arg}} are not", + "currently supported." + ), + call = call, + class = "wwinference_input_data_error" + ) + } + invisible() +} + #' Assert that a vector is either of a vector of integers or a vector of @@ -347,19 +394,15 @@ assert_req_ww_cols_present <- function(ww_data, #' traceback. #' #' @return NULL, invisibly -check_req_count_cols_present <- function(count_data, - count_col_name, - pop_size_col_name, - add_req_col_names = c("date"), - call = rlang::caller_env()) { +assert_req_count_cols_present <- function(count_data, + count_col_name, + pop_size_col_name, + add_req_col_names = c("date"), + call = rlang::caller_env()) { column_names <- colnames(count_data) expected_col_names <- c( - { - count_col_name - }, - { - pop_size_col_name - }, + count_col_name, + pop_size_col_name, add_req_col_names ) @@ -491,6 +534,9 @@ assert_daily_data <- function(dates, #' calibration time #' #' @param date_vector the vector of dates to check, must be of Date type +#' @param data_name What data correspond to the dates in `date_vector`. +#' Used to make the error message informative (e.g. +#' "hospital admissions data") #' @param calibration_time integer indicating the number of days that #' the dates must span #' @param call Calling environment to be passed to [cli::cli_abort()] for @@ -500,6 +546,7 @@ assert_daily_data <- function(dates, #' #' @return NULL invisible assert_sufficient_days_of_data <- function(date_vector, + data_name, calibration_time, call = rlang::caller_env(), add_err_msg = "") { @@ -511,7 +558,8 @@ assert_sufficient_days_of_data <- function(date_vector, if (!check_sufficient_data) { cli::cli_abort( c( - "Insufficient data for specified calibration time" + "Insufficient {.arg {data_name}} for the specified calibration time. ", + add_err_msg ), call = call, class = "wwinference_specification_error" @@ -540,9 +588,8 @@ assert_dates_within_frame <- function(dates1, checkmate::assert_date(dates1) checkmate::assert_date(dates2) check_dates2_win_frame <- min(dates1) <= max(dates2) & - min(dates2) >= min(dates1) & - max(dates2) <= max_date & - max(dates1) <= max_date + min(dates2) <= max(dates1) + if (!check_dates2_win_frame) { cli::cli_abort( c( @@ -556,6 +603,8 @@ assert_dates_within_frame <- function(dates1, invisible() } + + #' Assert that two tibbles of date and time mapping align #' #' @param first_data a tibble containing the columns `date` (with IS08601 diff --git a/R/data.R b/R/data.R index 7cc9f8ac..04b0e91d 100644 --- a/R/data.R +++ b/R/data.R @@ -42,12 +42,13 @@ #' Example wastewater dataset with independent site correlations. #' #' A dataset containing the simulated wastewater concentrations -#' (labeled here as `genome_copies_per_ml`) by sample collection date (`date`), -#' the site where the sample was collected (`site`) and the lab where the -#' samples were processed (`lab`). Additional columns that are required -#' attributes needed for the model are the limit of detection for that lab on -#' each day (labeled here as `lod`) and the population size of the wastewater -#' catchment area represented by the wastewater concentrations in each `site`. +#' (labeled here as `log_genome_copies_per_ml`) by sample collection date +#' (`date`), the site where the sample was collected (`site`) and the lab +#' where the samples were processed (`lab`). Additional columns that are +#' required attributes needed for the model are the limit of detection for +#' that lab on each day (labeled here as `log_lod`) and the population size of +#' the wastewater catchment area represented by the wastewater concentrations +#' in each `site`. #' #' This data is generated via the default values in the #' `generate_simulated_data()` function. They represent the bare minumum @@ -63,15 +64,18 @@ #' YYYY-MM-DD} #' \item{site}{The wastewater treatment plant where the sample was collected} #' \item{lab}{The lab where the sample was processed} -#' \item{genome_copies_per_ml}{The wastewater concentration measured on the -#' date specified, collected in the site specified, and processed in the lab -#' specified. The default parameters assume that this quantity is reported -#' as the genome copies per mL, on a natural scale.} -#' \item{lod}{The limit of detection in the site and lab on a particular day -#' of the quantification device (e.g. PCR). This is also by default reported -#' in terms of the genome copies per mL.} +#' \item{log_genome_copies_per_ml}{The natural log of the wastewater +#' concentration measured on the date specified, collected in the site +#' specified, and processed in the lab specified. The package expects +#' this quantity in units of log estimated genome copies per mL.} +#' \item{log_lod}{The log of the limit of detection in the site and lab on a +#' particular day of the quantification device (e.g. PCR). This should be in +#' units of log estimated genome copies per mL.} #' \item{site_pop}{The population size of the wastewater catchment area #' represented by the site variable} +#' \item{location}{ A string indicating the location that all of the +#' data is coming from. This is not a necessary column, but instead is +#' included to more realistically mirror a typical workflow} #' } #' @source vignette_data.R "ww_data_ind" @@ -106,6 +110,9 @@ #' hospital on that date, available as of the forecast date} #' \item{state_pop}{The number of people contributing to the daily hospital #' admissions} +#' \item{location}{ A string indicating the location that all of the +#' data is coming from. This is not a necessary column, but instead is +#' included to more realistically mirror a typical workflow} #' } #' @source vignette_data.R "hosp_data" diff --git a/R/figures.R b/R/figures.R index ee6b95ae..475fd7c7 100644 --- a/R/figures.R +++ b/R/figures.R @@ -30,10 +30,23 @@ get_plot_forecasted_counts <- function(draws, forecast_date, count_type = "hospital admissions", n_draws_to_plot = 100) { - sampled_draws <- sample(1:max(draws$draw), n_draws_to_plot) + n_draws_available <- max(draws$draw) + if (n_draws_available < n_draws_to_plot) { + stop( + sprintf( + "The number of draws to plot (%i) should be less or equal to ", + n_draws_to_plot + ), + sprintf( + "the number of draws in the data (%i).", + n_draws_available + ) + ) + } + + sampled_draws <- sample.int(n_draws_available, n_draws_to_plot) draws_to_plot <- draws |> dplyr::filter( - .data$name == "predicted counts", .data$draw %in% !!sampled_draws ) @@ -97,34 +110,36 @@ get_plot_ww_conc <- function(draws, draws_to_plot <- draws |> dplyr::filter( - .data$name == "predicted wastewater", .data$draw %in% !!sampled_draws - ) |> - dplyr::mutate( - site_lab_name = glue::glue("{subpop}, Lab: {lab}") ) p <- ggplot(draws_to_plot) + geom_line( aes( x = .data$date, y = .data$pred_value, - color = .data$subpop, + color = .data$subpop_name, group = .data$draw ), - alpha = 0.1, linewidth = 0.2, + alpha = 0.1, size = 0.2, show.legend = FALSE ) + geom_point(aes(x = .data$date, y = .data$observed_value), - color = "black", show.legend = FALSE + color = "black", show.legend = FALSE, size = 0.5 ) + - facet_wrap(~site_lab_name, scales = "free") + + geom_point( + data = draws_to_plot |> + dplyr::filter(.data$below_lod == 1), + aes(x = .data$date, y = .data$observed_value), + color = "blue", show.legend = FALSE, size = 0.5 + ) + + facet_wrap(~lab_site_name, scales = "free") + geom_vline( xintercept = lubridate::ymd(forecast_date), linetype = "dashed" ) + xlab("") + ylab("Log genome copies/mL") + - ggtitle("Lab-site level wastewater concentration") + + ggtitle("Lab-site level wastewater concentrations") + scale_x_date( date_breaks = "2 weeks", labels = scales::date_format("%Y-%m-%d") @@ -132,11 +147,13 @@ get_plot_ww_conc <- function(draws, theme_bw() + theme( axis.text.x = element_text( - size = 8, vjust = 1, + size = 5, vjust = 1, hjust = 1, angle = 45 ), axis.title.x = element_text(size = 12), + axis.text.y = element_text(size = 5), axis.title.y = element_text(size = 12), + strip.text = element_text(size = 6), plot.title = element_text( size = 10, vjust = 0.5, hjust = 0.5 @@ -163,10 +180,9 @@ get_plot_ww_conc <- function(draws, get_plot_global_rt <- function(draws, forecast_date, n_draws_to_plot = 100) { - sampled_draws <- sample(1:max(draws$draw), n_draws_to_plot) + sampled_draws <- sample.int(max(draws$draw), n_draws_to_plot) draws_to_plot <- draws |> dplyr::filter( - .data$name == "global R(t)", .data$draw %in% !!sampled_draws ) @@ -191,10 +207,11 @@ get_plot_global_rt <- function(draws, theme_bw() + theme( axis.text.x = element_text( - size = 8, vjust = 1, + size = 5, vjust = 1, hjust = 1, angle = 45 ), axis.title.x = element_text(size = 12), + axis.text.y = element_text(size = 5), axis.title.y = element_text(size = 12), plot.title = element_text( size = 10, @@ -222,10 +239,9 @@ get_plot_global_rt <- function(draws, get_plot_subpop_rt <- function(draws, forecast_date, n_draws_to_plot = 100) { - sampled_draws <- sample(1:max(draws$draw), n_draws_to_plot) + sampled_draws <- sample.int(max(draws$draw), n_draws_to_plot) draws_to_plot <- draws |> dplyr::filter( - .data$name == "subpopulation R(t)", .data$draw %in% !!sampled_draws ) @@ -233,7 +249,7 @@ get_plot_subpop_rt <- function(draws, geom_step( aes( x = .data$date, y = .data$pred_value, group = .data$draw, - color = .data$subpop + color = .data$subpop_name ), alpha = 0.1, linewidth = 0.2, show.legend = FALSE @@ -243,7 +259,7 @@ get_plot_subpop_rt <- function(draws, linetype = "dashed", show.legend = FALSE ) + - facet_wrap(~subpop, scales = "free") + + facet_wrap(~subpop_name, scales = "free") + geom_hline(aes(yintercept = 1), linetype = "dashed") + xlab("") + ylab("Subpopulation R(t)") + @@ -255,11 +271,13 @@ get_plot_subpop_rt <- function(draws, theme_bw() + theme( axis.text.x = element_text( - size = 8, vjust = 1, + size = 5, vjust = 1, hjust = 1, angle = 45 ), + axis.text.y = element_text(size = 5), axis.title.x = element_text(size = 12), axis.title.y = element_text(size = 12), + strip.text = element_text(size = 6), plot.title = element_text( size = 10, vjust = 0.5, hjust = 0.5 diff --git a/R/generate_simulated_data.R b/R/generate_simulated_data.R index 97eb0845..fbc4b33d 100644 --- a/R/generate_simulated_data.R +++ b/R/generate_simulated_data.R @@ -191,13 +191,6 @@ generate_simulated_data <- function(r_in_weeks = # nolint assert_ww_site_pops_lt_total(pop_size, ww_pop_sites) assert_site_lab_indices_align(site, lab) - - # Spatial bool check, if no spatial use ind. corr. func. with n+1 sites. - if (!use_spatial_corr) { - corr_function <- independence_corr_func - corr_fun_params <- list(num_sites = n_sites + 1) - } - # Expose the stan functions into the global environment-------------------- model <- cmdstanr::cmdstan_model( stan_file = system.file( @@ -220,6 +213,12 @@ generate_simulated_data <- function(r_in_weeks = # nolint ) spatial_fxns$expose_functions(global = TRUE) + # Spatial bool check, if no spatial use ind. corr. func. with n+1 sites. + if (!use_spatial_corr) { + corr_function <- independence_corr_func_r + corr_fun_params <- list(num_sites = n_sites + 1) + } + # Get other variables needed for forward simulation ------------------------ params <- get_params(input_params_path) # load in parameters diff --git a/R/get_draws.R b/R/get_draws.R new file mode 100644 index 00000000..5af9da11 --- /dev/null +++ b/R/get_draws.R @@ -0,0 +1,490 @@ +#' @title Postprocess to generate a draws dataframe +#' +#' @description +#' This function takes in the two input data sources, the CmdStan fit object, +#' and the 3 relevant mappings from stan indices to the real data, in order +#' to generate a dataframe containing the posterior draws of the counts (e.g. +#' hospital admissions), the wastewater concentration values, the "global" R(t), +#' and the "local" R(t) estimates + the critical metadata in the data. +#' This funtion has a default method that takes the two sets of input data, +#' the last of stan arguments, and the CmdStan fitting object, as well as an S3 +#' method for objects of class 'wwinference_fit' +#' +#' +#' @param x Either a dataframe of wastewater observations, or an object of +#' class wwinference_fit +#' @param ... additional arguments +#' @param what Character vector. Specifies the variables to extract from the +#' draws. It could be any from `"all"` `"predicted_counts"`, `"predicted_ww"`, +#' `"global_rt"`, or `"subpop_rt"`. When `what = "all"` (the default), +#' the function will extract all four variables. +#' @return A tibble containing the full set of posterior draws of the +#' estimated, nowcasted, and forecasted: counts, site-level wastewater +#' concentrations, "global"(e.g. state) R(t) estimate, and the "local" (site + +#' the one auxiliary subpopulation) R(t) estimates. In the instance where there +#' are observations, the data will be joined to each draw of the predicted +#' observation to facilitate plotting. +#' @export +get_draws <- function(x, ..., what = "all") { + UseMethod("get_draws") +} + +#' @rdname get_draws +#' @details +#' The function `get_draws_df()` has been deprecated in favor of `get_draws()`. +#' +#' @export +get_draws_df <- function(x, ...) { + .Deprecated("get_draws") +} + +#' S3 method for extracting posterior draws alongside data for a +#' wwinference_fit object +#' +#' This method overloads the generic `get_draws` function specifically +#' for objects of type 'wwinference_fit'. +#' +#' @rdname get_draws +#' @export +get_draws.wwinference_fit <- function(x, ..., what = "all") { + get_draws.data.frame( + x = x$raw_input_data$input_ww_data, + count_data = x$raw_input_data$input_count_data, + date_time_spine = x$raw_input_data$date_time_spine, + site_subpop_spine = x$raw_input_data$site_subpop_spine, + lab_site_subpop_spine = x$raw_input_data$lab_site_subpop_spine, + stan_data_list = x$stan_data_list, + fit_obj = x$fit, + what = what + ) +} + +#' @export +#' @rdname get_draws +get_draws.default <- function(x, ..., what = "all") { + stop( + "No method defined for get_draws for object of class(es) ", + paste(class(x), collapse = ", "), + ". Use directly on a wwinference_fit object or a", + "dataframe of wastewater observations.", + call. = FALSE + ) +} + +#' Vector of valid values for `what` in `get_draws` +#' @noRd +get_draws_what_ok <- c( + "all", "predicted_counts", "predicted_ww", "global_rt", "subpop_rt" +) + +#' @rdname get_draws +#' @param count_data A dataframe of the preprocessed daily count data (e.g. +#' hospital admissions) from the "global" population +#' @param date_time_spine tibble mapping dates to time in days +#' @param site_subpop_spine tibble mapping sites to subpopulations +#' @param lab_site_subpop_spine tibble mapping lab-sites to subpopulations +#' @param stan_data_list A list containing all the data passed to stan for +#' fitting the model +#' @param fit_obj a CmdStan object that is the output of fitting the model to +#' `x` and `count_data` +#' @export +get_draws.data.frame <- function(x, + count_data, + date_time_spine, + site_subpop_spine, + lab_site_subpop_spine, + stan_data_list, + fit_obj, + ..., + what = "all") { + # Checking we are getting all + what_ok <- get_draws_what_ok + + if (any(!what %in% what_ok)) { + idx <- which(!what %in% what_ok) + stop( + "The following invalid values were passed to `what`: ", + paste(what[idx], collapse = ", "), ". Valid values include: ", + paste(what_ok, collapse = ", "), "." + ) + } + + what_ok <- logical(length(what_ok)) + names(what_ok) <- get_draws_what_ok + what_ok[] <- FALSE + if ("all" %in% what) { + if (length(what) > 1) { + warning("Ignoring other values of `what` when `all` is present.") + } + what_ok[] <- TRUE + } else { + what_ok[what] <- TRUE + } + if (stan_data_list$include_ww == 0) { + if (any(c("predicted_ww", "subpop_rt") %in% what)) { + cli::cli_abort(c( + "Predicted wastewater concentrations and subpopulation R(t)s", + " can not be returned because the model wasn't fit to ", + " site-level wastewater data" + )) + } + what_ok["predicted_ww"] <- FALSE + what_ok["subpop_rt"] <- FALSE + if (what == "all") { + warning(c( + "Model wasn't fit to wastewater data. ", + "Predicted wastewater concentrations and subpopulation R(t)s", + "\nestimates will not be returned in the ", + "`wwinference_fit_draws` object" + )) + } + } + + draws <- fit_obj$result$draws() + + + count_draws <- if (what_ok["predicted_counts"]) { + draws |> # predicted_counts + tidybayes::spread_draws(!!str2lang("pred_hosp[t]")) |> + dplyr::rename("pred_value" = "pred_hosp") |> + dplyr::mutate( + draw = .data$`.draw`, + ) |> + dplyr::select("t", "pred_value", "draw") |> + dplyr::left_join(date_time_spine, by = "t") |> + dplyr::left_join( + count_data |> + dplyr::select(-"t"), + by = "date" + ) |> + dplyr::ungroup() |> + dplyr::rename("observed_value" = "count") |> + dplyr::select( + "date", + "draw", + "observed_value", + "pred_value", + "total_pop" + ) + } else { + NULL + } + + + ww_draws <- if (what_ok["predicted_ww"]) { + draws |> + tidybayes::spread_draws(!!str2lang("pred_ww[lab_site_index, t]")) |> + dplyr::rename("pred_value" = "pred_ww") |> + dplyr::mutate( + draw = .data$`.draw` + ) |> + dplyr::select("lab_site_index", "t", "pred_value", "draw") |> + dplyr::left_join(date_time_spine, by = "t") |> + dplyr::left_join(lab_site_subpop_spine, by = "lab_site_index") |> + dplyr::left_join( + x |> dplyr::distinct( + .data$log_genome_copies_per_ml, + .data$log_lod, + .data$date, + .data$below_lod, + .data$lab_site_index + ), + by = c( + "lab_site_index", "date" + ) + ) |> + dplyr::ungroup() |> + dplyr::mutate( + observed_value = .data$log_genome_copies_per_ml, + ) |> + dplyr::select( + "date", + "lab_site_name", + "pred_value", + "draw", + "observed_value", + "subpop_name", + "subpop_pop", + "site", + "lab", + "log_lod", + "below_lod", + "lab_site_index" + ) + } else { + NULL + } + + global_rt_draws <- if (what_ok["global_rt"]) { + draws |> + tidybayes::spread_draws(!!str2lang("rt[t]")) |> + dplyr::rename("pred_value" = "rt") |> + dplyr::mutate( + draw = .data$`.draw` + ) |> + dplyr::select("t", "pred_value", "draw") |> + dplyr::left_join(date_time_spine, by = "t") |> + dplyr::left_join( + count_data |> + dplyr::select(-"t"), + by = "date" + ) |> + dplyr::ungroup() |> + dplyr::select( + "date", + "pred_value", + "draw", + "total_pop" + ) + } else { + NULL + } + + subpop_rt_draws <- if (what_ok["subpop_rt"]) { + draws |> + tidybayes::spread_draws(!!str2lang("r_subpop_t[subpop_index, t]")) |> + dplyr::rename("pred_value" = "r_subpop_t") |> + dplyr::mutate( + draw = .data$`.draw`, + pred_value = .data$pred_value + ) |> + dplyr::select("subpop_index", "t", "pred_value", "draw") |> + dplyr::left_join(date_time_spine, by = "t") |> + dplyr::left_join(site_subpop_spine, by = "subpop_index") |> + dplyr::ungroup() |> + dplyr::select( + "date", + "pred_value", + "draw", + "subpop_name", + "subpop_pop", + ) + } else { + NULL + } + + return( + new_wwinference_fit_draws( + predicted_counts = count_draws, + predicted_ww = ww_draws, + global_rt = global_rt_draws, + subpop_rt = subpop_rt_draws + ) + ) +} + +#' @export +print.wwinference_fit_draws <- function(x, ...) { + # Computing the draws + draws <- c( + ifelse(length(x$predicted_counts) > 0, max(x$predicted_counts$draw), 0), + ifelse(length(x$predicted_ww) > 0, max(x$predicted_ww$draw), 0), + ifelse(length(x$global_rt) > 0, max(x$global_rt$draw), 0), + ifelse(length(x$subpop_rt) > 0, max(x$subpop_rt$draw), 0) + ) |> max() + + # This calculates the number of time points in each dataframe + timepoints <- c( + ifelse( + length(x$predicted_counts) > 0, + diff(range(x$predicted_counts$date)) + 1, 0 + ), + ifelse( + length(x$predicted_ww) > 0, + diff(range(x$predicted_ww$date)) + 1, 0 + ), + ifelse( + length(x$global_rt) > 0, + diff(range(x$global_rt$date)) + 1, 0 + ), + ifelse( + length(x$subpop_rt) > 0, + diff(range(x$subpop_rt$date)) + 1, 0 + ) + ) |> max() + + cat( + sprintf( + "Draws from the model featuring %i draws across %i days ", + draws, timepoints + ), + "in the following datasets:\n" + ) # Same draws and timepoints + + if (length(x$predicted_counts)) { + cat( + sprintf( + " - `$predicted_counts` with %i rows\n", + nrow(x$predicted_counts) + ) + ) + } + + if (length(x$predicted_ww)) { + cat( + sprintf( + " - `$predicted_ww` with %i rows across %i sites.\n", + nrow(x$predicted_ww), + length(unique(x$predicted_ww$lab_site_index)) + ) + ) + } + if (length(x$global_rt)) { + cat( + sprintf( + " - `$global_rt` with %i rows\n", + nrow(x$global_rt) + ) + ) + } + if (length(x$subpop_rt)) { + cat( + sprintf( + " - `$subpop_rt` with %i rows across %i subpopulations\n", + nrow(x$subpop_rt), + length(unique(x$subpop_rt$subpop_name)) + ) + ) + } + + cat("You can use $ to access the datasets.\n") + + invisible(x) +} + +#' Constructor for the new_wwinference_fit_draws +#' +#' Constructor running some checks on the contents of the data. +#' +#' @param predicted_counts Predicted counts +#' @param predicted_ww Predicted ww concentration +#' @param global_rt Global Rt() +#' @param site_level_r Site-level Rt()s +#' @noRd +new_wwinference_fit_draws <- function( + predicted_counts, + predicted_ww, + global_rt, + subpop_rt) { + # Checking colnames: Must match all exactly + predicted_counts_colnames <- c( + "date", "pred_value", "observed_value", "draw", "total_pop" + ) + if (length(predicted_counts)) { + checkmate::assert_names( + colnames(predicted_counts), + permutation.of = predicted_counts_colnames + ) + } + + predicted_ww_colnames <- c( + "below_lod", + "date", + "draw", + "lab", + "lab_site_name", + "log_lod", + "observed_value", + "pred_value", + "site", + "subpop_pop", + "subpop_name", + "lab_site_index" + ) + if (length(predicted_ww)) { + checkmate::assert_names( + colnames(predicted_ww), + permutation.of = predicted_ww_colnames + ) + } + + global_rt_colnames <- c( + "date", "draw", "pred_value", "total_pop" + ) + if (length(global_rt)) { + checkmate::assert_names( + colnames(global_rt), + permutation.of = global_rt_colnames + ) + } + + subpop_rt_colnames <- c( + "date", + "draw", + "pred_value", + "subpop_pop", + "subpop_name" + ) + if (length(subpop_rt)) { + checkmate::assert_names( + colnames(subpop_rt), + permutation.of = subpop_rt_colnames + ) + } + + structure( + list( + predicted_counts = predicted_counts, + predicted_ww = predicted_ww, + global_rt = global_rt, + subpop_rt = subpop_rt + ), + class = "wwinference_fit_draws" + ) +} + +#' @export +#' @rdname get_draws +#' @param x An object of class `get_draws`. +#' @param y Ignored in the the case of `plot`. +#' @details +#' The plot method for `wwinference_fit_draws` is a wrapper of +#' `get_plot_forecasted_counts`, `get_plot_ww_conc`, `get_plot_global_rt`, +#' and `get_plot_subpop_rt`. Depending on the value of `what`, the function +#' will call the appropriate method. +#' +plot.wwinference_fit_draws <- function(x, y = NULL, what, ...) { + if (length(what) != 1L) { + stop( + "The value provided to `what` must be a length one character vector. ", + "Currently, it is of length ", length(what), "." + ) + } + + which_what_are_ok <- setdiff(get_draws_what_ok, "all") + + if (!what %in% which_what_are_ok) { + stop( + sprintf( + paste0( + "The value provided to what (%s) is invalid. ", + "Valid values include \"%s\"." + ), + paste(what, collapse = ", "), + paste(which_what_are_ok, collapse = "\", \"") + ) + ) + } + + if (what == "predicted_counts") { + get_plot_forecasted_counts( + draws = x$predicted_counts, + ... + ) + } else if (what == "predicted_ww") { + get_plot_ww_conc( + x$predicted_ww, + ... + ) + } else if (what == "global_rt") { + get_plot_global_rt( + x$global_rt, + ... + ) + } else if (what == "subpop_rt") { + get_plot_subpop_rt( + x$subpop_rt, + ... + ) + } +} diff --git a/R/get_draws_df.R b/R/get_draws_df.R deleted file mode 100644 index 60d2eebe..00000000 --- a/R/get_draws_df.R +++ /dev/null @@ -1,224 +0,0 @@ -#' @title Postprocess to generate a draws dataframe -#' -#' @description -#' This function takes in the two input data sources, the CmdStan fit object, -#' and the 3 relevant mappings from stan indices to the real data, in order -#' to generate a dataframe containing the posterior draws of the counts (e.g. -#' hospital admissions), the wastewater concentration values, the "global" R(t), -#' and the "local" R(t) estimates + the critical metadata in the data. -#' This funtion has a default method that takes the two sets of input data, -#' the last of stan arguments, and the CmdStan fitting object, as well as an S3 -#' method for objects of class 'wwinference_fit' -#' -#' -#' @param x Either a dataframe of wastewater observations, or an object of -#' class wwinference_fit -#' @param count_data A dataframe of the preprocessed daily count data (e.g. -#' hospital admissions) from the "global" population -#' @param stan_data_list A list containing all the data passed to stan for -#' fitting the model -#' @param fit_obj a CmdStan object that is the output of fitting the model to -#' `x` and `count_data` -#' @param ... additional arguments -#' @return A tibble containing the full set of posterior draws of the -#' estimated, nowcasted, and forecasted: counts, site-level wastewater -#' concentrations, "global"(e.g. state) R(t) estimate, and the "local" (site + -#' the one auxiliary subpopulation) R(t) estimates. In the instance where there -#' are observations, the data will be joined to each draw of the predicted -#' observation to facilitate plotting. -#' @export -get_draws_df <- function(x, ...) { - UseMethod("get_draws_df") -} - -#' S3 method for extracting posterior draws alongside data for a -#' wwinference_fit object -#' -#' This method overloads the generic get_draws_df function specifically -#' for objects of type 'wwinference_fit'. -#' -#' @rdname get_draws_df -#' @export -get_draws_df.wwinference_fit <- function(x, ...) { - get_draws_df.data.frame( - x = x$raw_input_data$input_ww_data, - count_data = x$raw_input_data$input_count_data, - stan_data_list = x$stan_data_list, - fit_obj = x$fit - ) -} - -#' @export -#' @rdname get_draws_df -get_draws_df.default <- function(x, ...) { - stop( - "No method defined for get_draws_df for object of class(es) ", - paste(class(x), collapse = ", "), - ". Use directly on a wwinference_fit object or a", - "dataframe of wastewater observations.", - call. = FALSE - ) -} - -#' @rdname get_draws_df -#' @export -get_draws_df.data.frame <- function(x, - count_data, - stan_data_list, - fit_obj, - ...) { - draws <- fit_obj$result$draws() - - # Get the necessary mappings needed to join draws to data - date_time_spine <- tibble::tibble( - date = seq( - from = min(count_data$date), - to = min(count_data$date) + stan_data_list$ot + stan_data_list$ht, - by = "days" - ) - ) |> - dplyr::mutate(t = row_number()) - # Lab-site index to corresponding lab, site, and site population size - lab_site_spine <- x |> - dplyr::distinct(.data$site, .data$lab, .data$lab_site_index, .data$site_pop) - # Site index to corresponding site and subpopulation size - subpop_spine <- x |> - dplyr::distinct(.data$site, .data$site_index, .data$site_pop) |> - dplyr::mutate(site = as.factor(.data$site)) |> - dplyr::bind_rows(tibble::tibble( - site = "remainder of pop", - site_index = max(x$site_index) + 1, - site_pop = stan_data_list$subpop_size[ - length(unique(stan_data_list$subpop_size)) - ] - )) - - - count_draws <- draws |> - tidybayes::spread_draws(!!str2lang("pred_hosp[t]")) |> - dplyr::rename("pred_value" = "pred_hosp") |> - dplyr::mutate( - draw = .data$`.draw`, - name = "predicted counts" - ) |> - dplyr::select("name", "t", "pred_value", "draw") |> - dplyr::left_join(date_time_spine, by = "t") |> - dplyr::left_join( - count_data |> - dplyr::select(-"t"), - by = "date" - ) |> - dplyr::ungroup() |> - dplyr::rename("observed_value" = "count") |> - dplyr::mutate( - observation_type = "count", - type_of_quantity = "global", - lab_site_index = NA, - subpop = NA, - lab = NA, - site_pop = NA, - below_lod = NA, - log_lod = NA, - flag_as_ww_outlier = NA, - exclude = NA - ) |> - dplyr::select(-"t") - - ww_draws <- draws |> - tidybayes::spread_draws(!!str2lang("pred_ww[lab_site_index, t]")) |> - dplyr::rename("pred_value" = "pred_ww") |> - dplyr::mutate( - draw = .data$`.draw`, - name = "predicted wastewater", - ) |> - dplyr::select("name", "lab_site_index", "t", "pred_value", "draw") |> - dplyr::left_join(date_time_spine, by = "t") |> - dplyr::left_join(lab_site_spine, by = "lab_site_index") |> - dplyr::left_join( - x |> - dplyr::select(-"t"), - by = c( - "lab_site_index", "date", - "lab", "site", "site_pop" - ) - ) |> - dplyr::ungroup() |> - dplyr::mutate(observed_value = .data$log_genome_copies_per_ml) |> - dplyr::mutate( - observation_type = "log genome copies per mL", - type_of_quantity = "local", - total_pop = NA, - subpop = glue::glue("Site: {site}") - ) |> - dplyr::select(colnames(count_draws), -"t") - - global_rt_draws <- draws |> - tidybayes::spread_draws(!!str2lang("rt[t]")) |> - dplyr::rename("pred_value" = "rt") |> - dplyr::mutate( - draw = .data$`.draw`, - name = "global R(t)" - ) |> - dplyr::select("name", "t", "pred_value", "draw") |> - dplyr::left_join(date_time_spine, by = "t") |> - dplyr::left_join( - count_data |> - dplyr::select(-"t"), - by = "date" - ) |> - dplyr::ungroup() |> - dplyr::rename("observed_value" = "count") |> - dplyr::mutate( - observed_value = NA, - observation_type = "latent variable", - type_of_quantity = "global", - lab_site_index = NA, - subpop = NA, - lab = NA, - site_pop = NA, - below_lod = NA, - log_lod = NA, - flag_as_ww_outlier = NA, - exclude = NA - ) |> - dplyr::select(-"t") - - site_level_rt_draws <- draws |> - tidybayes::spread_draws(!!str2lang("r_site_t[site_index, t]")) |> - dplyr::rename("pred_value" = "r_site_t") |> - dplyr::mutate( - draw = .data$`.draw`, - name = "subpopulation R(t)", - pred_value = .data$pred_value - ) |> - dplyr::select("name", "site_index", "t", "pred_value", "draw") |> - dplyr::left_join(date_time_spine, by = "t") |> - dplyr::left_join(subpop_spine, by = "site_index") |> - dplyr::ungroup() |> - dplyr::mutate( - observed_value = NA, - lab_site_index = NA, - lab = NA, - below_lod = NA, - log_lod = NA, - flag_as_ww_outlier = NA, - exclude = NA, - observation_type = "latent variable", - type_of_quantity = "local", - total_pop = NA, - subpop = ifelse(.data$site != "remainder of pop", - glue::glue("Site: {site}"), "remainder of pop" - ) - ) |> - dplyr::select(colnames(count_draws), -"t") - - all_draws_df <- dplyr::bind_rows( - count_draws, - ww_draws, - global_rt_draws, - site_level_rt_draws - ) - - - return(all_draws_df) -} diff --git a/R/get_stan_data.R b/R/get_stan_data.R index a52757af..75d84950 100644 --- a/R/get_stan_data.R +++ b/R/get_stan_data.R @@ -14,7 +14,7 @@ get_input_count_data_for_stan <- function(preprocessed_count_data, input_count_data_filtered <- preprocessed_count_data |> dplyr::filter( - .data$date > last_count_data_date - lubridate::days(!!calibration_time) + .data$date > !!last_count_data_date - lubridate::days(!!calibration_time) ) count_data <- add_time_indexing(input_count_data_filtered) @@ -42,68 +42,229 @@ get_input_ww_data_for_stan <- function(preprocessed_ww_data, last_count_data_date, calibration_time) { # Test to see if ww_data_present - ww_data_present <- nrow(preprocessed_ww_data) != 0 + ww_data_present <- !is.null(preprocessed_ww_data) if (ww_data_present == FALSE) { message("No wastewater data present") - } - - if (all(sum(preprocessed_ww_data$flag_as_ww_outlier) > sum( - preprocessed_ww_data$exclude - ))) { - cli::cli_warn( - c( - "Wastewater data being passed to the model has outliers flagged,", - "but not all have been indicated for exclusion from model fit" + ww_data <- NULL + } else { + if (all(sum(preprocessed_ww_data$flag_as_ww_outlier) > sum( + preprocessed_ww_data$exclude + ))) { + cli::cli_warn( + c( + "Wastewater data being passed to the model has outliers flagged,", + "but not all have been indicated for exclusion from model fit" + ) ) + } + + # Test for presence of needed column names + assert_req_ww_cols_present(preprocessed_ww_data, + conc_col_name = "log_genome_copies_per_ml", + lod_col_name = "log_lod" ) + + # Filter out wastewater outliers, and remove extra wastewater + # data. Arrange data for indexing. This is what will be returned. + ww_data <- preprocessed_ww_data |> + dplyr::filter( + .data$exclude != 1, + .data$date > !!last_count_data_date - + lubridate::days(!!calibration_time) + ) |> + dplyr::arrange(.data$date, .data$lab_site_index) } + return(ww_data) +} - # Test for presence of needed column names - assert_req_ww_cols_present(preprocessed_ww_data, - conc_col_name = "log_genome_copies_per_ml", - lod_col_name = "log_lod" +#' Get date time spine to map to model output +#' +#' @param forecast_date a character string in ISO8601 format (YYYY-MM-DD) +#' indicating the date that the forecast is to be made. +#' @param input_count_data a dataframe of the count data to be passed +#' directly to stan, , must have the following columns: date, count, total_pop +#' @param last_count_data_date string indicating the date of the last observed +#' count data point in 1SO8601 format (YYYY-MM-DD) +#' @param calibration_time integer indicating the number of days to calibrate +#' the model for, default is `90` +#' @param forecast_horizon integer indicating the number of days, including the +#' forecast date, to produce forecasts for, default is `28` +#' +#' +#' @return a tibble containing an integer for time mapped to the corresponding +#' date, for the entire calibration and forecast period +#' @export +#' +get_date_time_spine <- function(forecast_date, + input_count_data, + last_count_data_date, + calibration_time, + forecast_horizon) { + nowcast_time <- as.integer( + lubridate::ymd(forecast_date) - last_count_data_date ) + date_time_spine <- tibble::tibble( + date = seq( + from = min(input_count_data$date), + to = min(input_count_data$date) + + calibration_time + + nowcast_time + + forecast_horizon, + by = "days" + ) + ) |> + dplyr::mutate(t = row_number()) + return(date_time_spine) +} - # Filter out wastewater outliers, and remove extra wastewater - # data. Arrange data for indexing. This is what will be returned. - ww_data <- preprocessed_ww_data |> - dplyr::filter( - .data$exclude != 1, - .data$date > !!last_count_data_date - - lubridate::days(!!calibration_time) - ) |> - dplyr::arrange(.data$date, .data$lab_site_index) +#' Get mapping from lab-site to site +#' +#' @param input_ww_data a dataframe of the wastewater data to be passed +#' directly to stan, must have the following columns: date, site, lab, +#' genome_copies_per_ml, site_pop, below_lod, and exclude +#' +#' @return a dataframe mapping the unique combinations of sites and labs +#' to their indices in the model and the population of the site in that +#' observation unit (lab_site) +#' @export +#' +get_lab_site_site_spine <- function(input_ww_data) { + ww_data_present <- !is.null(input_ww_data) + + if (ww_data_present) { + lab_site_site_spine <- + input_ww_data |> + dplyr::select( + "lab_site_index", "site_index", + "site", "lab", "site_pop" + ) |> + dplyr::arrange(.data$lab_site_index) |> + dplyr::distinct() |> + dplyr::mutate( + "lab_site_name" = glue::glue( + "Site: {site}, Lab: {lab}" + ) + ) + } else { + lab_site_site_spine <- tibble::tibble() + } - ww_data_sizes <- get_ww_data_sizes( - ww_data, - lod_col_name = "below_lod" - ) - ww_indices <- get_ww_data_indices( - ww_data, - first_count_data_date, - owt = ww_data_sizes$owt, - lod_col_name = "below_lod" - ) + return(lab_site_site_spine) +} + +#' Get site to subpopulation map +#' +#' @param input_ww_data a dataframe of the wastewater data to be passed +#' directly to stan, must have the following columns: date, site, lab, +#' genome_copies_per_ml, site_pop, below_lod, and exclude +#' @param input_count_data a dataframe of the count data to be passed +#' directly to stan, , must have the following columns: date, count, total_pop +#' +#' @return a dataframe mapping the sites to the corresponding subpopulation and +#' subpopulation index, plus the population in each subpopulation. Imposes +#' the logic to add a subpopulation if the total population is greater than +#' the sum of the site populations in the input wastewater data +#' @export +#' +get_site_subpop_spine <- function(input_ww_data, + input_count_data) { + ww_data_present <- !is.null(input_ww_data) - ww_data <- ww_data |> - dplyr::mutate( - t = ww_indices$ww_sampled_times + total_pop <- input_count_data |> + dplyr::distinct(.data$total_pop) |> + dplyr::pull() + + if (ww_data_present) { + add_auxiliary_subpop <- ifelse( + total_pop > sum(unique(input_ww_data$site_pop)), + TRUE, + FALSE ) + site_indices <- input_ww_data |> + dplyr::select("site_index", "site", "site_pop") |> + dplyr::distinct() |> + dplyr::arrange(.data$site_index) - return(ww_data) + if (add_auxiliary_subpop) { + aux_subpop <- tibble::tibble( + "site_index" = NA, + "site" = NA, + "site_pop" = total_pop - sum(site_indices$site_pop) + ) + } else { + aux_subpop <- tibble::tibble() + } + + site_subpop_spine <- aux_subpop |> + dplyr::bind_rows(site_indices) |> + dplyr::mutate( + subpop_index = dplyr::row_number() + ) |> + dplyr::mutate( + subpop_name = ifelse(!is.na(.data$site), + glue::glue("Site: {site}"), + "remainder of population" + ) + ) |> + dplyr::rename( + "subpop_pop" = "site_pop" + ) + } else { + site_subpop_spine <- tibble::tibble( + "site_index" = NA, + "site" = NA, + "subpop_pop" = total_pop, + "subpop_index" = 1, + "subpop_name" = "total population" + ) + } + + return(site_subpop_spine) +} + +#' Get lab-site subpopulation spine +#' +#' @param lab_site_site_spine tibble mapping lab-sites to sites +#' @param site_subpop_spine tibble mapping sites to subpopulations +#' +#' @return a tibble mapping lab-sites to subpopulations +#' @export +#' +get_lab_site_subpop_spine <- function(lab_site_site_spine, + site_subpop_spine) { + ww_data_present <- !nrow(lab_site_site_spine) == 0 + # Get lab_site to subpop spine + if (ww_data_present) { + lab_site_subpop_spine <- lab_site_site_spine |> + dplyr::left_join(site_subpop_spine, by = c("site_index", "site")) + } else { + lab_site_subpop_spine <- tibble::tibble( + subpop_index = numeric() + ) + } + + return(lab_site_subpop_spine) } + #' Get stan data for ww + hosp model #' -#' @param input_count_data a dataframe of the count data to be passed -#' directly to stan, , must have the following columns: date, count, total_pop -#' @param input_ww_data a dataframe of the wastewater data to be passed -#' directly to stan, must have the following columns: date, site, lab, -#' genome_copies_per_ml, site_pop, below_lod, and exclude + +#' @param input_count_data tibble with the input count data needed for stan +#' @param input_ww_data tibble with the input wastewater data and indices +#' needed for stan +#' @param date_time_spine tibble mapping dates to time in days +#' @param lab_site_site_spine tibble mapping lab-sites to sites +#' @param site_subpop_spine tibble mapping sites to subpopulations +#' @param lab_site_subpop_spine tibble mapping lab-sites to subpopulations +#' @param last_count_data_date string indicating the date of the last data +#' point in the count dataset in ISO8601 convention e.g. YYYY-MM-DD +#' @param first_count_data_date string indicating the date of the first data +#' point in the count dataset in ISO8601 convention e.g. YYYY-MM-DD #' @param forecast_date string indicating the forecast date in ISO8601 #' convention e.g. YYYY-MM-DD #' @param forecast_horizon integer indicating the number of days to make a @@ -197,9 +358,33 @@ get_input_ww_data_for_stan <- function(preprocessed_ww_data, #' last_count_data_date, #' calibration_time #' ) +#' date_time_spine <- get_date_time_spine( +#' forecast_date = forecast_date, +#' input_count_data = input_count_data_for_stan, +#' last_count_data_date = last_count_data_date, +#' forecast_horizon = forecast_horizon, +#' calibration_time = calibration_time +#' ) +#' lab_site_site_spine <- get_lab_site_site_spine( +#' input_ww_data = input_ww_data_for_stan +#' ) +#' site_subpop_spine <- get_site_subpop_spine( +#' input_ww_data = input_ww_data_for_stan, +#' input_count_data = input_count_data_for_stan +#' ) +#' lab_site_subpop_spine <- get_lab_site_subpop_spine( +#' lab_site_site_spine = lab_site_site_spine, +#' site_subpop_spine +#' ) #' stan_data_list <- get_stan_data( #' input_count_data_for_stan, #' input_ww_data_for_stan, +#' date_time_spine, +#' lab_site_site_spine, +#' site_subpop_spine, +#' lab_site_subpop_spine, +#' last_count_data_date, +#' first_count_data_date, #' forecast_date, #' forecast_horizon, #' calibration_time, @@ -213,6 +398,12 @@ get_input_ww_data_for_stan <- function(preprocessed_ww_data, #' ) get_stan_data <- function(input_count_data, input_ww_data, + date_time_spine, + lab_site_site_spine, + site_subpop_spine, + lab_site_subpop_spine, + last_count_data_date, + first_count_data_date, forecast_date, forecast_horizon, calibration_time, @@ -224,15 +415,10 @@ get_stan_data <- function(input_count_data, compute_likelihood = 1, dist_matrix, corr_structure_switch) { - # Assign parameter names - par_names <- colnames(params) - for (i in seq_along(par_names)) { - assign(par_names[i], as.double(params[i])) - } - # Get the last date that there were observations of the epidemiological # indicator (aka cases or hospital admissions counts) last_count_data_date <- max(input_count_data$date, na.rm = TRUE) + # Validate input pmfs---------------------------------------------------- validate_pmf(generation_interval, calibration_time, input_count_data, @@ -249,14 +435,36 @@ get_stan_data <- function(input_count_data, arg = "infection to count delay" ) - validate_both_datasets( - input_count_data, - input_ww_data, - calibration_time = calibration_time, - forecast_date = forecast_date + # Check that count data doesn't extend beyond forecast date + assert_no_dates_after_max( + date_vector = input_count_data$date, + max_date = forecast_date, + arg_dates = "wastewater data", + arg_max_date = "forecast date" ) + # Validate both datasets if both are used---------------------------------- + if (include_ww == 1) { + validate_both_datasets( + input_count_data = input_count_data, + input_ww_data = input_ww_data, + date_time_spine = date_time_spine, + lab_site_site_spine = lab_site_site_spine, + site_subpop_spine = site_subpop_spine, + lab_site_subpop_spine = lab_site_subpop_spine, + calibration_time = calibration_time, + forecast_date = forecast_date + ) + # Check that ww data doesn't extend beyond forecast date + assert_no_dates_after_max( + date_vector = input_ww_data$date, + max_date = forecast_date, + arg_dates = "wastewater data", + arg_max_date = "forecast date" + ) + } + # Define some global variables from the input data----------------------- # Get the total pop, coming from the larger population generating the # count data pop <- input_count_data |> @@ -271,59 +479,32 @@ get_stan_data <- function(input_count_data, ) ) - last_count_data_date <- max(input_count_data$date, na.rm = TRUE) - first_count_data_date <- min(input_count_data$date, na.rm = TRUE) - # Returns a list of the vectors of lod values, the site population sizes in - # order of the site index, a vector of observations of the log of - # the genome copies per ml - ww_values <- get_ww_values( - input_ww_data - ) + # Get wastewater inputs------------------------------------------------- # Returns a list with the numbers of elements needed for the stan model ww_data_sizes <- get_ww_data_sizes( input_ww_data ) - # Returns the vectors of indices you need to map latent variables to - # observations - ww_indices <- get_ww_data_indices( - input_ww_data |> dplyr::select(-"t"), - first_count_data_date, - owt = ww_data_sizes$owt + + ww_vals <- get_ww_indices_and_values( + input_ww_data = input_ww_data, + date_time_spine = date_time_spine, + lab_site_site_spine = lab_site_site_spine, + site_subpop_spine = site_subpop_spine, + lab_site_subpop_spine = lab_site_subpop_spine ) - # Ensure that both datasets have overlap with one another, are sufficient - # in length for the specified calibration time, and have proper time indexing stopifnot( "Wastewater sampled times not equal to length of input ww data" = - length(ww_indices$ww_sampled_times) == ww_data_sizes$owt + length(ww_vals$ww_sampled_times) == ww_data_sizes$owt ) message( "Prop of population size covered by wastewater: ", - sum(ww_values$pop_ww) / pop + sum(unique(input_ww_data$site_pop)) / pop ) - if (sum(ww_values$pop_ww) / pop > 1) { - cli::cli_warn(c( - "The sum of the wastewater site catchment area populations:", - "is greater than the global population. While the model supports this", - "we advise checking your input data to ensure it is specified correctly." - )) - } - - # Logic to determine the number of subpopulations to estimate R(t) for: - # First determine if we need to add an additional subpopulation - add_auxiliary_site <- ifelse(pop >= sum(ww_values$pop_ww), TRUE, FALSE) - # Then get the number of subpopulations, the population to normalize by - # (sum of the subpopulations), and the vector of sizes of each subpopulation - subpop_data <- get_subpop_data(add_auxiliary_site, - state_pop = pop, - pop_ww = ww_values$pop_ww, - n_ww_sites = ww_data_sizes$n_ww_sites - ) - - # Get the sizes of all the elements + # Get count data inputs----------------------------------------------- count_data_sizes <- get_count_data_sizes( input_count_data = input_count_data, forecast_date = forecast_date, @@ -371,7 +552,6 @@ get_stan_data <- function(input_count_data, ) inf_to_count_delay_max <- length(inf_to_count_delay) - # If user does / doesn't want spatial comps. # We can add an extra step here for when spatial desired and dist_matrix # not given. @@ -379,8 +559,8 @@ get_stan_data <- function(input_count_data, # This dist_matrix will not be used, only needed for stan data specs. dist_matrix <- matrix( 0, - nrow = subpop_data$n_subpops - 1, - ncol = subpop_data$n_subpops - 1 + nrow = length(ww_vals$subpop_pops) - 1, + ncol = length(ww_vals$subpop_pops) - 1 ) } if (!(corr_structure_switch %in% c(0, 1, 2))) { @@ -399,7 +579,7 @@ get_stan_data <- function(input_count_data, inf_to_hosp = inf_to_count_delay, mwpd = params$ml_of_ww_per_person_day, ot = count_data_sizes$ot, - n_subpops = subpop_data$n_subpops, + n_subpops = length(ww_vals$subpop_pops), n_ww_sites = ww_data_sizes$n_ww_sites, n_ww_lab_sites = ww_data_sizes$n_ww_lab_sites, owt = ww_data_sizes$owt, @@ -415,17 +595,19 @@ get_stan_data <- function(input_count_data, generation_interval = generation_interval, ts = 1:params$gt_max, state_pop = pop, - subpop_size = subpop_data$subpop_size, - norm_pop = subpop_data$norm_pop, - ww_sampled_times = ww_indices$ww_sampled_times, + subpop_size = ww_vals$subpop_pops, + norm_pop = sum(site_subpop_spine$subpop_pop), + ww_sampled_times = ww_vals$ww_sampled_times, hosp_times = count_indices$count_times, - ww_sampled_lab_sites = ww_indices$ww_sampled_lab_sites, - ww_log_lod = ww_values$ww_lod, - ww_censored = ww_indices$ww_censored, - ww_uncensored = ww_indices$ww_uncensored, + ww_sampled_subpops = ww_vals$ww_sampled_subpops, + lab_site_to_subpop_map = lab_site_subpop_spine$subpop_index, + ww_sampled_lab_sites = ww_vals$ww_sampled_lab_sites, + ww_log_lod = ww_vals$ww_lod, + ww_censored = ww_vals$ww_censored, + ww_uncensored = ww_vals$ww_uncensored, hosp = count_values$counts, day_of_week = count_values$day_of_week, - log_conc = ww_values$log_conc, + log_conc = ww_vals$log_conc, compute_likelihood = compute_likelihood, include_ww = include_ww, include_hosp = 1, @@ -435,8 +617,8 @@ get_stan_data <- function(input_count_data, viral_shedding_pars = viral_shedding_pars, # tpeak, viral peak, dur_shed autoreg_rt_a = params$autoreg_rt_a, autoreg_rt_b = params$autoreg_rt_b, - autoreg_rt_site_a = params$autoreg_rt_site_a, - autoreg_rt_site_b = params$autoreg_rt_site_b, + autoreg_rt_subpop_a = params$autoreg_rt_subpop_a, + autoreg_rt_subpop_b = params$autoreg_rt_subpop_b, autoreg_p_hosp_a = params$autoreg_p_hosp_a, autoreg_p_hosp_b = params$autoreg_p_hosp_b, inv_sqrt_phi_prior_mean = params$inv_sqrt_phi_prior_mean, @@ -479,8 +661,6 @@ get_stan_data <- function(input_count_data, sigma_rt_prior = params$sigma_rt_prior, log_phi_g_prior_mean = params$log_phi_g_prior_mean, log_phi_g_prior_sd = params$log_phi_g_prior_sd, - ww_sampled_sites = ww_indices$ww_sampled_sites, - lab_site_to_site_map = ww_indices$lab_site_to_site_map, log_phi_mu_prior = params$log_phi_mu_prior, log_phi_sd_prior = params$log_phi_sd_prior, l = params$l, @@ -489,10 +669,18 @@ get_stan_data <- function(input_count_data, log_scaling_factor_mu_prior = params$log_scaling_factor_mu_prior, log_scaling_factor_sd_prior = params$log_scaling_factor_sd_prior, dist_matrix = dist_matrix, - corr_structure_switch = corr_structure_switch + corr_structure_switch = corr_structure_switch, + offset_ref_log_r_t_prior_mean = params$offset_ref_log_r_t_prior_mean, + offset_ref_log_r_t_prior_sd = params$offset_ref_log_r_t_prior_sd, + offset_ref_logit_i_first_obs_prior_mean = + params$offset_ref_logit_i_first_obs_prior_mean, + offset_ref_logit_i_first_obs_prior_sd = + params$offset_ref_logit_i_first_obs_prior_sd, + offset_ref_initial_exp_growth_rate_prior_mean = + params$offset_ref_initial_exp_growth_rate_prior_mean, + offset_ref_initial_exp_growth_rate_prior_sd = + params$offset_ref_initial_exp_growth_rate_prior_sd ) - - return(stan_data_list) } @@ -557,191 +745,91 @@ get_ww_data_sizes <- function(ww_data, return(data_sizes) } -#' Get wastewater data indices +#' Get wastewater indices and values for stan #' -#' @param ww_data Input wastewater dataframe containing one row -#' per observation, with outliers already removed -#' @param first_count_data_date The earliest day with an observation in the ' -#' count dataset, in ISO8601 format YYYY-MM-DD -#' @param owt number of wastewater observations -#' @param lod_col_name A string representing the name of the -#' column in the input_ww_data that provides a 0 if the data point is not above -#' the LOD and a 1 if the data is below the LOD, default value is `below_LOD` +#' @param input_ww_data tibble with the input wastewater data and indices +#' needed for stan +#' @param date_time_spine tibble mapping dates to time in days +#' @param lab_site_site_spine tibble mapping lab-sites to sites +#' @param site_subpop_spine tibble mapping sites to subpopulations +#' @param lab_site_subpop_spine tibble mapping lab-sites to subpopulations #' -#' @return A list containing the necessary vectors of indices that -#' the stan model requires: -#' ww_censored: the vector of time points that the wastewater observations are -#' censored (below the LOD) in order of the date and the site index -#' ww_uncensored: the vector of time points that the wastewater observations are -#' uncensored (above the LOD) in order of the date and the site index -#' ww_sampled_times: the vector of time points that the wastewater observations -#' are passed in in log_conc in order of the date and the site index -#' ww_sampled_sites: the vector of sites that correspond to the observations -#' passed in in log_conc in order of the date and the site index -#' ww_sampled_lab_sites: the vector of unique combinations of site and labs -#' that correspond to the observations passed in in log_conc in order of the -#' date and the site index -#' lab_site_to_site_map: the vector of sites that correspond to each lab-site +#' @return a list of the vectors needed for stan #' @export -get_ww_data_indices <- function(ww_data, - first_count_data_date, - owt, - lod_col_name = "below_lod") { - # Vector of indices along the list of wastewater concentrations that - # correspond to censored observations - ww_data_present <- nrow(ww_data) != 0 +get_ww_indices_and_values <- function(input_ww_data, + date_time_spine, + lab_site_site_spine, + site_subpop_spine, + lab_site_subpop_spine) { + ww_data_present <- !is.null(input_ww_data) + + # Get a vector of population sizes for each subpop + subpop_pops <- site_subpop_spine |> + dplyr::select("subpop_index", "subpop_pop") |> + dplyr::arrange(.data$subpop_index, "desc") |> + dplyr::pull(.data$subpop_pop) if (isTRUE(ww_data_present)) { - ww_data_with_index <- ww_data |> - dplyr::mutate(ind_rel_to_sampled_times = dplyr::row_number()) - ww_censored <- ww_data_with_index |> - dplyr::filter(.data[[lod_col_name]] == 1) |> + ww_data_joined <- input_ww_data |> + dplyr::left_join(date_time_spine, by = "date") |> + dplyr::left_join(site_subpop_spine, by = c("site_index", "site")) |> + dplyr::mutate("ind_rel_to_sampled_times" = dplyr::row_number()) + + owt <- nrow(ww_data_joined) + + # Get the vector of log LOD values corresponding to each observation + ww_lod <- ww_data_joined |> + dplyr::pull("log_lod") + + # Get the vector of log wastewater concentrations + log_conc <- ww_data_joined |> + dplyr::pull("log_genome_copies_per_ml") + + # Get censored and uncensored indices, which are relative to the vector + # of sampled times (e.g. 1:owt) + ww_censored <- ww_data_joined |> + dplyr::filter(.data$below_lod == 1) |> dplyr::pull(.data$ind_rel_to_sampled_times) - ww_uncensored <- ww_data_with_index |> - dplyr::filter(.data[[lod_col_name]] == 0) |> + ww_uncensored <- ww_data_joined |> + dplyr::filter(.data$below_lod == 0) |> dplyr::pull(.data$ind_rel_to_sampled_times) stopifnot( "Length of censored vectors incorrect" = length(ww_censored) + length(ww_uncensored) == owt ) + ww_sampled_times <- ww_data_joined |> dplyr::pull("t") + ww_sampled_subpops <- ww_data_joined |> dplyr::pull("subpop_index") + lab_site_to_subpop_spine <- lab_site_site_spine |> + dplyr::left_join(site_subpop_spine, by = "site_index") |> + pull("subpop_index") + ww_sampled_lab_sites <- ww_data_joined |> dplyr::pull("lab_site_index") - - # Need to get the times of wastewater sampling, starting at the first - # day of hospital admissions data - ww_date_df <- data.frame( - date = seq( - from = first_count_data_date, - to = max(ww_data$date), - by = "days" - ), - t = 1:(as.integer(max(ww_data$date) - first_count_data_date) + 1) - ) - - # Left join the data mapped to time to the wastewater data - spine_ww <- ww_data |> - dplyr::left_join(ww_date_df, by = "date") - - # Pull just the vector of times of wastewater observations - ww_sampled_times <- spine_ww |> - dplyr::pull(t) - - # Pull just the indexes of the sites that correspond to the vector of - # sampled times - ww_sampled_sites <- ww_data$site_index - - # Pull just the indexes of the lab-sites that correspond to the vector of - # sampled times - ww_sampled_lab_sites <- ww_data$lab_site_index - - # Need a vector of indices indicating the site for each lab-site - lab_site_to_site_map <- ww_data |> - dplyr::select("lab_site_index", "site_index") |> - dplyr::arrange(.data$lab_site_index, "desc") |> - dplyr::distinct() |> - dplyr::pull(.data$site_index) - - ww_data_indices <- list( + ww_values <- list( + ww_lod = ww_lod, + subpop_pops = subpop_pops, + log_conc = log_conc, ww_censored = ww_censored, ww_uncensored = ww_uncensored, ww_sampled_times = ww_sampled_times, - ww_sampled_sites = ww_sampled_sites, - ww_sampled_lab_sites = ww_sampled_lab_sites, - lab_site_to_site_map = lab_site_to_site_map - ) - } else { - ww_data_indices <- list( - ww_censored = c(), - ww_uncensored = c(), - ww_sampled_times = c(), - ww_sampled_sites = c(), - ww_sampled_lab_sites = c(), - lab_site_to_site_map = c() - ) - } - - - return(ww_data_indices) -} - -#' Get wastewater data values -#' -#' @param ww_data Input wastewater dataframe containing one row -#' per observation, with outliers already removed -#' @param ww_measurement_col_name A string representing the name of the column -#' in the input_ww_data that indicates the wastewater measurement value in -#' log scale, default is `log_genome_copies_per_ml` -#' @param ww_lod_value_col_name A string representing the name of the column -#' in the ww_data that indicates the value of the LOD in log scale, -#' default is `log_lod` -#' @param ww_site_pop_col_name A string representing the name of the column in -#' the ww_data that indicates the number of people represented by that -#' wastewater catchment, default is `site_pop` -#' @param one_pop_per_site a boolean variable indicating if there should only -#' be on catchment area population per site, default is `TRUE` because this is -#' what the stan model expects -#' @param padding_value an smaller numeric value to add to the the -#' concentration measurements to ensure that log transformation will produce -#' real numbers, default value is `1e-8` -#' -#' @return A list containing the necessary vectors of values that -#' the stan model requires: -#' ww_lod: a vector of the LODs of the corresponding wastewater measurement -#' pop_ww: a vector of the population sizes of the wastewater catchment areas -#' in order of the sites by site_index -#' log_conc: a vector of the log of the wastewater concentration observation -#' @export -get_ww_values <- function(ww_data, - ww_measurement_col_name = "log_genome_copies_per_ml", - ww_lod_value_col_name = "log_lod", - ww_site_pop_col_name = "site_pop", - one_pop_per_site = TRUE, - padding_value = 1e-8) { - ww_data_present <- nrow(ww_data) != 0 - - if (isTRUE(ww_data_present)) { - # Get the vector of log LOD values corresponding to each observation - ww_lod <- ww_data |> - dplyr::pull({{ ww_lod_value_col_name }}) - - # Get a vector of population sizes - if (isTRUE(one_pop_per_site)) { - # Want one population per site during the model calibration period, - # so just take the average across the populations reported for each - # observation - pop_ww <- ww_data |> - dplyr::select("site_index", {{ ww_site_pop_col_name }}) |> - dplyr::group_by(.data$site_index) |> - dplyr::summarise(pop_avg = mean(.data[[ww_site_pop_col_name]])) |> - dplyr::arrange(.data$site_index, "desc") |> - dplyr::pull(.data$pop_avg) - } else { - # Want a vector of length of the number of observations, corresponding to - # the population at that time - pop_ww <- ww_data |> - dplyr::pull({{ ww_site_pop_col_name }}) - } - - - # Get the vector of log wastewater concentrations - log_conc <- ww_data |> - dplyr::pull({{ ww_measurement_col_name }}) - ww_values <- list( - ww_lod = ww_lod, - pop_ww = pop_ww, - log_conc = log_conc + ww_sampled_subpops = ww_sampled_subpops, + ww_sampled_lab_sites = ww_sampled_lab_sites ) } else { ww_values <- list( - ww_lod = c(), - pop_ww = c(), - log_conc = c() + ww_lod = numeric(), + subpop_pops = subpop_pops, + log_conc = numeric(), + ww_censored = numeric(), + ww_uncensored = numeric(), + ww_sampled_times = numeric(), + ww_sampled_subpops = numeric(), + ww_sampled_lab_sites = numeric() ) } - - return(ww_values) } + #' Add time indexing to count data #' #' @param input_count_data data frame with dates and counts, @@ -773,46 +861,6 @@ add_time_indexing <- function(input_count_data) { return(count_data) } -#' Get subpopulation data -#' -#' @param add_auxiliary_site Boolean indicating whether to add another -#' subpopulation in addition to the wastewater sites to estimate R(t) of -#' @param state_pop The state population size -#' @param pop_ww The population size in each of the wastewater sites -#' @param n_ww_sites The number of wastewater sites -#' -#' @return A list containing the necessary integers and vectors that stan -#' needs to estiamte infection dynamics for each subpopulation -#' @export -#' -#' @examples subpop_data <- get_subpop_data(TRUE, 100000, c(1000, 500), 2) -get_subpop_data <- function(add_auxiliary_site, - state_pop, - pop_ww, - n_ww_sites) { - if (add_auxiliary_site) { - # In most cases, wastewater catchment coverage < entire state. - # So here we add a subpopulation that represents the population not - # covered by wastewater surveillance - norm_pop <- state_pop - n_subpops <- n_ww_sites + 1 - subpop_size <- c(pop_ww, state_pop - sum(pop_ww)) - } else { - message("Sum of wastewater catchment areas is greater than state pop") - norm_pop <- sum(pop_ww) - # If sum catchment areas > state pop, - # use sum of catchment area pop to normalize - n_subpops <- n_ww_sites # Only divide the state into n_site subpops - subpop_size <- pop_ww - } - - subpop_data <- list( - norm_pop = norm_pop, - n_subpops = n_subpops, - subpop_size = subpop_size - ) - return(subpop_data) -} #' Get count data integer sizes for stan #' diff --git a/R/initialization.R b/R/initialization.R index f5f17431..f146056c 100644 --- a/R/initialization.R +++ b/R/initialization.R @@ -30,10 +30,25 @@ get_inits_for_one_chain <- function(stan_data, stdev = 0.01) { init_list <- list( w = stats::rnorm(n_weeks - 1, 0, stdev), + offset_ref_log_r_t = stats::rnorm( + stan_data$n_subpops > 1, + stan_data$offset_ref_log_r_t_prior_mean, + stdev + ), + offset_ref_logit_i_first_obs = stats::rnorm( + stan_data$n_subpops > 1, + stan_data$offset_ref_logit_i_first_obs_prior_mean, + stdev + ), + offset_ref_initial_exp_growth_rate = stats::rnorm( + stan_data$n_subpops > 1, + stan_data$offset_ref_initial_exp_growth_rate_prior_mean, + stdev + ), eta_sd = abs(stats::rnorm(1, 0, stdev)), - eta_i_first_obs = abs(stats::rnorm(n_subpops, 0, stdev)), + eta_i_first_obs = abs(stats::rnorm((n_subpops - 1), 0, stdev)), sigma_i_first_obs = abs(stats::rnorm(1, 0, stdev)), - eta_initial_exp_growth_rate = abs(stats::rnorm(n_subpops, 0, stdev)), + eta_initial_exp_growth_rate = abs(stats::rnorm((n_subpops - 1), 0, stdev)), sigma_initial_exp_growth_rate = abs(stats::rnorm(1, 0, stdev)), autoreg_rt = abs(stats::rnorm( 1, @@ -41,20 +56,12 @@ get_inits_for_one_chain <- function(stan_data, stdev = 0.01) { (stan_data$autoreg_rt_a + stan_data$autoreg_rt_b), 0.05 )), - log_r_mu_intercept = stats::rnorm( + log_r_t_first_obs = stats::rnorm( 1, convert_to_logmean(1, stdev), convert_to_logsd(1, stdev) ), - error_site = matrix( - stats::rnorm(n_subpops * n_weeks, - mean = 0, - sd = stdev - ), - n_subpops, - n_weeks - ), - autoreg_rt_site = abs(stats::rnorm(1, 0.5, 0.05)), + autoreg_rt_subpop = abs(stats::rnorm(1, 0.5, 0.05)), autoreg_p_hosp = abs(stats::rnorm(1, 1 / 100, 0.001)), sigma_rt = abs(stats::rnorm(1, 0, stdev)), i_first_obs_over_n = @@ -99,23 +106,32 @@ get_inits_for_one_chain <- function(stan_data, stdev = 0.01) { log_sigma_generalized = stats::rnorm(1, log(0.05^(n_subpops - 1)), 0.5), log_phi = stats::rnorm(1, log(0.25), 0.1), log_scaling_factor = stats::rnorm(1, log(1), 0.1), - non_cent_spatial_dev_ns_mat = matrix( - stats::rnorm((n_subpops - 1) * n_weeks, - mean = 0, - sd = stdev - ), - (n_subpops - 1), - n_weeks - ), norm_vec_aux_site = stats::rnorm(n_weeks, 0, stdev), # Initialize the cholesky decomposition matrix if inferring # unstructured correlation matrix L_Omega = as.matrix(diag(2)) ) - if (stan_data$corr_structure_switch == 2) { init_list$L_Omega <- diag((n_subpops - 1)) } + if (stan_data$n_subpops > 1) { + init_list$error_rt_subpop <- matrix( + stats::rnorm((n_subpops - 1) * n_weeks, + mean = 0, + sd = stdev + ), + (n_subpops - 1), + n_weeks + ) + init_list$non_cent_spatial_dev_ns_mat <- matrix( + stats::rnorm((n_subpops - 1) * n_weeks, + mean = 0, + sd = stdev + ), + (n_subpops - 1), + n_weeks + ) + } return(init_list) } diff --git a/R/model_component_fwd_sim.R b/R/model_component_fwd_sim.R index c0d032f8..b5449646 100644 --- a/R/model_component_fwd_sim.R +++ b/R/model_component_fwd_sim.R @@ -339,7 +339,8 @@ downsample_ww_obs <- function(log_conc_lab_site, #' site combination #' #' @return a tidy dataframe containing observed wastewater concentrations -#' in log genome copies per mL for each site and lab at each time point +#' in log estimated genome copies per mL for each site and lab at each time +#' point format_ww_data <- function(log_obs_conc_lab_site, ot, ht, diff --git a/R/model_diagnostics.R b/R/model_diagnostics.R index 1f71b366..489f3a43 100644 --- a/R/model_diagnostics.R +++ b/R/model_diagnostics.R @@ -102,5 +102,17 @@ get_model_diagnostic_flags.default <- function(x, #' @family diagnostics #' @export parameter_diagnostics <- function(ww_fit, ...) { + ww_fit$fit$result$summary() +} + +#' Method for printing the CmdStan summary diagnostics for +#' wwinference_fit_object +#' +#' @param ww_fit An object of class wwinference_fit +#' @param ... additional arguments +#' +#' @family diagnostics +#' @export +summary_diagnostics <- function(ww_fit, ...) { ww_fit$fit$result$diagnostic_summary(quiet = TRUE) } diff --git a/R/preprocessing.R b/R/preprocessing.R index db622e95..29cbbcc6 100644 --- a/R/preprocessing.R +++ b/R/preprocessing.R @@ -46,12 +46,15 @@ preprocess_ww_data <- function(ww_data, lod_col_name = lod_col_name ) - + # Order by site population so the first site index corresponds largest pop + ww_data_ordered <- ww_data |> + dplyr::arrange(desc(.data$site_pop)) # Add some columns - ww_data_add_cols <- ww_data |> + ww_data_add_cols <- ww_data_ordered |> + dplyr::ungroup() |> dplyr::left_join( - ww_data |> + ww_data_ordered |> dplyr::distinct(.data$lab, .data$site) |> dplyr::mutate( lab_site_index = dplyr::row_number() @@ -59,7 +62,7 @@ preprocess_ww_data <- function(ww_data, by = c("lab", "site") ) |> dplyr::left_join( - ww_data |> + ww_data_ordered |> dplyr::distinct(.data$site) |> dplyr::mutate(site_index = dplyr::row_number()), by = "site" @@ -112,7 +115,7 @@ preprocess_count_data <- function(count_data, count_col_name = "daily_hosp_admits", pop_size_col_name = "state_pop") { # This checks that we have all the right column names - check_req_count_cols_present( + assert_req_count_cols_present( count_data, count_col_name, pop_size_col_name @@ -183,7 +186,7 @@ flag_ww_outliers <- function(ww_data, .data$n_data_points > !!threshold_n_dps ) |> dplyr::group_by(.data$lab_site_index) |> - dplyr::arrange(.data$date, "desc") |> + dplyr::arrange(desc(.data$date)) |> dplyr::mutate( log_conc = .data[[conc_col_name]], prev_log_conc = dplyr::lag(.data$log_conc, 1), diff --git a/R/sysdata.rda b/R/sysdata.rda deleted file mode 100644 index b7914563..00000000 Binary files a/R/sysdata.rda and /dev/null differ diff --git a/R/validate.R b/R/validate.R index 482e8c62..b9b5633b 100644 --- a/R/validate.R +++ b/R/validate.R @@ -31,6 +31,20 @@ validate_ww_conc_data <- function(ww_data, ) checkmate::assert_vector(ww_conc) + # Check for repeated wastewater observations within a site and lab + assert_cols_det_unique_row( + df = ww_data, + unique_key_columns = c("date", "site", "lab"), + arg = "lab-site-day", + add_err_msg = + c( + "Package expects either at most one ", + "wastewater observation per a given a site, lab, ", + "and sample collection date. Got date(s) with ", + "more than one observation for a given site and lab." + ) + ) + ww_lod <- ww_data |> dplyr::pull({ lod_col_name }) @@ -66,7 +80,6 @@ validate_ww_conc_data <- function(ww_data, assert_non_missingness(site_pops, arg, call) assert_elements_non_neg(site_pops, arg, call) - invisible() } @@ -146,18 +159,31 @@ validate_count_data <- function(count_data, #' been filtered and is ready to be passed into stan #' @param input_ww_data tibble containing the input wastewater data that has #' been filtered and is ready to be passed into stan +#' @param date_time_spine tibble mapping dates to time in days +#' @param lab_site_site_spine tibble mapping lab-sites to sites +#' @param site_subpop_spine tibble mapping sites to subpopulations +#' @param lab_site_subpop_spine tibble mapping lab-sites to subpopulations #' @param calibration_time integer indicating the calibration time #' @param forecast_date IS08 formatted date indicating the forecast date #' #' @return NULL, invisibly validate_both_datasets <- function(input_count_data, input_ww_data, + date_time_spine, + lab_site_site_spine, + site_subpop_spine, + lab_site_subpop_spine, calibration_time, forecast_date) { # check that you have sufficient count data for the calibration time assert_sufficient_days_of_data( input_count_data$date, - calibration_time + data_name = "input count data", + calibration_time, + add_err_msg = c( + "Check that the count data supplied has sufficient values", + " before the forecast date" + ) ) assert_elements_non_neg(calibration_time, @@ -190,12 +216,47 @@ validate_both_datasets <- function(input_count_data, ) # check that the time and date indices of both datasets line up + ww_data_sizes <- get_ww_data_sizes( + input_ww_data + ) + + ww_vals <- get_ww_indices_and_values( + input_ww_data = input_ww_data, + date_time_spine = date_time_spine, + lab_site_site_spine = lab_site_site_spine, + site_subpop_spine = site_subpop_spine, + lab_site_subpop_spine = lab_site_subpop_spine + ) + + input_ww_data_w_t <- input_ww_data |> + dplyr::mutate(t = ww_vals$ww_sampled_times) + assert_equivalent_indexing( input_count_data, - input_ww_data, + input_ww_data_w_t, arg1 = "count data", arg2 = "ww data" ) + + # Warn if sum(site pops) are greater than total pop. + # The package can handle this, but warn users that they may have an input + # data error. + sum_site_pops <- input_ww_data |> + dplyr::distinct(.data$site_pop) |> + sum() + total_pop <- input_count_data |> + dplyr::distinct(.data$total_pop) + if (sum_site_pops > total_pop) { + cli::cli_warn(c( + "The sum of the populations in the wastewater catchment areas is ", + "larger than the total population. While the model supports this ", + "we advise checking your input data to ensure it is specified ", + "correctly and to make sure that populations represented by the ", + "wastewater catchment areas are not overlapping (e.g. if both ", + " the larger wastewater treatment plant and the upstream manhole ", + "are included)." + )) + } invisible() } @@ -208,15 +269,18 @@ validate_both_datasets <- function(input_count_data, #' @param calibration_time integer indicating the calibration time #' @param count_data tibble containing the input count data ready to be passed #' to stan +#' @param tolerance numeric indicating the allowable difference between the +#' sum of the pmf and 1, default is `1e-6` #' @param arg name of the argument supplying the object #' @param call The calling environment to be reflected in the error message #' @return NULL, invisibly validate_pmf <- function(pmf, calibration_time, count_data, + tolerance = 1e-6, arg = "x", call = rlang::caller_env()) { - if (!all.equal(sum(pmf), 1)) { + if (!isTRUE(all.equal(sum(pmf), 1, tolerance = 1e-6))) { cli::cli_abort( c( "{.arg {arg}} does not sum to 1." diff --git a/R/wwinference.R b/R/wwinference.R index 1b7d97e9..734705fc 100644 --- a/R/wwinference.R +++ b/R/wwinference.R @@ -16,7 +16,7 @@ #' @param ww_data A dataframe containing the pre-processed, site-level #' wastewater concentration data for a model run. The dataframe must contain #' the following columns: `date`, `site`, `lab`, `log_genome_copies_per_ml`, -#' `lab_site_index`, `log_lod`, `below_lod`, `site_pop` `exclude` +#' `lab_site_index`, `log_lod`, `below_lod`, `site_pop` `exclude`. #' @param count_data A dataframe containing the pre-procssed, "global" (e.g. #' state) daily count data, pertaining to the number of events that are being #' counted on that day, e.g. number of daily cases or daily hospital admissions. @@ -31,13 +31,15 @@ #' `get_model_spec()`. The default here pertains to the `forecast_date` in the #' example data provided by the package, but this should be specified by the #' user based on the date they are producing a forecast -#' @param fit_opts The fit options, which in this case default to the -#' MCMC parameters as defined using `get_mcmc_options()`. This includes -#' the following arguments, which are passed to -#' [`$sample()`][cmdstanr::model-method-sample]: -#' the number of chains, the number of warmup -#' and sampling iterations, the maximum tree depth, the average acceptance -#' probability, and the stan PRNG seed +#' @param fit_opts MCMC fitting options, as a list of keys and values. +#' These are passed as keyword arguments to +#' [`compiled_model$sample()`][cmdstanr::model-method-sample]. +#' Where no option is specified, [wwinference()] will fall back first on a +#' package-specific default value given by [get_mcmc_options()], if one exists. +#' If no package-specific default exists, [wwinference()] will fall back on +#' the default value defined in [`$sample()`][cmdstanr::model-method-sample]. +#' See the documentation for [`$sample()`][cmdstanr::model-method-sample] for +#' details on available options. #' @param generate_initial_values Boolean indicating whether or not to specify #' the initialization of the sampler, default is `TRUE`, meaning that #' initialization lists will be generated and passed as the `init` argument @@ -132,24 +134,27 @@ #' calibration_time <- 90 #' forecast_horizon <- 28 #' include_ww <- 1 -#' ww_fit <- wwinference(input_ww_data, -#' input_count_data, +#' +#' ww_fit <- wwinference( +#' ww_data = input_ww_data, +#' count_data = input_count_data, +#' forecast_date = forecast_date, +#' calibration_time = calibration_time, +#' forecast_horizon = forecast_horizon, #' model_spec = get_model_spec( -#' forecast_date = forecast_date, -#' calibration_time = calibration_time, -#' forecast_horizon = forecast_horizon, #' generation_interval = generation_interval, -#' inf_to_count_delay = inf_to_coutn_delay, +#' inf_to_count_delay = inf_to_count_delay, #' infection_feedback_pmf = infection_feedback_pmf, #' params = params #' ), -#' fit_opts = get_mcmc_options( +#' fit_opts = list( #' iter_warmup = 250, #' iter_sampling = 250, -#' n_chains = 2 +#' chains = 2 #' ) #' ) #' } +#' #' @rdname wwinference #' @aliases wwinference_fit wwinference <- function(ww_data, @@ -158,36 +163,80 @@ wwinference <- function(ww_data, calibration_time = 90, forecast_horizon = 28, model_spec = get_model_spec(), - fit_opts = get_mcmc_options(), + fit_opts = list(), generate_initial_values = TRUE, initial_values_seed = NULL, compiled_model = compile_model(), dist_matrix = NULL, corr_structure_switch = 0) { + include_ww <- as.integer(model_spec$include_ww) + if (is.null(forecast_date)) { cli::cli_abort( "The user must specify a forecast date" ) } - # Check that data is compatible with specifications - assert_no_dates_after_max(ww_data$date, forecast_date) + # If there is no wastewater data, set include_ww to 0 + if (is.null(ww_data) || nrow(ww_data) == 0) { + cli::cli_warn( + c( + "No wastewater data was passed to the model.", + "The model will default to fitting only to the count data" + ) + ) + include_ww <- 0 + } + # If include_ww == 0, we will specify an empty dataset + if (include_ww == 0) { + ww_data <- NULL + } + + + fit_opts_use <- get_mcmc_options() # get defaults + # this overwrites defaults with all and only the values the user sets in + # `fit_opts` + fit_opts_use[names(fit_opts)] <- fit_opts + + # Check that the fit options passed to wwinference are valid cmdstanr::sample + # arguments + checkmate::assert_names(names(fit_opts), + subset.of = formalArgs(compiled_model$sample) + ) + + + ## Check that data is compatible with specifications + if (!is.null(ww_data)) { + assert_no_dates_after_max(ww_data$date, forecast_date) + } assert_no_dates_after_max(count_data$date, forecast_date) + # Get the input count data that will get passed directly to stan input_count_data <- get_input_count_data_for_stan( count_data, calibration_time ) last_count_data_date <- max(input_count_data$date, na.rm = TRUE) first_count_data_date <- min(input_count_data$date, na.rm = TRUE) + + # Get the input wastewater data that will be passed directly to stan input_ww_data <- get_input_ww_data_for_stan( ww_data, first_count_data_date, last_count_data_date, calibration_time ) - raw_input_data <- list( + # Get the table that maps 1-indexed time to dates + date_time_spine <- get_date_time_spine( + forecast_date = forecast_date, input_count_data = input_count_data, + last_count_data_date = last_count_data_date, + forecast_horizon = forecast_horizon, + calibration_time = calibration_time + ) + + # Get lab_site_site_spine + lab_site_site_spine <- get_lab_site_site_spine( input_ww_data = input_ww_data ) @@ -198,11 +247,37 @@ wwinference <- function(ww_data, ) } + # Get site to subpop spine + site_subpop_spine <- get_site_subpop_spine( + input_ww_data = input_ww_data, + input_count_data = input_count_data + ) + + lab_site_subpop_spine <- get_lab_site_subpop_spine( + lab_site_site_spine = lab_site_site_spine, + site_subpop_spine = site_subpop_spine + ) + + + raw_input_data <- list( + input_count_data = input_count_data, + input_ww_data = input_ww_data, + date_time_spine = date_time_spine, + lab_site_site_spine = lab_site_site_spine, + site_subpop_spine = site_subpop_spine, + lab_site_subpop_spine = lab_site_subpop_spine + ) # If checks pass, create stan data object stan_data_list <- get_stan_data( input_count_data = input_count_data, input_ww_data = input_ww_data, + date_time_spine = date_time_spine, + lab_site_site_spine = lab_site_site_spine, + site_subpop_spine = site_subpop_spine, + lab_site_subpop_spine = lab_site_subpop_spine, + last_count_data_date = last_count_data_date, + first_count_data_date = first_count_data_date, forecast_date = forecast_date, calibration_time = calibration_time, forecast_horizon = forecast_horizon, @@ -210,7 +285,7 @@ wwinference <- function(ww_data, inf_to_count_delay = model_spec$inf_to_count_delay, infection_feedback_pmf = model_spec$infection_feedback_pmf, params = model_spec$params, - include_ww = as.numeric(model_spec$include_ww), + include_ww = include_ww, compute_likelihood = as.integer(model_spec$compute_likelihood), dist_matrix = dist_matrix, corr_structure_switch = corr_structure_switch @@ -224,7 +299,7 @@ wwinference <- function(ww_data, if (generate_initial_values) { withr::with_seed(initial_values_seed, { init_lists <- lapply( - 1:fit_opts$n_chains, + 1:fit_opts_use$chains, \(x) { get_inits_for_one_chain(stan_data_list) } @@ -240,7 +315,7 @@ wwinference <- function(ww_data, fit <- safe_fit_model( compiled_model = compiled_model, stan_data_list = stan_data_list, - fit_opts = fit_opts, + fit_opts = fit_opts_use, init_lists = init_lists ) @@ -314,7 +389,7 @@ print.wwinference_fit <- function(x, ...) { cat("wwinference_fit object\n") cat("N of WW sites :", x$stan_data_list$n_ww_sites, "\n") cat("N of unique lab-site pairs :", x$stan_data_list$n_ww_lab_sites, "\n") - cat("State population :", formatC( + cat("Total population :", formatC( x$stan_data_list$state_pop, format = "d" ), "\n") @@ -349,15 +424,18 @@ fit_model <- function(compiled_model, stan_data_list, fit_opts, init_lists) { - fit <- compiled_model$sample( - data = stan_data_list, - init = init_lists, - seed = fit_opts$seed, - iter_sampling = fit_opts$iter_sampling, - iter_warmup = fit_opts$iter_warmup, - max_treedepth = fit_opts$max_treedepth, - chains = fit_opts$n_chains, - parallel_chains = fit_opts$n_chains + args_for_stan_sampling <- + c( + list( + data = stan_data_list, + init = init_lists + ), + fit_opts + ) + + fit <- do.call( + compiled_model$sample, + args_for_stan_sampling ) return(fit) @@ -368,42 +446,45 @@ fit_model <- function(compiled_model, #' #' @description #' This function returns a list of MCMC settings to pass to the -#' `cmdstanr::sample()` function to fit the model. The default settings are -#' specified for production-level runs, consider adjusting to optimize -#' for speed while iterating. +#' [`$sample()`][cmdstanr::model-method-sample] function to fit the model. +#' The default settings are specified for production-level runs. +#' All input arguments to [`$sample()`][cmdstanr::model-method-sample] +#' are configurable by the user. See +#' [`$sample()`][cmdstanr::model-method-sample] documentation +#' for details of the available arguments. #' #' #' @param iter_warmup integer indicating the number of warm-up iterations, -#' default is `750` +#' default is `750`. #' @param iter_sampling integer indicating the number of sampling iterations, -#' default is `500` -#' @param n_chains integer indicating the number of MCMC chains to run, default -#' is `4` -#' @param seed set of integers indicating the random seed of the stan sampler, -#' default is NULL +#' default is `500`. +#' @param seed integer, A seed for the (P)RNG to pass to CmdStan. In the case +#' of multi-chain sampling the single seed will automatically be augmented by +#' the the run (chain) ID so that each chain uses a different seed. +#' Default is `NULL`. +#' @param chains integer indicating the number of MCMC chains to run, default +#' is `4`. #' @param adapt_delta float between 0 and 1 indicating the average acceptance -#' probability, default is `0.95` +#' probability, default is `0.95`. #' @param max_treedepth integer indicating the maximum tree depth of the -#' sampler, default is 12 +#' sampler, default is 12. #' -#' @return a list of mcmc settings with the values given by the function +#' @return A list of MCMC settings with the values given by the function. #' arguments -#' @export #' -#' @examples -#' mcmc_settings <- get_mcmc_options() +#' @export get_mcmc_options <- function( iter_warmup = 750, iter_sampling = 500, - n_chains = 4, seed = NULL, + chains = 4, adapt_delta = 0.95, max_treedepth = 12) { mcmc_settings <- list( iter_warmup = iter_warmup, iter_sampling = iter_sampling, - n_chains = n_chains, seed = seed, + chains = chains, adapt_delta = adapt_delta, max_treedepth = max_treedepth ) diff --git a/README.md b/README.md index 74d51618..6ce78bb5 100644 --- a/README.md +++ b/README.md @@ -1,17 +1,20 @@ -# `wwinference`: joint inference and forecasting from wastewater and epidemiological indicators +# `wwinference`: joint inference and forecasting
from wastewater and epidemiological count data wwinference website > [!CAUTION] -> This project is a work-in-progress. Despite this project's early stage, all development is in public as part of the Center for Forecasting and Outbreak Analytics' goals around open development. Questions and suggestions are welcome through GitHub issues or a PR. +> This package is still in development. +> Note the package is still flagged as in development, though the authors plan on using it for production work in the coming weeks. +> All development is in public as part of the Center for Forecasting and Outbreak Analytics' goals around open development. +> Questions and suggestions are welcome through GitHub issues or a PR. > ## Overview -This project is an in-development R package, `{wwinference}` that estimates latent incident infections from wastewater concentration data and data on epidemiological indicators, with an initial assumed structure that the wastewater concentration data comes from subsets of the population contributing to the "global" epidemiological indicator data, such as hospital admissions. +This project is an in-development R package, `{wwinference}` that estimates latent incident infections from wastewater concentration data and data on epidemiological count data, with an initial assumed structure that the wastewater concentration data comes from subsets of the population contributing to the "global" epidemiological count data, such as hospital admissions. In brief, our model builds upon [EpiNow2](https://github.com/epiforecasts/EpiNow2/tree/main), a widely used [R](https://www.r-project.org/) and [Stan](https://mc-stan.org/) package for Bayesian epidemiological inference. -We modify EpiNow2 to add model for the observed viral RNA concentration in wastewater, adding hierarchical structure to link the subpopulations represented by the osberved wastewater concentrations in each wastewater catchment area. -See our Model Definition page for a mathematical description of the generative model, and the Getting Stated vignette to see an example of how to run the inference model on simulated data. +We modify EpiNow2 to add a model for the observed viral RNA concentration in wastewater, adding hierarchical structure to link the subpopulations represented by the observed wastewater concentrations in each wastewater catchment area. -The intention is for {wwinference} to provide a user-friendly R-package interface for running forecasting models that use wastewater concentrations combined with other more traditional epidemiological signals such as cases or hospital admissions. It aims to be a re-implementation of the modeling components contained in the [wastewater-informed-covid-forecasting](https://github.com/CDCgov/wastewater-informed-covid-forecasting) project repository, with +The intention is for {wwinference} to provide a user-friendly R-package interface for running forecasting models that use wastewater concentrations combined with other more traditional epidemiological signals such as cases or hospital admissions. +It aims to be a re-implementation of the modeling components contained in the [wastewater-informed-covid-forecasting](https://github.com/CDCgov/wastewater-informed-covid-forecasting) project repository, with an emphasis here on making it easier for users to supply their own data. We recommend reading the [model definition](model_definition.md) to learn more about how the model is structured and running the ["Getting Started" vignette](vignettes/wwinference.Rmd) for an example of how to fit the model to simulated data of COVID-19 hospital admissions and wastewater concentrations. @@ -20,13 +23,14 @@ This will help make clear the data requirements and how to structure this data t ## Project Admins - Kaitlyn Johnson (kaitejohnson) - Dylan Morris (dylanhmorris) +- George Vega Yon (gvegayon) - Sam Abbott (seabbs) - Damon Bayer (damonbayer) # Installing and running code ## Install R -To run our code, you will need a working installation of [R](https://www.r-project.org/) (version `4.3.0` or later). You can find instructions for installing R on the official [R project website](https://www.r-project.org/). +To run our code, you will need a working installation of [R](https://www.r-project.org/) (version `4.1.0` or later). You can find instructions for installing R on the official [R project website](https://www.r-project.org/). ## Install `cmdstanr` and `CmdStan` We do inference from our models using [`CmdStan`](https://mc-stan.org/users/interfaces/cmdstan) (version `2.35.0` or later) via its R interface [`cmdstanr`](https://mc-stan.org/cmdstanr/) (version `0.8.0` or later). @@ -74,6 +78,10 @@ Confirm that package installation has succeeded by running the following within library(wwinference) ``` +## Contributing to this package +We welcome and encourage contributions. Open an issue in the repository to request changes. +To contribute, fork the repository locally and open a pull request into the `main` branch. + ## Public Domain Standard Notice This repository constitutes a work of the United States Government and is not subject to domestic copyright protection under 17 USC § 105. This repository is in @@ -83,6 +91,18 @@ All contributions to this repository will be released under the CC0 dedication. submitting a pull request you are agreeing to comply with this waiver of copyright interest. +## Contributing Standard Notice +Anyone is encouraged to contribute to the repository by [forking](https://help.github.com/articles/fork-a-repo) +and submitting a pull request. (If you are new to GitHub, you might start with a +[basic tutorial](https://help.github.com/articles/set-up-git).) By contributing +to this project, you grant a world-wide, royalty-free, perpetual, irrevocable, +non-exclusive, transferable license to all users under the terms of the +[Apache Software License v2](http://www.apache.org/licenses/LICENSE-2.0.html) or +later. + +All comments, messages, pull requests, and other submissions received through +CDC including this GitHub page may be subject to applicable federal law, including but not limited to the Federal Records Act, and may be archived. Learn more at [http://www.cdc.gov/other/privacy.html](http://www.cdc.gov/other/privacy.html). + ## License Standard Notice The repository utilizes code licensed under the terms of the Apache Software License and therefore is licensed under ASL v2 or later. @@ -107,18 +127,6 @@ information. All material and community participation is covered by the and [Code of Conduct](code-of-conduct.md). For more information about CDC's privacy policy, please visit [http://www.cdc.gov/other/privacy.html](https://www.cdc.gov/other/privacy.html). -## Contributing Standard Notice -Anyone is encouraged to contribute to the repository by [forking](https://help.github.com/articles/fork-a-repo) -and submitting a pull request. (If you are new to GitHub, you might start with a -[basic tutorial](https://help.github.com/articles/set-up-git).) By contributing -to this project, you grant a world-wide, royalty-free, perpetual, irrevocable, -non-exclusive, transferable license to all users under the terms of the -[Apache Software License v2](http://www.apache.org/licenses/LICENSE-2.0.html) or -later. - -All comments, messages, pull requests, and other submissions received through -CDC including this GitHub page may be subject to applicable federal law, including but not limited to the Federal Records Act, and may be archived. Learn more at [http://www.cdc.gov/other/privacy.html](http://www.cdc.gov/other/privacy.html). - ## Records Management Standard Notice This repository is not a source of government records, but is a copy to increase collaboration and collaborative potential. All government records will be diff --git a/_pkgdown.yml b/_pkgdown.yml index 347325a6..cbc2ae70 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -1,4 +1,4 @@ url: https://cdcgov.github.io/ww-inference-model/ template: bootstrap: 5 - math-rendering: katex + math-rendering: mathjax diff --git a/data-raw/test_data.R b/data-raw/test_data.R deleted file mode 100644 index 25478349..00000000 --- a/data-raw/test_data.R +++ /dev/null @@ -1,101 +0,0 @@ -############ -# Make entirely fake stan input data via prior-predictive generated quantities -############ - -hosp_data <- wwinference::hosp_data -ww_data <- wwinference::ww_data -params <- wwinference::get_params( - fs::path_package("extdata", "example_params.toml", - package = "wwinference" - ) -) - - -# Data pre-processing -------------------------------------------------------- -ww_data_preprocessed <- wwinference::preprocess_ww_data( - ww_data, - conc_col_name = "log_genome_copies_per_ml", - lod_col_name = "log_lod" -) - -hosp_data_preprocessed <- wwinference::preprocess_count_data( - hosp_data, - count_col_name = "daily_hosp_admits", - pop_size_col_name = "state_pop" -) - -ww_data_to_fit <- wwinference::indicate_ww_exclusions( - ww_data_preprocessed, - outlier_col_name = "flag_as_ww_outlier", - remove_outliers = TRUE -) - -forecast_date <- "2023-12-06" -calibration_time <- 90 -forecast_horizon <- 28 -generation_interval <- wwinference::default_covid_gi -inf_to_hosp <- wwinference::default_covid_inf_to_hosp - -dist_matrix <- as.matrix( - dist( - data.frame( - x = c(85, 37, 36, 7), - y = c(12, 75, 75, 96) - ), - diag = TRUE, - upper = TRUE - ) -) - -# Assign infection feedback equal to the generation interval -infection_feedback_pmf <- generation_interval - -model_spec <- wwinference::get_model_spec( - generation_interval = generation_interval, - inf_to_count_delay = inf_to_hosp, - infection_feedback_pmf = infection_feedback_pmf, - params = params -) - -mcmc_options <- wwinference::get_mcmc_options( - seed = 55, - iter_warmup = 25, - iter_sampling = 25, - n_chains = 1 -) - -generate_initial_values <- TRUE - -model_test_data <- list( - ww_data = ww_data_to_fit, - count_data = hosp_data_preprocessed, - forecast_date = forecast_date, - calibration_time = calibration_time, - forecast_horizon = forecast_horizon, - dist_matrix = dist_matrix, - model_spec = model_spec, - fit_opts = mcmc_options, - generate_initial_values = generate_initial_values -) - -withr::with_seed(5, { - fit <- do.call( - wwinference::wwinference, - model_test_data - ) -}) - - -# Generate the last draw of a very short run for testing -test_fit_last_draw <- posterior::subset_draws( - fit$fit$result$draws(), - draw = 25 -) -# Save the data as internal data. Every time the model changes, will need -# to regenerate this testing data. -usethis::use_data( - model_test_data, - test_fit_last_draw, - internal = TRUE, - overwrite = TRUE -) diff --git a/data-raw/vignette_data.R b/data-raw/vignette_data.R index cf3a7869..f918250c 100644 --- a/data-raw/vignette_data.R +++ b/data-raw/vignette_data.R @@ -1,7 +1,18 @@ set.seed(1) simulated_data <- wwinference::generate_simulated_data() -hosp_data <- simulated_data$hosp_data -ww_data <- simulated_data$ww_data +hosp_data_from_sim <- simulated_data$hosp_data +ww_data_from_sim <- simulated_data$ww_data +# Add some columns and reorder sites to ensure package works as expected +# even if sites are not in order +ww_data <- ww_data_from_sim |> + dplyr::mutate( + "location" = "example state", + "site" = .data$site + 1 + ) |> + dplyr::ungroup() |> + dplyr::arrange(desc(.data$site)) +hosp_data <- hosp_data_from_sim |> + dplyr::mutate("location" = "example state") hosp_data_eval <- simulated_data$hosp_data_eval rt_site_data <- simulated_data$rt_site_data rt_global_data <- simulated_data$rt_global_data @@ -12,6 +23,7 @@ simulated_data_ind <- wwinference::generate_simulated_data( use_spatial_corr = FALSE, aux_site_bool = FALSE ) + hosp_data_ind <- simulated_data_ind$hosp_data ww_data_ind <- simulated_data_ind$ww_data hosp_data_eval_ind <- simulated_data_ind$hosp_data_eval diff --git a/data/hosp_data.rda b/data/hosp_data.rda index 96261a49..fdde74ca 100644 Binary files a/data/hosp_data.rda and b/data/hosp_data.rda differ diff --git a/data/hosp_data_eval.rda b/data/hosp_data_eval.rda index 119fd25f..705b474f 100644 Binary files a/data/hosp_data_eval.rda and b/data/hosp_data_eval.rda differ diff --git a/data/hosp_data_eval_ind.rda b/data/hosp_data_eval_ind.rda index bffbd6e3..d5ecacaf 100644 Binary files a/data/hosp_data_eval_ind.rda and b/data/hosp_data_eval_ind.rda differ diff --git a/data/hosp_data_ind.rda b/data/hosp_data_ind.rda index b2590a6a..24161d97 100644 Binary files a/data/hosp_data_ind.rda and b/data/hosp_data_ind.rda differ diff --git a/data/rt_global_data.rda b/data/rt_global_data.rda index b5227057..896ed055 100644 Binary files a/data/rt_global_data.rda and b/data/rt_global_data.rda differ diff --git a/data/rt_global_data_ind.rda b/data/rt_global_data_ind.rda index 5e5a01fe..275b2980 100644 Binary files a/data/rt_global_data_ind.rda and b/data/rt_global_data_ind.rda differ diff --git a/data/rt_site_data.rda b/data/rt_site_data.rda index bae45017..74f5fa35 100644 Binary files a/data/rt_site_data.rda and b/data/rt_site_data.rda differ diff --git a/data/rt_site_data_ind.rda b/data/rt_site_data_ind.rda index 9375f36c..2673af7e 100644 Binary files a/data/rt_site_data_ind.rda and b/data/rt_site_data_ind.rda differ diff --git a/data/ww_data.rda b/data/ww_data.rda index f3e6b925..eadf5bc5 100644 Binary files a/data/ww_data.rda and b/data/ww_data.rda differ diff --git a/data/ww_data_ind.rda b/data/ww_data_ind.rda index 766aa522..94871746 100644 Binary files a/data/ww_data_ind.rda and b/data/ww_data_ind.rda differ diff --git a/inst/extdata/example_params.toml b/inst/extdata/example_params.toml index 1b254b1c..1975fc60 100644 --- a/inst/extdata/example_params.toml +++ b/inst/extdata/example_params.toml @@ -30,10 +30,24 @@ sigma_initial_exp_growth_rate_prior_sd = 0.05 autoreg_rt_a = 2 # shape1 parameter of autoreg term on Rt trend autoreg_rt_b = 40 # shape2 parameter of autoreg on Rt trend # mean = a/(a+b) = 0.05, stdv = sqrt(a)/b = sqrt(2)/40 = 0.035 -autoreg_rt_site_a = 1 # shape1 parameter of autoreg term on difference between - # R(t) state and R(t) site -autoreg_rt_site_b = 4 # shape2 parameter of autoreg term on difference between -# R(t) state and R(t) site +autoreg_rt_subpop_a = 1 # shape1 parameter of autoreg term on difference between + # R(t) ref and R(t) subpop +autoreg_rt_subpop_b = 4 # shape2 parameter of autoreg term on difference between +# R(t) ref and R(t) subpop + +# Normal prior on fixed offset between central log scale R(t) and reference pop +offset_ref_log_r_t_prior_mean = 0 +offset_ref_log_r_t_prior_sd = 0.2 + +# Normal prior on fixed offset between central logit scale i_first_obs/n and reference pop i_first_obs/n +offset_ref_logit_i_first_obs_prior_mean = 0 +offset_ref_logit_i_first_obs_prior_sd = 0.25 + +# Normal prior on fixed offset between central initial exponential growth rate +# and reference population initial exponential growth rate +offset_ref_initial_exp_growth_rate_prior_mean = 0 +offset_ref_initial_exp_growth_rate_prior_sd = 0.025 + autoreg_p_hosp_a = 1 # shape1 parameter of autoreg term on IHR(t) trend autoreg_p_hosp_b = 100 # shape2 parameter of autoreg term on IHR(t) trend eta_sd_sd = 0.01 diff --git a/inst/stan/wwinference.stan b/inst/stan/wwinference.stan index d1b56bdd..41030f22 100644 --- a/inst/stan/wwinference.stan +++ b/inst/stan/wwinference.stan @@ -25,7 +25,7 @@ data { vector[if_l] infection_feedback_pmf; // infection feedback pmf int ot; // maximum time index for the hospital admissions (max number of days we could have observations) int oht; // number of days that we have hospital admissions observations - int n_subpops; // number of WW sites + int n_subpops; // number of modeled subpopulations int n_ww_lab_sites; // number of unique ww-lab combos int n_censored; // numer of observed WW data points that are below the LOD int n_uncensored; //number not below LOD @@ -41,15 +41,14 @@ data { vector[n_subpops] subpop_size; // the population sizes for each subpopulation real norm_pop; array[owt] int ww_sampled_times; // a list of all of the days on which WW is sampled - // will be mapped to the corresponding sites (ww_sampled_sites) + // will be mapped to the corresponding subpops (ww_sampled_subpops) array[oht] int hosp_times; // the days on which hospital admissions are observed - array[owt] int ww_sampled_sites; // vector of unique sites in order of the sampled times - array[owt] int ww_sampled_lab_sites; // vector of unique lab-site combos i - // n order of the sampled times + array[owt] int ww_sampled_subpops; // vector of unique subpops in order of the sampled times + array[owt] int ww_sampled_lab_sites; // vector mapping the subpops to lab-site combos array[n_censored] int ww_censored; // times that the WW data is below the LOD array[n_uncensored] int ww_uncensored; // time that WW data is above LOD vector[owt] ww_log_lod; // The limit of detection in that site at that time point - array[n_ww_lab_sites] int lab_site_to_site_map; // which lab sites correspond to which sites + array[n_ww_lab_sites] int lab_site_to_subpop_map; // which lab sites correspond to which subpops array[oht] int hosp; // observed hospital admissions array[ot + ht] int day_of_week; // integer vector with 1-7 corresponding to the weekday vector[owt] log_conc; // observed concentration of viral genomes in WW @@ -57,10 +56,17 @@ data { int include_ww; // 1= include wastewater data in likelihood calculation int include_hosp; // 1 = fit to hosp, 0 = only fit wastewater model vector[6] viral_shedding_pars;// tpeak, viral peak, shedding duration mean and sd + real offset_ref_log_r_t_prior_mean; + real offset_ref_log_r_t_prior_sd; + real offset_ref_logit_i_first_obs_prior_mean; + real offset_ref_logit_i_first_obs_prior_sd; + real offset_ref_initial_exp_growth_rate_prior_mean; + real offset_ref_initial_exp_growth_rate_prior_sd; + real autoreg_rt_a; real autoreg_rt_b; - real autoreg_rt_site_a; - real autoreg_rt_site_b; + real autoreg_rt_subpop_a; + real autoreg_rt_subpop_b; real autoreg_p_hosp_a; real autoreg_p_hosp_b; real inv_sqrt_phi_prior_mean; @@ -128,29 +134,51 @@ transformed data { // The parameters accepted by the model. parameters { - vector[n_weeks-1] w; // weekly random walk of state-level mean baseline R(t) (log scale) + vector[n_weeks-1] w; // Normal(0,1) noise for the weekly random + // walk on reference subpopulation log R(t) real eta_sd; - real autoreg_rt;// coefficient on AR process in R(t) - real log_r_mu_intercept; // state-level mean baseline reproduction number estimate (log) at t=0 - real sigma_rt; // magnitude of site level variation from state level - real autoreg_rt_site; + vector[n_subpops > 1 ? 1 : 0] offset_ref_log_r_t; + // offset of reference population log R(t) from central dynamic + vector[n_subpops > 1 ? 1 : 0] offset_ref_logit_i_first_obs; + // offset of reference population per capita infections + // at the time of first observation from central value + vector[n_subpops > 1 ? 1 : 0] offset_ref_initial_exp_growth_rate; + // offset of reference population initial exponential growth rate + // from central value + real autoreg_rt; // autoregressive coefficient for + // AR process on first differences in log R(t) + real log_r_t_first_obs; // central log R(t) at the time of + // the first observation + real sigma_rt; // magnitude of subpopulation level + // R(t) heterogeneity + real autoreg_rt_subpop; real autoreg_p_hosp; - real i_first_obs_over_n; // per capita - // infection incidence on the day of the first observed infection - vector[n_subpops] eta_i_first_obs; // z-score on logit scale of site - // initial per capita infection incidence relative to state value - real sigma_i_first_obs; // stdev between logit state and site initial - // per capita infection incidence - vector[n_subpops] eta_initial_exp_growth_rate; // z scores of individual site level initial exponential growth rates - real sigma_initial_exp_growth_rate; // sd of distribution of site level initial exp growth rates - real mean_initial_exp_growth_rate; // mean of distribution of site level initial exp growth rates + matrix[n_subpops-1, n_subpops > 1 ? n_weeks : 0] error_rt_subpop; + real i_first_obs_over_n; // mean per capita + // infection incidence on the day of the first observation + vector[n_subpops - 1] eta_i_first_obs; // z-score on logit scale + // of subpopulation per capita infection incidences + // on the day of the first observation + real sigma_i_first_obs; // logit scale variability + // in per capita incidence at time of first observation + real mean_initial_exp_growth_rate; // central initial exponential growth + // rate across all subpopulations + real sigma_initial_exp_growth_rate; // variability of + // subpopulation level initial exponential growth rates + vector[n_subpops - 1] eta_initial_exp_growth_rate; // z scores of + // individual subpopulation-level initial exponential growth rates real inv_sqrt_phi_h; - real mode_sigma_ww_site; //mode of site level stdev - real sd_log_sigma_ww_site; // stdev of the log site level stdev - vector[n_ww_lab_sites] eta_log_sigma_ww_site; // let each lab-site combo have its own observation error + real mode_sigma_ww_site; // mode of site level wastewater + // observation error standard deviations + real sd_log_sigma_ww_site; // sd of the log site level + // wastewater observation error standard deviations + vector[n_ww_lab_sites] eta_log_sigma_ww_site; // z-scores + // of the log site level wastewater observation error standard + // deviations real p_hosp_mean; // Estimated mean IHR - vector[tot_weeks] p_hosp_w; // weekly random walk for IHR - real p_hosp_w_sd; // Estimated IHR sd + vector[tot_weeks] p_hosp_w; // weekly Normal(0, 1) + // stochastic process noise for IHR + real p_hosp_w_sd; // Estimated IHR stochasti cprocess sd real t_peak; // time to viral load peak in shedding real viral_peak; // log10 peak viral load shed /mL real dur_shed; // duration of detectable viral shedding @@ -167,7 +195,7 @@ parameters { real log_sigma_generalized; real log_phi; real log_scaling_factor; - matrix[n_subpops-1,n_weeks] non_cent_spatial_dev_ns_mat; + matrix[n_subpops-1, n_subpops > 1 ? n_weeks: 0] non_cent_spatial_dev_ns_mat; vector[n_weeks] norm_vec_aux_site; cholesky_factor_corr[corr_structure_switch == 2 ? n_subpops-1 : 2] L_Omega; //---------------------------------------------------------------------------- @@ -185,18 +213,19 @@ transformed parameters { row_vector [ot + uot + ht] model_net_i; // number of net infected individuals shedding on each day (sum of individuals in dift stages of infection) real phi_h = inv_square(inv_sqrt_phi_h); vector[n_ww_lab_sites] sigma_ww_site; - vector[n_weeks] log_r_mu_t_in_weeks; // log of state level mean R(t) in weeks - vector[ot + ht] unadj_r; // state level R(t) before damping - matrix[n_subpops, ot+ht] r_site_t; // site_level R(t) - row_vector[ot + ht] unadj_r_site_t; // site_level R(t) before damping - row_vector[ot + uot + ht] new_i_site; // site level incident infections per capita + vector[n_weeks] log_r_t_in_weeks; // global unadjusted weekly log R(t) + matrix[n_subpops, ot+ht] r_subpop_t; // matrix of subpopulation level R(t) + row_vector[ot + ht] unadj_r_subpop_t; // subpopulation level R(t) before damping -- temp vector + vector[n_weeks] log_r_subpop_t_in_weeks; // subpop level R(t) in weeks-- temp vector + real log_i0_subpop; // subpop level log i0/n -- temp var + row_vector[ot + uot + ht] new_i_subpop; // subpopulation level incident infections per capita -- temp vector real pop_fraction; // proportion of state population that the subpopulation represents vector[ot + uot + ht] state_inf_per_capita = rep_vector(0, uot + ot + ht); // state level incident infections per capita matrix[n_subpops, ot + ht] model_log_v_ot; // expected observed viral genomes/mL at all observed and forecasted times real g = pow(log10_g, 10); // Estimated genomes shed per infected individual - vector[n_subpops] i_first_obs_over_n_site; + vector[n_subpops] i_first_obs_over_n_subpop; // per capita infection incidence at the first observed time - vector[n_subpops] initial_exp_growth_rate_site; + vector[n_subpops] initial_exp_growth_rate_subpop; // site level unobserved period growth rate // Site spatial trans params-------------------------------------------------- @@ -207,25 +236,18 @@ transformed parameters { matrix[n_subpops-1,n_subpops-1] norm_omega; matrix[n_subpops-1,n_subpops-1] sigma_mat; matrix[n_subpops-1,n_weeks] spatial_dev_ns_mat; - matrix[n_subpops-1,n_weeks] log_r_site_t_in_weeks_matrix; - vector[n_weeks] log_r_aux_site_t_in_weeks; - matrix[n_subpops, n_weeks] combined_log_r_site_t_in_weeks; - vector[n_weeks] log_r_site_t_in_weeks_vector; + matrix[n_subpops-1,n_weeks] log_r_subpop_t_in_weeks_matrix; + //---------------------------------------------------------------------------- - // State-leve R(t) AR + RW implementation: - log_r_mu_t_in_weeks = diff_ar1(log_r_mu_intercept, - autoreg_rt, - eta_sd, - w, - 0); - unadj_r = ind_m*log_r_mu_t_in_weeks; - unadj_r = exp(unadj_r); + // AR(1) process on first differences in "global" + // (central) R(t) + log_r_t_in_weeks = diff_ar1(log_r_t_first_obs, + autoreg_rt, eta_sd, w, 0); // Shedding kinetics trajectory s = get_vl_trajectory(t_peak, viral_peak, dur_shed, gt_max); - // Site level spatial Rt------------------------------------------------------ if (corr_structure_switch == 0){ // If no dist matrix given, use n_sites + 1 = n_subpops were all ind. @@ -243,66 +265,74 @@ transformed parameters { else { reject("Model should not reach this point. Invalid corr_structure_switch value. Check model code"); } - sigma_mat = pow(sigma_generalized, 1.0 / (n_subpops - 1)) * norm_omega; - for (i in 1:n_weeks) { - spatial_dev_ns_mat[,i] = cholesky_decompose(sigma_mat) * non_cent_spatial_dev_ns_mat[,i]; + + if(n_subpops > 1){ + sigma_mat = pow(sigma_generalized, 1.0 / (n_subpops - 1)) * norm_omega; + for (i in 1:n_weeks) { + spatial_dev_ns_mat[,i] = cholesky_decompose(sigma_mat) * non_cent_spatial_dev_ns_mat[,i]; + } + + log_r_subpop_t_in_weeks_matrix = construct_spatial_rt( + log_r_t_in_weeks, + autoreg_rt_subpop, + spatial_dev_ns_mat + ); } - log_r_site_t_in_weeks_matrix = construct_spatial_rt( - log_r_mu_t_in_weeks, - autoreg_rt_site, - spatial_dev_ns_mat - ); - //---------------------------------------------------------------------------- - // AUX site Rt---------------------------------------------------------------- - log_r_aux_site_t_in_weeks = construct_aux_rt( - log_r_mu_t_in_weeks, - autoreg_rt_site, - scaling_factor, - sqrt(sigma_generalized), - norm_vec_aux_site, - 0 - ); - //---------------------------------------------------------------------------- - // Site Comb with AUX--------------------------------------------------------- - combined_log_r_site_t_in_weeks = append_row(log_r_site_t_in_weeks_matrix, log_r_aux_site_t_in_weeks'); + //---------------------------------------------------------------------------- // Site level disease dynamics - i_first_obs_over_n_site = inv_logit(logit(i_first_obs_over_n) + + // initial conditions + i_first_obs_over_n_subpop[1] = inv_logit(logit(i_first_obs_over_n) + + (n_subpops > 1 ? offset_ref_logit_i_first_obs[1] : 0)); + initial_exp_growth_rate_subpop[1] = mean_initial_exp_growth_rate + + (n_subpops > 1 ? offset_ref_initial_exp_growth_rate[1] : 0); + i_first_obs_over_n_subpop[2:n_subpops] = inv_logit(logit(i_first_obs_over_n) + sigma_i_first_obs * eta_i_first_obs); - initial_exp_growth_rate_site = mean_initial_exp_growth_rate + + initial_exp_growth_rate_subpop[2:n_subpops] = mean_initial_exp_growth_rate + sigma_initial_exp_growth_rate * eta_initial_exp_growth_rate; + // Loop over n_subpops to estimate deviations from reference subpop and + // generate infections and wastewater concentrations for (i in 1:n_subpops) { - real log_i0_site = log(i_first_obs_over_n_site[i]) - uot * initial_exp_growth_rate_site[i]; + + log_i0_subpop = log(i_first_obs_over_n_subpop[i]) - uot * initial_exp_growth_rate_subpop[i]; + + // Let site-level R(t) vary around the reference subpopulation R(t) + // log(R(t)subpop) ~ log(R(t)sref) + autoreg*(log(R(t)ref-log(R(t)subpop)) + eta_subpop + if(i == 1) { + log_r_subpop_t_in_weeks = log_r_t_in_weeks + + (n_subpops > 1 ? offset_ref_log_r_t[1] : 0); + } else { + log_r_subpop_t_in_weeks = to_vector(log_r_subpop_t_in_weeks_matrix[i-1, :]); + } + //convert from weekly to daily - log_r_site_t_in_weeks_vector = to_vector(combined_log_r_site_t_in_weeks[i, :]); - unadj_r_site_t = exp(to_row_vector(ind_m*(log_r_site_t_in_weeks_vector))); + unadj_r_subpop_t = exp(to_row_vector(ind_m*(log_r_subpop_t_in_weeks))); { - tuple(vector[num_elements(state_inf_per_capita)], vector[num_elements(unadj_r)]) output; + tuple(vector[num_elements(state_inf_per_capita)], vector[num_elements(unadj_r_subpop_t)]) output; output = generate_infections( - to_vector(unadj_r_site_t), + to_vector(unadj_r_subpop_t), uot, gt_rev_pmf, - log_i0_site , - initial_exp_growth_rate_site[i], + log_i0_subpop , + initial_exp_growth_rate_subpop[i], ht, infection_feedback, infection_feedback_rev_pmf ); - new_i_site = to_row_vector(output.1); - r_site_t[i] = to_row_vector(output.2); + new_i_subpop = to_row_vector(output.1); + r_subpop_t[i] = to_row_vector(output.2); } - // For each site, tack on number of state infections - // site level infection dynamics sum to the total state infections: - pop_fraction = subpop_size[i] / norm_pop; - state_inf_per_capita += pop_fraction * to_vector(new_i_site); + // For each subpopulation, tack on number of infections + // subpopulation level infection dynamics sum to the total infections: + pop_fraction = subpop_size[i] / norm_pop; // first subpop is ref subpop + state_inf_per_capita += pop_fraction * to_vector(new_i_subpop); - model_net_i = to_row_vector(convolve_dot_product(to_vector(new_i_site), + model_net_i = to_row_vector( + convolve_dot_product(to_vector(new_i_subpop), reverse(s), (uot + ot + ht))); - - model_log_v_ot[i] = log(10) * log10_g + log(model_net_i[(uot+1):(uot + ot + ht) ] + 1e-8) - log(mwpd); @@ -334,7 +364,7 @@ transformed parameters { // These are the true expected genomes at the site level before observation error // (which is at the lab-site level) for (i in 1:owt) { - exp_obs_log_v_true[i] = model_log_v_ot[ww_sampled_sites[i], ww_sampled_times[i]]; + exp_obs_log_v_true[i] = model_log_v_ot[ww_sampled_subpops[i], ww_sampled_times[i]]; } // modify by lab-site specific variation (multiplier!) @@ -366,12 +396,18 @@ model { //-------------------------------------------------------------------------- w ~ std_normal(); + offset_ref_log_r_t ~ normal(offset_ref_log_r_t_prior_mean, offset_ref_log_r_t_prior_sd); + offset_ref_logit_i_first_obs ~ normal(offset_ref_logit_i_first_obs_prior_mean, + offset_ref_logit_i_first_obs_prior_sd); + offset_ref_initial_exp_growth_rate ~ normal(offset_ref_initial_exp_growth_rate_prior_mean, + offset_ref_initial_exp_growth_rate_prior_sd); eta_sd ~ normal(0, eta_sd_sd); - autoreg_rt_site ~ beta(autoreg_rt_site_a, autoreg_rt_site_b); + autoreg_rt_subpop ~ beta(autoreg_rt_subpop_a, autoreg_rt_subpop_b); autoreg_rt ~ beta(autoreg_rt_a, autoreg_rt_b); autoreg_p_hosp ~ beta(autoreg_p_hosp_a, autoreg_p_hosp_b); - log_r_mu_intercept ~ normal(r_logmean, r_logsd); + log_r_t_first_obs ~ normal(r_logmean, r_logsd); + to_vector(error_rt_subpop) ~ std_normal(); sigma_rt ~ normal(0, sigma_rt_prior); i_first_obs_over_n ~ beta(i_first_obs_over_n_prior_a, i_first_obs_over_n_prior_b); @@ -438,7 +474,7 @@ generated quantities { // Here need to iterate through each lab-site, find the corresponding site // and apply the expected lab-site error for(i in 1:n_ww_lab_sites) { - pred_ww[i] = normal_rng(model_log_v_ot[lab_site_to_site_map[i], 1 : ot + ht] + ww_site_mod[i], + pred_ww[i] = normal_rng(model_log_v_ot[lab_site_to_subpop_map[i], 1 : ot + ht] + ww_site_mod[i], sigma_ww_site[i]); } diff --git a/man/.DS_Store b/man/.DS_Store new file mode 100644 index 00000000..15e9aa79 Binary files /dev/null and b/man/.DS_Store differ diff --git a/man/assert_cols_det_unique_row.Rd b/man/assert_cols_det_unique_row.Rd new file mode 100644 index 00000000..be4afc63 --- /dev/null +++ b/man/assert_cols_det_unique_row.Rd @@ -0,0 +1,36 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/checkers.R +\name{assert_cols_det_unique_row} +\alias{assert_cols_det_unique_row} +\title{Check a set of columns in a data frame uniquely identify +data frame rows.} +\usage{ +assert_cols_det_unique_row( + df, + unique_key_columns, + arg = "x", + call = rlang::caller_env(), + add_err_msg = "" +) +} +\arguments{ +\item{df}{the dataframe to check} + +\item{unique_key_columns}{Columns that, taken together, should +uniquely identify a row in the data frame.} + +\item{arg}{the name of the unique grouping to check} + +\item{call}{Calling environment to be passed to \code{\link[cli:cli_abort]{cli::cli_abort()}} for +traceback.} + +\item{add_err_msg}{string containing an additional error message, +default is the empty string (\code{""})} +} +\value{ +NULL, invisibly +} +\description{ +Equivalently, this checks that when grouping by the columns in question, +each group has a single entry +} diff --git a/man/assert_no_dates_after_max.Rd b/man/assert_no_dates_after_max.Rd index a59a791c..7dc13a02 100644 --- a/man/assert_no_dates_after_max.Rd +++ b/man/assert_no_dates_after_max.Rd @@ -4,7 +4,13 @@ \alias{assert_no_dates_after_max} \title{Check that all dates in dataframe passed in are before a specified date} \usage{ -assert_no_dates_after_max(date_vector, max_date, call = rlang::caller_env()) +assert_no_dates_after_max( + date_vector, + max_date, + arg_dates = "y", + arg_max_date = "x", + call = rlang::caller_env() +) } \arguments{ \item{date_vector}{vector of dates} @@ -12,6 +18,12 @@ assert_no_dates_after_max(date_vector, max_date, call = rlang::caller_env()) \item{max_date}{string indicating the maximum date in ISO8601 convention e.g. YYYY-MM-DD} +\item{arg_dates}{string to print the name of the data you are checking the +dates for} + +\item{arg_max_date}{string to print the name of the maximum date you are +checkign the data for} + \item{call}{Calling environment to be passed to \code{\link[cli:cli_abort]{cli::cli_abort()}} for traceback.} } diff --git a/man/check_req_count_cols_present.Rd b/man/assert_req_count_cols_present.Rd similarity index 92% rename from man/check_req_count_cols_present.Rd rename to man/assert_req_count_cols_present.Rd index 7088e534..8f6bdc78 100644 --- a/man/check_req_count_cols_present.Rd +++ b/man/assert_req_count_cols_present.Rd @@ -1,10 +1,10 @@ % Generated by roxygen2: do not edit by hand % Please edit documentation in R/checkers.R -\name{check_req_count_cols_present} -\alias{check_req_count_cols_present} +\name{assert_req_count_cols_present} +\alias{assert_req_count_cols_present} \title{Check that the input count data contains all the required column names} \usage{ -check_req_count_cols_present( +assert_req_count_cols_present( count_data, count_col_name, pop_size_col_name, diff --git a/man/assert_sufficient_days_of_data.Rd b/man/assert_sufficient_days_of_data.Rd index bc4a95a6..5f12bb78 100644 --- a/man/assert_sufficient_days_of_data.Rd +++ b/man/assert_sufficient_days_of_data.Rd @@ -7,6 +7,7 @@ calibration time} \usage{ assert_sufficient_days_of_data( date_vector, + data_name, calibration_time, call = rlang::caller_env(), add_err_msg = "" @@ -15,6 +16,10 @@ assert_sufficient_days_of_data( \arguments{ \item{date_vector}{the vector of dates to check, must be of Date type} +\item{data_name}{What data correspond to the dates in \code{date_vector}. +Used to make the error message informative (e.g. +"hospital admissions data")} + \item{calibration_time}{integer indicating the number of days that the dates must span} diff --git a/man/figures/.DS_Store b/man/figures/.DS_Store new file mode 100644 index 00000000..5008ddfc Binary files /dev/null and b/man/figures/.DS_Store differ diff --git a/man/figures/logo.svg b/man/figures/logo.svg new file mode 100644 index 00000000..30660d2c --- /dev/null +++ b/man/figures/logo.svg @@ -0,0 +1,211 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/man/format_ww_data.Rd b/man/format_ww_data.Rd index acc22f22..c70c4613 100644 --- a/man/format_ww_data.Rd +++ b/man/format_ww_data.Rd @@ -35,7 +35,8 @@ site combination} } \value{ a tidy dataframe containing observed wastewater concentrations -in log genome copies per mL for each site and lab at each time point +in log estimated genome copies per mL for each site and lab at each time +point } \description{ Format the wastewater data as a tidy data frame diff --git a/man/get_date_time_spine.Rd b/man/get_date_time_spine.Rd new file mode 100644 index 00000000..c00d125b --- /dev/null +++ b/man/get_date_time_spine.Rd @@ -0,0 +1,37 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/get_stan_data.R +\name{get_date_time_spine} +\alias{get_date_time_spine} +\title{Get date time spine to map to model output} +\usage{ +get_date_time_spine( + forecast_date, + input_count_data, + last_count_data_date, + calibration_time, + forecast_horizon +) +} +\arguments{ +\item{forecast_date}{a character string in ISO8601 format (YYYY-MM-DD) +indicating the date that the forecast is to be made.} + +\item{input_count_data}{a dataframe of the count data to be passed +directly to stan, , must have the following columns: date, count, total_pop} + +\item{last_count_data_date}{string indicating the date of the last observed +count data point in 1SO8601 format (YYYY-MM-DD)} + +\item{calibration_time}{integer indicating the number of days to calibrate +the model for, default is \code{90}} + +\item{forecast_horizon}{integer indicating the number of days, including the +forecast date, to produce forecasts for, default is \code{28}} +} +\value{ +a tibble containing an integer for time mapped to the corresponding +date, for the entire calibration and forecast period +} +\description{ +Get date time spine to map to model output +} diff --git a/man/get_draws.Rd b/man/get_draws.Rd new file mode 100644 index 00000000..7e5d91c7 --- /dev/null +++ b/man/get_draws.Rd @@ -0,0 +1,89 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/get_draws.R +\name{get_draws} +\alias{get_draws} +\alias{get_draws_df} +\alias{get_draws.wwinference_fit} +\alias{get_draws.default} +\alias{get_draws.data.frame} +\alias{plot.wwinference_fit_draws} +\title{Postprocess to generate a draws dataframe} +\usage{ +get_draws(x, ..., what = "all") + +get_draws_df(x, ...) + +\method{get_draws}{wwinference_fit}(x, ..., what = "all") + +\method{get_draws}{default}(x, ..., what = "all") + +\method{get_draws}{data.frame}( + x, + count_data, + date_time_spine, + site_subpop_spine, + lab_site_subpop_spine, + stan_data_list, + fit_obj, + ..., + what = "all" +) + +\method{plot}{wwinference_fit_draws}(x, y = NULL, what, ...) +} +\arguments{ +\item{x}{An object of class \code{get_draws}.} + +\item{...}{additional arguments} + +\item{what}{Character vector. Specifies the variables to extract from the +draws. It could be any from \code{"all"} \code{"predicted_counts"}, \code{"predicted_ww"}, +\code{"global_rt"}, or \code{"subpop_rt"}. When \code{what = "all"} (the default), +the function will extract all four variables.} + +\item{count_data}{A dataframe of the preprocessed daily count data (e.g. +hospital admissions) from the "global" population} + +\item{date_time_spine}{tibble mapping dates to time in days} + +\item{site_subpop_spine}{tibble mapping sites to subpopulations} + +\item{lab_site_subpop_spine}{tibble mapping lab-sites to subpopulations} + +\item{stan_data_list}{A list containing all the data passed to stan for +fitting the model} + +\item{fit_obj}{a CmdStan object that is the output of fitting the model to +\code{x} and \code{count_data}} + +\item{y}{Ignored in the the case of \code{plot}.} +} +\value{ +A tibble containing the full set of posterior draws of the +estimated, nowcasted, and forecasted: counts, site-level wastewater +concentrations, "global"(e.g. state) R(t) estimate, and the "local" (site + +the one auxiliary subpopulation) R(t) estimates. In the instance where there +are observations, the data will be joined to each draw of the predicted +observation to facilitate plotting. +} +\description{ +This function takes in the two input data sources, the CmdStan fit object, +and the 3 relevant mappings from stan indices to the real data, in order +to generate a dataframe containing the posterior draws of the counts (e.g. +hospital admissions), the wastewater concentration values, the "global" R(t), +and the "local" R(t) estimates + the critical metadata in the data. +This funtion has a default method that takes the two sets of input data, +the last of stan arguments, and the CmdStan fitting object, as well as an S3 +method for objects of class 'wwinference_fit' + +This method overloads the generic \code{get_draws} function specifically +for objects of type 'wwinference_fit'. +} +\details{ +The function \code{get_draws_df()} has been deprecated in favor of \code{get_draws()}. + +The plot method for \code{wwinference_fit_draws} is a wrapper of +\code{get_plot_forecasted_counts}, \code{get_plot_ww_conc}, \code{get_plot_global_rt}, +and \code{get_plot_subpop_rt}. Depending on the value of \code{what}, the function +will call the appropriate method. +} diff --git a/man/get_draws_df.Rd b/man/get_draws_df.Rd deleted file mode 100644 index ee6ec13a..00000000 --- a/man/get_draws_df.Rd +++ /dev/null @@ -1,53 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/get_draws_df.R -\name{get_draws_df} -\alias{get_draws_df} -\alias{get_draws_df.wwinference_fit} -\alias{get_draws_df.default} -\alias{get_draws_df.data.frame} -\title{Postprocess to generate a draws dataframe} -\usage{ -get_draws_df(x, ...) - -\method{get_draws_df}{wwinference_fit}(x, ...) - -\method{get_draws_df}{default}(x, ...) - -\method{get_draws_df}{data.frame}(x, count_data, stan_data_list, fit_obj, ...) -} -\arguments{ -\item{x}{Either a dataframe of wastewater observations, or an object of -class wwinference_fit} - -\item{...}{additional arguments} - -\item{count_data}{A dataframe of the preprocessed daily count data (e.g. -hospital admissions) from the "global" population} - -\item{stan_data_list}{A list containing all the data passed to stan for -fitting the model} - -\item{fit_obj}{a CmdStan object that is the output of fitting the model to -\code{x} and \code{count_data}} -} -\value{ -A tibble containing the full set of posterior draws of the -estimated, nowcasted, and forecasted: counts, site-level wastewater -concentrations, "global"(e.g. state) R(t) estimate, and the "local" (site + -the one auxiliary subpopulation) R(t) estimates. In the instance where there -are observations, the data will be joined to each draw of the predicted -observation to facilitate plotting. -} -\description{ -This function takes in the two input data sources, the CmdStan fit object, -and the 3 relevant mappings from stan indices to the real data, in order -to generate a dataframe containing the posterior draws of the counts (e.g. -hospital admissions), the wastewater concentration values, the "global" R(t), -and the "local" R(t) estimates + the critical metadata in the data. -This funtion has a default method that takes the two sets of input data, -the last of stan arguments, and the CmdStan fitting object, as well as an S3 -method for objects of class 'wwinference_fit' - -This method overloads the generic get_draws_df function specifically -for objects of type 'wwinference_fit'. -} diff --git a/man/get_lab_site_site_spine.Rd b/man/get_lab_site_site_spine.Rd new file mode 100644 index 00000000..decd26b8 --- /dev/null +++ b/man/get_lab_site_site_spine.Rd @@ -0,0 +1,21 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/get_stan_data.R +\name{get_lab_site_site_spine} +\alias{get_lab_site_site_spine} +\title{Get mapping from lab-site to site} +\usage{ +get_lab_site_site_spine(input_ww_data) +} +\arguments{ +\item{input_ww_data}{a dataframe of the wastewater data to be passed +directly to stan, must have the following columns: date, site, lab, +genome_copies_per_ml, site_pop, below_lod, and exclude} +} +\value{ +a dataframe mapping the unique combinations of sites and labs +to their indices in the model and the population of the site in that +observation unit (lab_site) +} +\description{ +Get mapping from lab-site to site +} diff --git a/man/get_lab_site_subpop_spine.Rd b/man/get_lab_site_subpop_spine.Rd new file mode 100644 index 00000000..6c2caefa --- /dev/null +++ b/man/get_lab_site_subpop_spine.Rd @@ -0,0 +1,19 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/get_stan_data.R +\name{get_lab_site_subpop_spine} +\alias{get_lab_site_subpop_spine} +\title{Get lab-site subpopulation spine} +\usage{ +get_lab_site_subpop_spine(lab_site_site_spine, site_subpop_spine) +} +\arguments{ +\item{lab_site_site_spine}{tibble mapping lab-sites to sites} + +\item{site_subpop_spine}{tibble mapping sites to subpopulations} +} +\value{ +a tibble mapping lab-sites to subpopulations +} +\description{ +Get lab-site subpopulation spine +} diff --git a/man/get_mcmc_options.Rd b/man/get_mcmc_options.Rd index 454b2c9a..193bb6f1 100644 --- a/man/get_mcmc_options.Rd +++ b/man/get_mcmc_options.Rd @@ -7,41 +7,43 @@ get_mcmc_options( iter_warmup = 750, iter_sampling = 500, - n_chains = 4, seed = NULL, + chains = 4, adapt_delta = 0.95, max_treedepth = 12 ) } \arguments{ \item{iter_warmup}{integer indicating the number of warm-up iterations, -default is \code{750}} +default is \code{750}.} \item{iter_sampling}{integer indicating the number of sampling iterations, -default is \code{500}} +default is \code{500}.} -\item{n_chains}{integer indicating the number of MCMC chains to run, default -is \code{4}} +\item{seed}{integer, A seed for the (P)RNG to pass to CmdStan. In the case +of multi-chain sampling the single seed will automatically be augmented by +the the run (chain) ID so that each chain uses a different seed. +Default is \code{NULL}.} -\item{seed}{set of integers indicating the random seed of the stan sampler, -default is NULL} +\item{chains}{integer indicating the number of MCMC chains to run, default +is \code{4}.} \item{adapt_delta}{float between 0 and 1 indicating the average acceptance -probability, default is \code{0.95}} +probability, default is \code{0.95}.} \item{max_treedepth}{integer indicating the maximum tree depth of the -sampler, default is 12} +sampler, default is 12.} } \value{ -a list of mcmc settings with the values given by the function +A list of MCMC settings with the values given by the function. arguments } \description{ This function returns a list of MCMC settings to pass to the -\code{cmdstanr::sample()} function to fit the model. The default settings are -specified for production-level runs, consider adjusting to optimize -for speed while iterating. -} -\examples{ -mcmc_settings <- get_mcmc_options() +\code{\link[cmdstanr:model-method-sample]{$sample()}} function to fit the model. +The default settings are specified for production-level runs. +All input arguments to \code{\link[cmdstanr:model-method-sample]{$sample()}} +are configurable by the user. See +\code{\link[cmdstanr:model-method-sample]{$sample()}} documentation +for details of the available arguments. } diff --git a/man/get_model_diagnostic_flags.Rd b/man/get_model_diagnostic_flags.Rd index bb4ebda5..6fbaee5c 100644 --- a/man/get_model_diagnostic_flags.Rd +++ b/man/get_model_diagnostic_flags.Rd @@ -60,6 +60,7 @@ specifically for objects of type 'wwinference_fit'. \seealso{ Other diagnostics: \code{\link{parameter_diagnostics}()}, +\code{\link{summary_diagnostics}()}, \code{\link{wwinference}()} } \concept{diagnostics} diff --git a/man/get_site_subpop_spine.Rd b/man/get_site_subpop_spine.Rd new file mode 100644 index 00000000..a77617f9 --- /dev/null +++ b/man/get_site_subpop_spine.Rd @@ -0,0 +1,25 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/get_stan_data.R +\name{get_site_subpop_spine} +\alias{get_site_subpop_spine} +\title{Get site to subpopulation map} +\usage{ +get_site_subpop_spine(input_ww_data, input_count_data) +} +\arguments{ +\item{input_ww_data}{a dataframe of the wastewater data to be passed +directly to stan, must have the following columns: date, site, lab, +genome_copies_per_ml, site_pop, below_lod, and exclude} + +\item{input_count_data}{a dataframe of the count data to be passed +directly to stan, , must have the following columns: date, count, total_pop} +} +\value{ +a dataframe mapping the sites to the corresponding subpopulation and +subpopulation index, plus the population in each subpopulation. Imposes +the logic to add a subpopulation if the total population is greater than +the sum of the site populations in the input wastewater data +} +\description{ +Get site to subpopulation map +} diff --git a/man/get_stan_data.Rd b/man/get_stan_data.Rd index 1ee9c60d..8a4b7d0a 100644 --- a/man/get_stan_data.Rd +++ b/man/get_stan_data.Rd @@ -7,6 +7,12 @@ get_stan_data( input_count_data, input_ww_data, + date_time_spine, + lab_site_site_spine, + site_subpop_spine, + lab_site_subpop_spine, + last_count_data_date, + first_count_data_date, forecast_date, forecast_horizon, calibration_time, @@ -21,12 +27,24 @@ get_stan_data( ) } \arguments{ -\item{input_count_data}{a dataframe of the count data to be passed -directly to stan, , must have the following columns: date, count, total_pop} +\item{input_count_data}{tibble with the input count data needed for stan} + +\item{input_ww_data}{tibble with the input wastewater data and indices +needed for stan} + +\item{date_time_spine}{tibble mapping dates to time in days} + +\item{lab_site_site_spine}{tibble mapping lab-sites to sites} + +\item{site_subpop_spine}{tibble mapping sites to subpopulations} + +\item{lab_site_subpop_spine}{tibble mapping lab-sites to subpopulations} -\item{input_ww_data}{a dataframe of the wastewater data to be passed -directly to stan, must have the following columns: date, site, lab, -genome_copies_per_ml, site_pop, below_lod, and exclude} +\item{last_count_data_date}{string indicating the date of the last data +point in the count dataset in ISO8601 convention e.g. YYYY-MM-DD} + +\item{first_count_data_date}{string indicating the date of the first data +point in the count dataset in ISO8601 convention e.g. YYYY-MM-DD} \item{forecast_date}{string indicating the forecast date in ISO8601 convention e.g. YYYY-MM-DD} @@ -134,9 +152,33 @@ input_ww_data_for_stan <- get_input_ww_data_for_stan( last_count_data_date, calibration_time ) +date_time_spine <- get_date_time_spine( + forecast_date = forecast_date, + input_count_data = input_count_data_for_stan, + last_count_data_date = last_count_data_date, + forecast_horizon = forecast_horizon, + calibration_time = calibration_time +) +lab_site_site_spine <- get_lab_site_site_spine( + input_ww_data = input_ww_data_for_stan +) +site_subpop_spine <- get_site_subpop_spine( + input_ww_data = input_ww_data_for_stan, + input_count_data = input_count_data_for_stan +) +lab_site_subpop_spine <- get_lab_site_subpop_spine( + lab_site_site_spine = lab_site_site_spine, + site_subpop_spine +) stan_data_list <- get_stan_data( input_count_data_for_stan, input_ww_data_for_stan, + date_time_spine, + lab_site_site_spine, + site_subpop_spine, + lab_site_subpop_spine, + last_count_data_date, + first_count_data_date, forecast_date, forecast_horizon, calibration_time, diff --git a/man/get_subpop_data.Rd b/man/get_subpop_data.Rd deleted file mode 100644 index ed5600a9..00000000 --- a/man/get_subpop_data.Rd +++ /dev/null @@ -1,28 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/get_stan_data.R -\name{get_subpop_data} -\alias{get_subpop_data} -\title{Get subpopulation data} -\usage{ -get_subpop_data(add_auxiliary_site, state_pop, pop_ww, n_ww_sites) -} -\arguments{ -\item{add_auxiliary_site}{Boolean indicating whether to add another -subpopulation in addition to the wastewater sites to estimate R(t) of} - -\item{state_pop}{The state population size} - -\item{pop_ww}{The population size in each of the wastewater sites} - -\item{n_ww_sites}{The number of wastewater sites} -} -\value{ -A list containing the necessary integers and vectors that stan -needs to estiamte infection dynamics for each subpopulation -} -\description{ -Get subpopulation data -} -\examples{ -subpop_data <- get_subpop_data(TRUE, 100000, c(1000, 500), 2) -} diff --git a/man/get_ww_data_indices.Rd b/man/get_ww_data_indices.Rd deleted file mode 100644 index 1ecebfd6..00000000 --- a/man/get_ww_data_indices.Rd +++ /dev/null @@ -1,45 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/get_stan_data.R -\name{get_ww_data_indices} -\alias{get_ww_data_indices} -\title{Get wastewater data indices} -\usage{ -get_ww_data_indices( - ww_data, - first_count_data_date, - owt, - lod_col_name = "below_lod" -) -} -\arguments{ -\item{ww_data}{Input wastewater dataframe containing one row -per observation, with outliers already removed} - -\item{first_count_data_date}{The earliest day with an observation in the ' -count dataset, in ISO8601 format YYYY-MM-DD} - -\item{owt}{number of wastewater observations} - -\item{lod_col_name}{A string representing the name of the -column in the input_ww_data that provides a 0 if the data point is not above -the LOD and a 1 if the data is below the LOD, default value is \code{below_LOD}} -} -\value{ -A list containing the necessary vectors of indices that -the stan model requires: -ww_censored: the vector of time points that the wastewater observations are -censored (below the LOD) in order of the date and the site index -ww_uncensored: the vector of time points that the wastewater observations are -uncensored (above the LOD) in order of the date and the site index -ww_sampled_times: the vector of time points that the wastewater observations -are passed in in log_conc in order of the date and the site index -ww_sampled_sites: the vector of sites that correspond to the observations -passed in in log_conc in order of the date and the site index -ww_sampled_lab_sites: the vector of unique combinations of site and labs -that correspond to the observations passed in in log_conc in order of the -date and the site index -lab_site_to_site_map: the vector of sites that correspond to each lab-site -} -\description{ -Get wastewater data indices -} diff --git a/man/get_ww_indices_and_values.Rd b/man/get_ww_indices_and_values.Rd new file mode 100644 index 00000000..b50f1a65 --- /dev/null +++ b/man/get_ww_indices_and_values.Rd @@ -0,0 +1,32 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/get_stan_data.R +\name{get_ww_indices_and_values} +\alias{get_ww_indices_and_values} +\title{Get wastewater indices and values for stan} +\usage{ +get_ww_indices_and_values( + input_ww_data, + date_time_spine, + lab_site_site_spine, + site_subpop_spine, + lab_site_subpop_spine +) +} +\arguments{ +\item{input_ww_data}{tibble with the input wastewater data and indices +needed for stan} + +\item{date_time_spine}{tibble mapping dates to time in days} + +\item{lab_site_site_spine}{tibble mapping lab-sites to sites} + +\item{site_subpop_spine}{tibble mapping sites to subpopulations} + +\item{lab_site_subpop_spine}{tibble mapping lab-sites to subpopulations} +} +\value{ +a list of the vectors needed for stan +} +\description{ +Get wastewater indices and values for stan +} diff --git a/man/get_ww_values.Rd b/man/get_ww_values.Rd deleted file mode 100644 index 4498a6b3..00000000 --- a/man/get_ww_values.Rd +++ /dev/null @@ -1,50 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/get_stan_data.R -\name{get_ww_values} -\alias{get_ww_values} -\title{Get wastewater data values} -\usage{ -get_ww_values( - ww_data, - ww_measurement_col_name = "log_genome_copies_per_ml", - ww_lod_value_col_name = "log_lod", - ww_site_pop_col_name = "site_pop", - one_pop_per_site = TRUE, - padding_value = 1e-08 -) -} -\arguments{ -\item{ww_data}{Input wastewater dataframe containing one row -per observation, with outliers already removed} - -\item{ww_measurement_col_name}{A string representing the name of the column -in the input_ww_data that indicates the wastewater measurement value in -log scale, default is \code{log_genome_copies_per_ml}} - -\item{ww_lod_value_col_name}{A string representing the name of the column -in the ww_data that indicates the value of the LOD in log scale, -default is \code{log_lod}} - -\item{ww_site_pop_col_name}{A string representing the name of the column in -the ww_data that indicates the number of people represented by that -wastewater catchment, default is \code{site_pop}} - -\item{one_pop_per_site}{a boolean variable indicating if there should only -be on catchment area population per site, default is \code{TRUE} because this is -what the stan model expects} - -\item{padding_value}{an smaller numeric value to add to the the -concentration measurements to ensure that log transformation will produce -real numbers, default value is \code{1e-8}} -} -\value{ -A list containing the necessary vectors of values that -the stan model requires: -ww_lod: a vector of the LODs of the corresponding wastewater measurement -pop_ww: a vector of the population sizes of the wastewater catchment areas -in order of the sites by site_index -log_conc: a vector of the log of the wastewater concentration observation -} -\description{ -Get wastewater data values -} diff --git a/man/hosp_data.Rd b/man/hosp_data.Rd index 76f09b0f..0463d14d 100644 --- a/man/hosp_data.Rd +++ b/man/hosp_data.Rd @@ -40,6 +40,9 @@ standatds as YYYY-MM-DD} hospital on that date, available as of the forecast date} \item{state_pop}{The number of people contributing to the daily hospital admissions} +\item{location}{ A string indicating the location that all of the +data is coming from. This is not a necessary column, but instead is +included to more realistically mirror a typical workflow} } } \keyword{datasets} diff --git a/man/parameter_diagnostics.Rd b/man/parameter_diagnostics.Rd index ffbc6404..db50149b 100644 --- a/man/parameter_diagnostics.Rd +++ b/man/parameter_diagnostics.Rd @@ -19,6 +19,7 @@ wwinference_fit_object \seealso{ Other diagnostics: \code{\link{get_model_diagnostic_flags}()}, +\code{\link{summary_diagnostics}()}, \code{\link{wwinference}()} } \concept{diagnostics} diff --git a/man/summary_diagnostics.Rd b/man/summary_diagnostics.Rd new file mode 100644 index 00000000..228e34ea --- /dev/null +++ b/man/summary_diagnostics.Rd @@ -0,0 +1,25 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/model_diagnostics.R +\name{summary_diagnostics} +\alias{summary_diagnostics} +\title{Method for printing the CmdStan summary diagnostics for +wwinference_fit_object} +\usage{ +summary_diagnostics(ww_fit, ...) +} +\arguments{ +\item{ww_fit}{An object of class wwinference_fit} + +\item{...}{additional arguments} +} +\description{ +Method for printing the CmdStan summary diagnostics for +wwinference_fit_object +} +\seealso{ +Other diagnostics: +\code{\link{get_model_diagnostic_flags}()}, +\code{\link{parameter_diagnostics}()}, +\code{\link{wwinference}()} +} +\concept{diagnostics} diff --git a/man/validate_both_datasets.Rd b/man/validate_both_datasets.Rd index c71dafe5..8224586b 100644 --- a/man/validate_both_datasets.Rd +++ b/man/validate_both_datasets.Rd @@ -8,6 +8,10 @@ compatible with one another and the the user-specified parameters} validate_both_datasets( input_count_data, input_ww_data, + date_time_spine, + lab_site_site_spine, + site_subpop_spine, + lab_site_subpop_spine, calibration_time, forecast_date ) @@ -19,6 +23,14 @@ been filtered and is ready to be passed into stan} \item{input_ww_data}{tibble containing the input wastewater data that has been filtered and is ready to be passed into stan} +\item{date_time_spine}{tibble mapping dates to time in days} + +\item{lab_site_site_spine}{tibble mapping lab-sites to sites} + +\item{site_subpop_spine}{tibble mapping sites to subpopulations} + +\item{lab_site_subpop_spine}{tibble mapping lab-sites to subpopulations} + \item{calibration_time}{integer indicating the calibration time} \item{forecast_date}{IS08 formatted date indicating the forecast date} diff --git a/man/validate_pmf.Rd b/man/validate_pmf.Rd index 20fb0362..4e84ce76 100644 --- a/man/validate_pmf.Rd +++ b/man/validate_pmf.Rd @@ -10,6 +10,7 @@ validate_pmf( pmf, calibration_time, count_data, + tolerance = 1e-06, arg = "x", call = rlang::caller_env() ) @@ -23,6 +24,9 @@ each day} \item{count_data}{tibble containing the input count data ready to be passed to stan} +\item{tolerance}{numeric indicating the allowable difference between the +sum of the pmf and 1, default is \code{1e-6}} + \item{arg}{name of the argument supplying the object} \item{call}{The calling environment to be reflected in the error message} diff --git a/man/ww_data_ind.Rd b/man/ww_data_ind.Rd index b64ed8f6..103b3342 100644 --- a/man/ww_data_ind.Rd +++ b/man/ww_data_ind.Rd @@ -13,15 +13,18 @@ A tibble with 102 rows and 6 columns YYYY-MM-DD} \item{site}{The wastewater treatment plant where the sample was collected} \item{lab}{The lab where the sample was processed} -\item{genome_copies_per_ml}{The wastewater concentration measured on the -date specified, collected in the site specified, and processed in the lab -specified. The default parameters assume that this quantity is reported -as the genome copies per mL, on a natural scale.} -\item{lod}{The limit of detection in the site and lab on a particular day -of the quantification device (e.g. PCR). This is also by default reported -in terms of the genome copies per mL.} +\item{log_genome_copies_per_ml}{The natural log of the wastewater +concentration measured on the date specified, collected in the site +specified, and processed in the lab specified. The package expects +this quantity in units of log estimated genome copies per mL.} +\item{log_lod}{The log of the limit of detection in the site and lab on a +particular day of the quantification device (e.g. PCR). This should be in +units of log estimated genome copies per mL.} \item{site_pop}{The population size of the wastewater catchment area represented by the site variable} +\item{location}{ A string indicating the location that all of the +data is coming from. This is not a necessary column, but instead is +included to more realistically mirror a typical workflow} } } } @@ -33,12 +36,13 @@ ww_data_ind } \description{ A dataset containing the simulated wastewater concentrations -(labeled here as \code{genome_copies_per_ml}) by sample collection date (\code{date}), -the site where the sample was collected (\code{site}) and the lab where the -samples were processed (\code{lab}). Additional columns that are required -attributes needed for the model are the limit of detection for that lab on -each day (labeled here as \code{lod}) and the population size of the wastewater -catchment area represented by the wastewater concentrations in each \code{site}. +(labeled here as \code{log_genome_copies_per_ml}) by sample collection date +(\code{date}), the site where the sample was collected (\code{site}) and the lab +where the samples were processed (\code{lab}). Additional columns that are +required attributes needed for the model are the limit of detection for +that lab on each day (labeled here as \code{log_lod}) and the population size of +the wastewater catchment area represented by the wastewater concentrations +in each \code{site}. } \details{ This data is generated via the default values in the diff --git a/man/wwinference-package.Rd b/man/wwinference-package.Rd index 9ff7fd04..7f92dd92 100644 --- a/man/wwinference-package.Rd +++ b/man/wwinference-package.Rd @@ -24,6 +24,7 @@ Authors: \item Dylan Morris \email{dylan@dylanhmorris.com} (\href{https://orcid.org/0000-0002-3655-406X}{ORCID}) \item Sam Abbott \email{contact@samabbott.co.uk} (\href{https://orcid.org/0000-0001-8057-8037}{ORCID}) \item Christian Bernal Zelaya \email{xuk0@cdc.gov} + \item George Vega Yon \email{g.vegayon@gmail.com} (\href{https://orcid.org/0000-0002-3171-0844}{ORCID}) \item Damon Bayer \email{xum8@cdc.gov} \item Andrew Magee \email{rzg0@cdc.gov} \item Scott Olesen \email{ulp7@cdc.gov} @@ -31,7 +32,10 @@ Authors: Other contributors: \itemize{ - \item George Vega Yon \email{g.vegayon@gmail.com} (\href{https://orcid.org/0000-0002-3171-0844}{ORCID}) [contributor] + \item Adam Howes \email{adamthowes@gmail.com} (\href{https://orcid.org/0000-0003-2386-4031}{ORCID}) [contributor] + \item Chirag Kumar \email{kzs9@cdc.gov} [contributor] + \item Alexander Keyel \email{alexander.keyel@health.ny.gov} (\href{https://orcid.org/000-0001-5256-6274}{ORCID}) [contributor] + \item Hannah Cohen \email{llg4@cdc.gov} [contributor] } } diff --git a/man/wwinference.Rd b/man/wwinference.Rd index 61fae5d0..274b47f6 100644 --- a/man/wwinference.Rd +++ b/man/wwinference.Rd @@ -15,7 +15,7 @@ wwinference( calibration_time = 90, forecast_horizon = 28, model_spec = get_model_spec(), - fit_opts = get_mcmc_options(), + fit_opts = list(), generate_initial_values = TRUE, initial_values_seed = NULL, compiled_model = compile_model(), @@ -31,7 +31,7 @@ wwinference( \item{ww_data}{A dataframe containing the pre-processed, site-level wastewater concentration data for a model run. The dataframe must contain the following columns: \code{date}, \code{site}, \code{lab}, \code{log_genome_copies_per_ml}, -\code{lab_site_index}, \code{log_lod}, \code{below_lod}, \code{site_pop} \code{exclude}} +\code{lab_site_index}, \code{log_lod}, \code{below_lod}, \code{site_pop} \code{exclude}.} \item{count_data}{A dataframe containing the pre-procssed, "global" (e.g. state) daily count data, pertaining to the number of events that are being @@ -52,13 +52,15 @@ forecast date, to produce forecasts for, default is \code{28}} example data provided by the package, but this should be specified by the user based on the date they are producing a forecast} -\item{fit_opts}{The fit options, which in this case default to the -MCMC parameters as defined using \code{get_mcmc_options()}. This includes -the following arguments, which are passed to -\code{\link[cmdstanr:model-method-sample]{$sample()}}: -the number of chains, the number of warmup -and sampling iterations, the maximum tree depth, the average acceptance -probability, and the stan PRNG seed} +\item{fit_opts}{MCMC fitting options, as a list of keys and values. +These are passed as keyword arguments to +\code{\link[cmdstanr:model-method-sample]{compiled_model$sample()}}. +Where no option is specified, \code{\link[=wwinference]{wwinference()}} will fall back first on a +package-specific default value given by \code{\link[=get_mcmc_options]{get_mcmc_options()}}, if one exists. +If no package-specific default exists, \code{\link[=wwinference]{wwinference()}} will fall back on +the default value defined in \code{\link[cmdstanr:model-method-sample]{$sample()}}. +See the documentation for \code{\link[cmdstanr:model-method-sample]{$sample()}} for +details on available options.} \item{generate_initial_values}{Boolean indicating whether or not to specify the initialization of the sampler, default is \code{TRUE}, meaning that @@ -182,28 +184,32 @@ forecast_date <- "2023-11-06" calibration_time <- 90 forecast_horizon <- 28 include_ww <- 1 -ww_fit <- wwinference(input_ww_data, - input_count_data, + +ww_fit <- wwinference( + ww_data = input_ww_data, + count_data = input_count_data, + forecast_date = forecast_date, + calibration_time = calibration_time, + forecast_horizon = forecast_horizon, model_spec = get_model_spec( - forecast_date = forecast_date, - calibration_time = calibration_time, - forecast_horizon = forecast_horizon, generation_interval = generation_interval, - inf_to_count_delay = inf_to_coutn_delay, + inf_to_count_delay = inf_to_count_delay, infection_feedback_pmf = infection_feedback_pmf, params = params ), - fit_opts = get_mcmc_options( + fit_opts = list( iter_warmup = 250, iter_sampling = 250, - n_chains = 2 + chains = 2 ) ) } + } \seealso{ Other diagnostics: \code{\link{get_model_diagnostic_flags}()}, -\code{\link{parameter_diagnostics}()} +\code{\link{parameter_diagnostics}()}, +\code{\link{summary_diagnostics}()} } \concept{diagnostics} diff --git a/model_definition.md b/model_definition.md index 88af1da4..36646fb4 100644 --- a/model_definition.md +++ b/model_definition.md @@ -65,10 +65,15 @@ The total population consists of $K_\mathrm{total}$ subpopulations $k$ with corr Whenever the sum of the wastewater catchment population sizes $\sum\nolimits_{k=1}^{K_\mathrm{sites}} n_k$ is less than the total population size $n$, we use an additional subpopulation of size $n - \sum\nolimits_{k=1}^{K_\mathrm{sites}} n_k$ to model individuals in the population who are not covered by wastewater sampling. The total number of subpopulations is then $K_\mathrm{total} = K_\mathrm{sites} + 1$: the $K_\mathrm{sites}$ subpopulations with sampled wastewater, and the final subpopulation to account for individuals not covered by wastewater sampling. +The model without wastewater (hospital admissions only model) is therefore a special case of the model where $K_\mathrm{sites} = 0$ and $K_\mathrm{total} = 1$, with subpopulation size $n_k = n$, the total population. +In the case where the sum of the wastewater site catchment populations meets or exceeds the total population ($\sum\nolimits_{k=1}^{K_\mathrm{sites}} n_k \ge n$) the model does not use a final subpopulation without sampled wastewater. In that case, the total number of subpopulations $K_\mathrm{total} = K_\mathrm{sites}$. This amounts to modeling the wastewater catchments populations as approximately non-overlapping; every infected individual either does not contribute to measured wastewater or contributes principally to one wastewater catchment. This approximation is reasonable if we restrict our analyses to primary wastewaster treatment plants, which avoids the possibility that an individual might be sampled once in a sample taken upstream and then sampled again in a more aggregated sample taken further downstream. +<<<<<<< HEAD If the sum of the wastewater site catchment populations meets or exceeds the reported jurisdiction population ($\sum\nolimits_{k=1}^{K_\mathrm{sites}} n_k \ge n$) the model does not use a final subpopulation without sampled wastewater. In that case, the total number of subpopulations $K_\mathrm{total} = K_\mathrm{sites}$. +======= +>>>>>>> main When converting from predicted per capita incident hospital admissions $H(t)$ to predicted hospitalization counts, we use the jurisdiction population size $n$, even in the case where $\sum n_k > n$. @@ -76,21 +81,26 @@ This amounts to making two key additional modeling assumptions: - Any individuals who contribute to wastewaster measurements but are not part of the total population are distributed among the catchment populations approximately proportional to catchment population size. - Whenever $\sum n_k \ge n$, the fraction of individuals in the jurisdiction not covered by wastewater is small enough to have minimal impact on the jurisdiction-wide per capita infection dynamics. +The hierarchical subpopulation structure linking infection dynamics in each subpopulation to a central or "global" dynamic is implemented using a reference subpopulation. +The reference subpopulation is by default the subpopulation not covered by wastewater, or in the case where the sum of the wastewater site catchment populations meet or exceed the total population ($\sum\nolimits_{k=1}^{K_\mathrm{sites}} n_k \ge n$), the reference subpopulation is by default the wastewater catchment area with the largest population size. + #### Subpopulation-level infections -We couple the subpopulation and total population infection dynamics at the level of the un-damped instantaneous reproduction number $\mathcal{R}^\mathrm{u}(t)$. +We couple the subpopulation and total population infection dynamics at the level of the un-damped instantaneous reproduction number in the reference subpopulation, $\mathcal{R}^\mathrm{u}_ {0}(t)$. -We model the subpopulations as having infection dynamics that are _similar_ to one another but can differ from the overall "global" dynamic. +We model the subpopulations as having infection dynamics that are _similar_ to one another but can differ from the reference subpopulation dynamic. -We represent this with a hierarchical model where we first model a "global" un-damped effective reproductive number $\mathcal{R}^\mathrm{u}(t)$, but then allow individual subpopulations $k$ to have individual subpopulation values of $\mathcal{R}^\mathrm{u}_{k}(t)$ +We represent this with a hierarchical model where we estimate the reference subpopulation's un-damped effective reproductive number $\mathcal{R}^\mathrm{u}_ {0}(t)$ and then estimate the individual subpopulations $k$ deviations from the reference value, $\mathcal{R}^\mathrm{u}_{k}(t)$ -The "global" model for the undamped instantaneous reproductive number $\mathcal{R}^\mathrm{u}(t)$ follows the time-evolution described above. -Subpopulation deviations from the "global" reproduction number are modeled via a log-scale AR(1) process. Specifically, for subpopulation $k$: +The refrence value for the undamped instantaneous reproductive number $\mathcal{R}^\mathrm{u}(t)$ follows the time-evolution described above. +Subpopulation deviations from the reference reproduction number are modeled via a log-scale AR(1) process. Specifically, for subpopulation $k$: $$ -\log[\mathcal{R}^\mathrm{u}_{k}(t)] = \log[\mathcal{R}^\mathrm{u}(t)] + \delta_k(t) +\log[\mathcal{R}^\mathrm{u}_{k}(t)] = \log[\mathcal{R}^\mathrm{u}_0(t)] + m +\delta_k(t) $$ -where $\delta_k(t)$ is the time-varying subpopulation effect on $\mathcal{R}(t)$, modeled as, +where $m$ is an "intercept" for the reference subpopulation, which is a fixed parameter and allows for the fact that $\log[\mathcal{R}^\mathrm{u}_ {0}(t)]$ may differ from the central dynamic by $m$. + +The time-varying subpopulation effect on $log[\mathcal{R}_ {0}(t)]$, $\delta_k(t)$ is modeled as: $$\delta_k(t) = \varphi_{R(t)} \delta_k(t-1) + \epsilon_{kt}$$ diff --git a/tests/testthat/helper.R b/tests/testthat/helper.R index 59e37f77..8264c677 100644 --- a/tests/testthat/helper.R +++ b/tests/testthat/helper.R @@ -131,3 +131,10 @@ diff_ar1_from_z_scores_alt <- function(x0, ar, sd, z, stationary = FALSE) { return(x) } + +silent_wwinference <- function(...) { + utils::capture.output( + fit <- suppressMessages(wwinference(...)) + ) + return(fit) +} diff --git a/tests/testthat/test_checkers.R b/tests/testthat/test_checkers.R index b1d97ba3..f4960309 100644 --- a/tests/testthat/test_checkers.R +++ b/tests/testthat/test_checkers.R @@ -12,11 +12,23 @@ test_that( max_date <- lubridate::ymd("2024-01-02") - expect_error(assert_no_dates_after_max(date_vector, max_date)) + expect_error( + assert_no_dates_after_max(date_vector, max_date, + arg_dates = "example data", + arg_max_date = "maximum date" + ), + regexp = "The example data passed in has observations" + ) - max_date <- "character" + max_date <- as.character("2024-01-02") - expect_error(assert_no_dates_after_max(date_vector, max_date)) + expect_error( + assert_no_dates_after_max(date_vector, max_date, + arg_dates = "example data", + arg_max_date = "maximum date" + ), + regexp = "The example data passed in has observations" + ) } ) @@ -176,7 +188,7 @@ test_that( ) count_col_name <- "hosp" pop_size_col_name <- "pop" - expect_no_error(check_req_count_cols_present( + expect_no_error(assert_req_count_cols_present( x, count_col_name, pop_size_col_name @@ -190,7 +202,7 @@ test_that( ) count_col_name <- "count" pop_size_col_name <- "pop" - expect_error(check_req_hosp_columns_present( + expect_error(assert_req_count_columns_present( x, count_col_name, pop_size_col_name @@ -204,7 +216,7 @@ test_that( ) count_col_name <- "hosp" pop_size_col_name <- "pop" - expect_error(check_req_hosp_columns_present( + expect_error(assert_req_count_columns_present( x, count_col_name, pop_size_col_name @@ -269,44 +281,31 @@ test_that( } ) +test_that( + "Test that validate pmfs returns the expected error message.", + { + invalid_pmf <- c(0.4, 0.4, 0.4) + expect_error(validate_pmf(invalid_pmf), + regexp = "does not sum to 1" + ) + } +) + test_that( "Test that assert dates in range function works as expected.", { dates1 <- lubridate::ymd(c("2023-01-01", "2023-01-02")) dates2 <- lubridate::ymd(c("2023-01-01", "2023-01-04")) - max_date <- "2023-01-05" expect_no_error(assert_dates_within_frame( dates1, - dates2, - max_date - )) - - - dates1 <- lubridate::ymd(c("2023-01-01", "2023-01-02")) - dates2 <- lubridate::ymd(c("2023-01-03", "2023-01-04")) - max_date <- "2023-01-05" - expect_no_error(assert_dates_within_frame( - dates1, - dates2, - max_date + dates2 )) dates1 <- lubridate::ymd(c("2023-01-01", "2023-01-02")) dates2 <- lubridate::ymd(c("2024-01-03", "2024-01-04")) - max_date <- "2023-01-05" - expect_error(assert_dates_within_frame( - dates1, - dates2, - max_date - )) - - dates1 <- lubridate::ymd(c("2023-01-01", "2023-01-02")) - dates2 <- lubridate::ymd(c("2023-01-03", "2023-01-04")) - max_date <- "2022-01-05" expect_error(assert_dates_within_frame( dates1, - dates2, - max_date + dates2 )) } ) diff --git a/tests/testthat/test_get_stan_data.R b/tests/testthat/test_get_stan_data.R index 300d879d..d71638ff 100644 --- a/tests/testthat/test_get_stan_data.R +++ b/tests/testthat/test_get_stan_data.R @@ -63,6 +63,184 @@ input_ww_data <- get_input_ww_data_for_stan( last_count_data_date, calibration_time ) +date_time_spine <- get_date_time_spine( + forecast_date = forecast_date, + input_count_data = input_count_data, + last_count_data_date = last_count_data_date, + forecast_horizon = forecast_horizon, + calibration_time = calibration_time +) + +lab_site_site_spine <- get_lab_site_site_spine( + input_ww_data = input_ww_data +) + +site_subpop_spine <- get_site_subpop_spine( + input_ww_data = input_ww_data, + input_count_data = input_count_data +) + +lab_site_subpop_spine <- get_lab_site_subpop_spine( + lab_site_site_spine = lab_site_site_spine, + site_subpop_spine = site_subpop_spine +) + + +test_that(paste0( + "Test that the number of subpopulations is correct for the", + "standard case where sum(site_pops) < total_pop" +), { + stan_data <- get_stan_data( + input_count_data, + input_ww_data, + date_time_spine, + lab_site_site_spine, + site_subpop_spine, + lab_site_subpop_spine, + last_count_data_date, + first_count_data_date, + forecast_date, + forecast_horizon, + calibration_time, + generation_interval, + inf_to_count_delay, + infection_feedback_pmf, + params, + include_ww, + dist_matrix = NULL, + corr_structure_switch = 0 + ) + + expect_equal(stan_data$n_subpop, (stan_data$n_ww_sites + 1)) + expect_equal(length(stan_data$subpop_size), stan_data$n_subpops) +}) + +test_that(paste0( + "Test that the number of subpopulations is correct for the ", + "standard case where sum(site_pops) > total_pop" +), { + input_count_data_mod <- input_count_data + input_count_data_mod$total_pop <- sum(unique(input_ww_data$site_pop) - 100) + site_subpop_spine_mod <- get_site_subpop_spine( + input_ww_data = input_ww_data, + input_count_data = input_count_data_mod + ) + + lab_site_subpop_spine_mod <- get_lab_site_subpop_spine( + lab_site_site_spine = lab_site_site_spine, + site_subpop_spine = site_subpop_spine_mod + ) + + expect_warning({ + stan_data_mod <- get_stan_data( + input_count_data_mod, + input_ww_data, + date_time_spine, + lab_site_site_spine, + site_subpop_spine_mod, + lab_site_subpop_spine_mod, + last_count_data_date, + first_count_data_date, + forecast_date, + forecast_horizon, + calibration_time, + generation_interval, + inf_to_count_delay, + infection_feedback_pmf, + params, + include_ww, + dist_matrix = NULL, + corr_structure_switch = 0 + ) + }) + + expect_equal(stan_data_mod$n_subpop, (stan_data_mod$n_ww_sites)) + expect_equal(length(stan_data_mod$subpop_size), stan_data_mod$n_subpops) + expect_equal(stan_data_mod$norm_pop, sum(stan_data_mod$subpop_size)) +}) + +test_that(paste0( + "Test that the model handles include_ww = 0 ", + "appropriately by only estimating one subpopulation" +), { + # This happens upstream in wwinference + input_ww_data_mod <- NULL + site_subpop_spine_mod <- get_site_subpop_spine( + input_ww_data = input_ww_data_mod, + input_count_data = input_count_data + ) + + lab_site_subpop_spine_mod <- get_lab_site_subpop_spine( + lab_site_site_spine = lab_site_site_spine, + site_subpop_spine = site_subpop_spine_mod + ) + + stan_data_ho <- get_stan_data( + input_count_data, + input_ww_data_mod, + date_time_spine, + lab_site_site_spine, + site_subpop_spine_mod, + lab_site_subpop_spine_mod, + last_count_data_date, + first_count_data_date, + forecast_date, + forecast_horizon, + calibration_time, + generation_interval, + inf_to_count_delay, + infection_feedback_pmf, + params, + include_ww = 0, + dist_matrix = NULL, + corr_structure_switch = 0 + ) + + expect_equal(stan_data_ho$n_subpops, 1) + expect_equal(length(stan_data_ho$subpop_size), 1) +}) + +test_that(paste0( + "Test that the model handles include_ww = 0 ", + "and no data appropriately" +), { + null_ww_data <- NULL + + site_subpop_spine_mod <- get_site_subpop_spine( + input_ww_data = null_ww_data, + input_count_data = input_count_data + ) + + lab_site_subpop_spine_mod <- get_lab_site_subpop_spine( + lab_site_site_spine = lab_site_site_spine, + site_subpop_spine = site_subpop_spine_mod + ) + + stan_data_ho <- get_stan_data( + input_count_data, + input_ww_data = null_ww_data, + date_time_spine, + lab_site_site_spine, + site_subpop_spine_mod, + lab_site_subpop_spine_mod, + last_count_data_date, + first_count_data_date, + forecast_date, + forecast_horizon, + calibration_time, + generation_interval, + inf_to_count_delay, + infection_feedback_pmf, + params, + include_ww = 0, + dist_matrix = NULL, + corr_structure_switch = 0 + ) + + expect_equal(stan_data_ho$n_subpops, 1) + expect_equal(length(stan_data_ho$subpop_size), 1) +}) + test_that(paste0( @@ -76,6 +254,7 @@ test_that(paste0( expect_true(nrow(result) == 80) }) + test_that(paste0( "Test that things not flagged for removal don't get removed ", "and things that are flagged for removal do get removed" @@ -137,10 +316,38 @@ test_that(paste0( last_count_data_date, calibration_time ) + date_time_spine <- get_date_time_spine( + forecast_date = forecast_date, + input_count_data = input_count_data, + last_count_data_date = last_count_data_date, + forecast_horizon = forecast_horizon, + calibration_time = calibration_time + ) + + lab_site_site_spine_od <- get_lab_site_site_spine( + input_ww_data = recent_input_ww_data_for_stan + ) + + site_subpop_spine_od <- get_site_subpop_spine( + input_ww_data = recent_input_ww_data_for_stan, + input_count_data = input_count_data + ) + + lab_site_subpop_spine_od <- get_lab_site_subpop_spine( + lab_site_site_spine = lab_site_site_spine, + site_subpop_spine = site_subpop_spine_od + ) + expect_error(get_stan_data( input_count_data, recent_input_ww_data_for_stan, + date_time_spine, + lab_site_site_spine_od, + site_subpop_spine_od, + lab_site_subpop_spine_od, + last_count_data_date, + first_count_data_date, forecast_date, forecast_horizon, calibration_time, @@ -181,11 +388,37 @@ test_that(paste0( last_count_data_date, calibration_time ) + date_time_spine <- get_date_time_spine( + forecast_date = forecast_date, + input_count_data = input_count_data, + last_count_data_date = last_count_data_date, + forecast_horizon = forecast_horizon, + calibration_time = calibration_time + ) + lab_site_site_spine_old <- get_lab_site_site_spine( + input_ww_data = old_input_ww_data_for_stan + ) + + site_subpop_spine_old <- get_site_subpop_spine( + input_ww_data = old_input_ww_data_for_stan, + input_count_data = input_count_data + ) + + lab_site_subpop_spine_old <- get_lab_site_subpop_spine( + lab_site_site_spine = lab_site_site_spine_old, + site_subpop_spine = site_subpop_spine_old + ) expect_error(get_stan_data( input_count_data, old_input_ww_data, + date_time_spine, + lab_site_site_spine_od, + site_subpop_spine_od, + lab_site_subpop_spine_od, + last_count_data_date, + first_count_data_date, forecast_date, forecast_horizon, calibration_time, @@ -203,6 +436,12 @@ test_that("Test that pmf check works as expected", { expect_warning(get_stan_data( input_count_data, input_ww_data, + date_time_spine, + lab_site_site_spine, + site_subpop_spine, + lab_site_subpop_spine, + last_count_data_date, + first_count_data_date, forecast_date, forecast_horizon, calibration_time, @@ -218,6 +457,12 @@ test_that("Test that pmf check works as expected", { expect_warning(get_stan_data( input_count_data, input_ww_data, + date_time_spine, + lab_site_site_spine, + site_subpop_spine, + lab_site_subpop_spine, + last_count_data_date, + first_count_data_date, forecast_date, forecast_horizon, calibration_time, @@ -233,6 +478,12 @@ test_that("Test that pmf check works as expected", { expect_warning(get_stan_data( input_count_data, input_ww_data, + date_time_spine, + lab_site_site_spine, + site_subpop_spine, + lab_site_subpop_spine, + last_count_data_date, + first_count_data_date, forecast_date, forecast_horizon, calibration_time, @@ -248,6 +499,12 @@ test_that("Test that pmf check works as expected", { expect_error(get_stan_data( input_count_data, input_ww_data, + date_time_spine, + lab_site_site_spine, + site_subpop_spine, + lab_site_subpop_spine, + last_count_data_date, + first_count_data_date, forecast_date, forecast_horizon, calibration_time, diff --git a/tests/testthat/test_helper.R b/tests/testthat/test_helper.R index 8b6b9480..84673e18 100644 --- a/tests/testthat/test_helper.R +++ b/tests/testthat/test_helper.R @@ -1,13 +1,4 @@ test_that("Make sure we can find and load files we need for other tests.", { - testthat::expect_true( - exists("model_test_data") - ) - - testthat::expect_true( - exists("test_fit_last_draw") - ) - - # Compiled model object should exist in the workspace, with functions exposed testthat::expect_true( exists("compiled_site_inf_model") diff --git a/tests/testthat/test_models_run_without_ww.R b/tests/testthat/test_models_run_without_ww.R new file mode 100644 index 00000000..50cdd5de --- /dev/null +++ b/tests/testthat/test_models_run_without_ww.R @@ -0,0 +1,122 @@ +options(cmdstanr_warn_inits = FALSE) + +hosp_data <- wwinference::hosp_data +ww_data <- wwinference::ww_data +params <- wwinference::get_params( + fs::path_package("extdata", "example_params.toml", + package = "wwinference" + ) +) + + +# Data pre-processing -------------------------------------------------------- +ww_data_preprocessed <- wwinference::preprocess_ww_data( + ww_data, + conc_col_name = "log_genome_copies_per_ml", + lod_col_name = "log_lod" +) + +hosp_data_preprocessed <- wwinference::preprocess_count_data( + hosp_data, + count_col_name = "daily_hosp_admits", + pop_size_col_name = "state_pop" +) + +ww_data_to_fit <- wwinference::indicate_ww_exclusions( + ww_data_preprocessed, + outlier_col_name = "flag_as_ww_outlier", + remove_outliers = TRUE +) + +forecast_date <- "2023-12-06" +calibration_time <- 90 +forecast_horizon <- 28 +generation_interval <- wwinference::default_covid_gi +inf_to_hosp <- wwinference::default_covid_inf_to_hosp + +# Assign infection feedback equal to the generation interval +infection_feedback_pmf <- generation_interval + +model_spec <- wwinference::get_model_spec( + generation_interval = generation_interval, + inf_to_count_delay = inf_to_hosp, + infection_feedback_pmf = infection_feedback_pmf, + params = params +) + +mcmc_options <- list( + seed = 5, + iter_warmup = 500, + iter_sampling = 250, + chains = 2, + show_messages = FALSE, + show_exceptions = FALSE +) + +generate_initial_values <- TRUE + +model_test_data <- list( + ww_data = ww_data_to_fit, + count_data = hosp_data_preprocessed, + forecast_date = forecast_date, + calibration_time = calibration_time, + forecast_horizon = forecast_horizon, + model_spec = model_spec, + fit_opts = mcmc_options, + generate_initial_values = generate_initial_values, + compiled_model = compiled_site_inf_model +) + + +test_that("Test that the model runs on simulated data when include_ww=0.", { + ####### + # run model briefly on the simulated data + ####### + model_test_data_no_ww <- model_test_data + model_test_data_no_ww$model_spec$include_ww <- 0 + + expect_no_error(withr::with_seed(55, { + fit <- do.call( + wwinference::wwinference, + model_test_data_no_ww + ) + })) +}) + +test_that("Test that the model runs without wastewater, include_ww=0.", { + ####### + # run model briefly on the simulated data + ####### + model_test_data_no_ww <- model_test_data + model_test_data_no_ww$model_spec$include_ww <- 0 + model_test_data_no_ww$ww_data <- tibble::tibble() + + expect_warning( + withr::with_seed(55, { + fit <- do.call( + wwinference::wwinference, + model_test_data_no_ww + ) + }), + regex = "No wastewater data was passed to the model." + ) +}) + +test_that("Test that the model runs without wastewater, include_ww=1.", { + ####### + # run model briefly on the simulated data + ####### + model_test_data_no_ww <- model_test_data + model_test_data_no_ww$model_spec$include_ww <- 1 + model_test_data_no_ww$ww_data <- tibble::tibble() + + expect_warning( + withr::with_seed(55, { + fit <- do.call( + wwinference::wwinference, + model_test_data_no_ww + ) + }), + regex = "No wastewater data was passed to the model." + ) +}) diff --git a/tests/testthat/test_preprocess_ww_data.R b/tests/testthat/test_preprocess_ww_data.R index 9da1e350..39d47c44 100644 --- a/tests/testthat/test_preprocess_ww_data.R +++ b/tests/testthat/test_preprocess_ww_data.R @@ -1,13 +1,29 @@ # Test data setup ww_data <- tibble::tibble( date = lubridate::ymd(rep(c("2023-11-01", "2023-11-02"), 2)), - site = c(rep(1, 2), rep(2, 2)), + site = c("1", "1", "2", "2"), lab = c(1, 1, 1, 1), - conc = c(345.2, 784.1, 401.5, 681.8), - lod = c(20, 20, 15, 15), - site_pop = c(rep(1e6, 2), rep(3e5, 2)) + conc = log(c(345.2, 784.1, 401.5, 681.8)), + lod = log(c(20, 20, 15, 15)), + site_pop = c(rep(3e5, 2), rep(1e6, 2)), + location = c(rep("MA", 4)) ) +# Test that function returns a dataframe with site indices ordered by +# population size (with first index at highest pop) +test_that("Function returns site indices in order of largest site pop", { + processed <- preprocess_ww_data(ww_data, + conc_col_name = "conc", + lod_col_name = "lod" + ) + + spine <- processed |> distinct(site_pop, site_index) + + + expect_true(spine$site_pop[spine$site_index == 1] == max(spine$site_pop)) +}) + + # Test that function returns a dataframe with correct columns test_that("Function returns dataframe with correct columns", { processed <- preprocess_ww_data(ww_data, @@ -24,6 +40,87 @@ test_that("Function returns dataframe with correct columns", { checkmate::expect_names(names(processed), must.include = expected_cols) }) +# Test that can pass either integer or character site names +ww_data_char <- tibble::tibble( + date = lubridate::ymd(rep(c("2023-11-01", "2023-11-02"), 2)), + site = c("1", "1", "2", "2"), + lab = c(1, 1, 1, 1), + conc = log(c(345.2, 784.1, 401.5, 681.8)), + lod = log(c(20, 20, 15, 15)), + site_pop = c(rep(1e6, 2), rep(3e5, 2)), + location = c(rep("MA", 4)) +) + +ww_data_int <- tibble::tibble( + date = lubridate::ymd(rep(c("2023-11-01", "2023-11-02"), 2)), + site = c(1, 1, 2, 2), + lab = c(1, 1, 1, 1), + conc = log(c(345.2, 784.1, 401.5, 681.8)), + lod = log(c(20, 20, 15, 15)), + site_pop = c(rep(1e6, 2), rep(3e5, 2)), + location = c(rep("MA", 4)) +) + +ww_data_int_alt <- tibble::tibble( + date = lubridate::ymd(rep(c("2023-11-01", "2023-11-02"), 2)), + site = c(5, 5, 1, 1), + lab = c(1, 1, 1, 1), + conc = log(c(345.2, 784.1, 401.5, 681.8)), + lod = log(c(20, 20, 15, 15)), + site_pop = c(rep(1e6, 2), rep(3e5, 2)), + location = c(rep("MA", 4)) +) + +test_that("Function returns dataframe with correct site indices", { + processed_int <- preprocess_ww_data(ww_data_int, + conc_col_name = "conc", + lod_col_name = "lod" + ) + processed_char <- preprocess_ww_data(ww_data_char, + conc_col_name = "conc", + lod_col_name = "lod" + ) + processed_int_alt <- preprocess_ww_data(ww_data_int_alt, + conc_col_name = "conc", + lod_col_name = "lod" + ) + # site indices should be the same even if sites are not ordered or are + # characters + expect_equal(processed_int$site_index, processed_char$site_index) + expect_equal(processed_int_alt$site_index, processed_int$site_index) + expect_equal(processed_int_alt$site_index, processed_char$site_index) + # Ordering shouldn't change even if site integers not in order + expect_equal(processed_int_alt$site, ww_data_int_alt$site) +}) + +ww_data_w_repeats <- tibble::tibble( + date = lubridate::ymd( + rep(c("2023-11-01", "2023-11-02"), 2), + "2023-11-02" + ), + site = c("1", "1", "2", "2", "2"), + lab = c(1, 1, 1, 1, 1), + conc = log(c(345.2, 784.1, 401.5, 681.8, 681.8)), + lod = log(c(20, 20, 15, 15, 15)), + site_pop = c(rep(1e6, 2), rep(3e5, 3)), + location = c(rep("MA", 5)) +) + +test_that("Function returns an error if there are repeated values", { + msg <- expect_error( + preprocess_ww_data(ww_data_w_repeats, + conc_col_name = "conc", + lod_col_name = "lod" + ), "The data has more than one observation per `lab-site-day`" + ) + + expect_no_error(preprocess_ww_data(ww_data, + conc_col_name = "conc", + lod_col_name = "lod" + )) +}) + + # Test that concentration column is renamed correctly test_that("Concentration column is renamed correctly", { processed <- preprocess_ww_data(ww_data, @@ -192,8 +289,8 @@ test_that("lab_site_name is constructed properly", { ) expected_lab_site_names <- c( - "Site: 1, Lab: 1", "Site: 1, Lab: 1", - "Site: 2, Lab: 1", "Site: 2, Lab: 1" + "Site: 2, Lab: 1", "Site: 2, Lab: 1", + "Site: 1, Lab: 1", "Site: 1, Lab: 1" ) expect_equal(processed$lab_site_name, expected_lab_site_names) diff --git a/tests/testthat/test_wwinference.R b/tests/testthat/test_wwinference.R index 6abf6ab5..bcd266ae 100644 --- a/tests/testthat/test_wwinference.R +++ b/tests/testthat/test_wwinference.R @@ -59,12 +59,9 @@ test_that("wwinference model can compile", { test_that("Function to get mcmc options produces the expected outputs", { mcmc_options <- get_mcmc_options() expected_names <- c( - "iter_warmup", "iter_sampling", - "n_chains", "seed", "adapt_delta", "max_treedepth", - "compute_likelihood" + "iter_warmup", "iter_sampling", "seed", "adapt_delta", "max_treedepth" ) - # Checkmade doesn't work here for a list, says it must be a character vector - expect_true(all(names(mcmc_options) %in% expected_names)) + checkmate::expect_names(names(mcmc_options), must.include = expected_names) }) test_that("Function to get model specs produces expected outputs", { @@ -77,3 +74,16 @@ test_that("Function to get model specs produces expected outputs", { # Checkmade doesn't work here for a list, says it must be a character vector expect_true(all(names(model_spec) %in% expected_names)) }) + +test_that("Passing invalid args to fit_opts throws an error ", { + expect_error( + wwinference( + ww_data = input_ww_data, + count_data = input_count_data, + forecast_date = forecast_date, + model_spec = get_model_spec(), + fit_opts = list(not_an_arg = 4) + ), + regexp = c("Names must be a subset of ") + ) +}) diff --git a/vignettes/spatial_wwinference.Rmd b/vignettes/spatial_wwinference.Rmd index 8e0741dd..9c9936d1 100644 --- a/vignettes/spatial_wwinference.Rmd +++ b/vignettes/spatial_wwinference.Rmd @@ -1071,20 +1071,29 @@ data generation model and inference model was used.
```{r} set.seed(2024) -draws_to_keep <- sample(1:max(get_draws_df(fit_iid_to_iid)$draw), 100) +draws_to_keep <- sample(1:max(get_draws( + fit_iid_to_iid, + what = "predicted_counts" +)$predicted_counts$draw), 100) # IID data --------------------------------------------------------------------- iid_pred_draws_df <- rbind( - get_draws_df(fit_iid_to_iid) %>% + get_draws(fit_iid_to_iid, + what = "predicted_counts" + )$predicted_counts %>% filter(draw %in% draws_to_keep) %>% mutate( inf_model_type = "IID" ), - get_draws_df(fit_iid_to_exp) %>% + get_draws(fit_iid_to_exp, + what = "predicted_counts" + )$predicted_counts %>% filter(draw %in% draws_to_keep) %>% mutate( inf_model_type = "Exponential" ), - get_draws_df(fit_iid_to_unstruct) %>% + get_draws(fit_iid_to_unstruct, + what = "predicted_counts" + )$predicted_counts %>% filter(draw %in% draws_to_keep) %>% mutate( inf_model_type = "Unstructured" @@ -1096,17 +1105,23 @@ iid_pred_draws_df <- rbind( # ------------------------------------------------------------------------------ # Exponential data ------------------------------------------------------------- exp_pred_draws_df <- rbind( - get_draws_df(fit_exp_to_iid) %>% + get_draws(fit_exp_to_iid, + what = "predicted_counts" + )$predicted_counts %>% filter(draw %in% draws_to_keep) %>% mutate( inf_model_type = "IID" ), - get_draws_df(fit_exp_to_exp) %>% + get_draws(fit_exp_to_exp, + what = "predicted_counts" + )$predicted_counts %>% filter(draw %in% draws_to_keep) %>% mutate( inf_model_type = "Exponential" ), - get_draws_df(fit_exp_to_unstruct) %>% + get_draws(fit_exp_to_unstruct, + what = "predicted_counts" + )$predicted_counts %>% filter(draw %in% draws_to_keep) %>% mutate( inf_model_type = "Unstructured" @@ -1118,17 +1133,23 @@ exp_pred_draws_df <- rbind( # ------------------------------------------------------------------------------ # Rand. Corr. Matrix data ------------------------------------------------------ rand_pred_draws_df <- rbind( - get_draws_df(fit_rand_to_iid) %>% + get_draws(fit_rand_to_iid, + what = "predicted_counts" + )$predicted_counts %>% filter(draw %in% draws_to_keep) %>% mutate( inf_model_type = "IID" ), - get_draws_df(fit_rand_to_exp) %>% + get_draws(fit_rand_to_exp, + what = "predicted_counts" + )$predicted_counts %>% filter(draw %in% draws_to_keep) %>% mutate( inf_model_type = "Exponential" ), - get_draws_df(fit_rand_to_unstruct) %>% + get_draws(fit_rand_to_unstruct, + what = "predicted_counts" + )$predicted_counts %>% filter(draw %in% draws_to_keep) %>% mutate( inf_model_type = "Unstructured" @@ -1161,6 +1182,331 @@ all_pred_draws_df <- rbind( ) ) ) + +# Wastewater draws------------------------------ +iid_ww_draws_df <- rbind( + get_draws(fit_iid_to_iid, + what = "predicted_ww" + )$predicted_ww %>% + filter(draw %in% draws_to_keep) %>% + mutate( + inf_model_type = "IID" + ), + get_draws(fit_iid_to_exp, + what = "predicted_ww" + )$predicted_ww %>% + filter(draw %in% draws_to_keep) %>% + mutate( + inf_model_type = "Exponential" + ), + get_draws(fit_iid_to_unstruct, + what = "predicted_ww" + )$predicted_ww %>% + filter(draw %in% draws_to_keep) %>% + mutate( + inf_model_type = "Unstructured" + ) +) %>% + mutate( + gen_model_type = "IID" + ) +# ------------------------------------------------------------------------------ +# Exponential data ------------------------------------------------------------- +exp_ww_draws_df <- rbind( + get_draws(fit_exp_to_iid, + what = "predicted_ww" + )$predicted_ww %>% + filter(draw %in% draws_to_keep) %>% + mutate( + inf_model_type = "IID" + ), + get_draws(fit_exp_to_exp, + what = "predicted_ww" + )$predicted_ww %>% + filter(draw %in% draws_to_keep) %>% + mutate( + inf_model_type = "Exponential" + ), + get_draws(fit_exp_to_unstruct, + what = "predicted_ww" + )$predicted_ww %>% + filter(draw %in% draws_to_keep) %>% + mutate( + inf_model_type = "Unstructured" + ) +) %>% + mutate( + gen_model_type = "Exponential" + ) +# ------------------------------------------------------------------------------ +# Rand. Corr. Matrix data ------------------------------------------------------ +rand_ww_draws_df <- rbind( + get_draws(fit_rand_to_iid, + what = "predicted_ww" + )$predicted_ww %>% + filter(draw %in% draws_to_keep) %>% + mutate( + inf_model_type = "IID" + ), + get_draws(fit_rand_to_exp, + what = "predicted_ww" + )$predicted_ww %>% + filter(draw %in% draws_to_keep) %>% + mutate( + inf_model_type = "Exponential" + ), + get_draws(fit_rand_to_unstruct, + what = "predicted_ww" + )$predicted_ww %>% + filter(draw %in% draws_to_keep) %>% + mutate( + inf_model_type = "Unstructured" + ) +) %>% + mutate( + gen_model_type = "Rand. Corr. Matrix" + ) + +all_ww_draws_df <- rbind( + iid_ww_draws_df, + exp_ww_draws_df, + rand_ww_draws_df +) %>% + mutate( + inf_model_type = factor( + inf_model_type, + levels = c( + "Exponential", + "Unstructured", + "IID" + ) + ), + gen_model_type = factor( + gen_model_type, + levels = c( + "IID", + "Exponential", + "Rand. Corr. Matrix" + ) + ) + ) + +# Global R(t) draws------------------------------------------- +iid_rt_draws_df <- rbind( + get_draws(fit_iid_to_iid, + what = "global_rt" + )$global_rt %>% + filter(draw %in% draws_to_keep) %>% + mutate( + inf_model_type = "IID" + ), + get_draws(fit_iid_to_exp, + what = "global_rt" + )$global_rt %>% + filter(draw %in% draws_to_keep) %>% + mutate( + inf_model_type = "Exponential" + ), + get_draws(fit_iid_to_unstruct, + what = "global_rt" + )$global_rt %>% + filter(draw %in% draws_to_keep) %>% + mutate( + inf_model_type = "Unstructured" + ) +) %>% + mutate( + gen_model_type = "IID" + ) +# ------------------------------------------------------------------------------ +# Exponential data ------------------------------------------------------------- +exp_rt_draws_df <- rbind( + get_draws(fit_exp_to_iid, + what = "global_rt" + )$global_rt %>% + filter(draw %in% draws_to_keep) %>% + mutate( + inf_model_type = "IID" + ), + get_draws(fit_exp_to_exp, + what = "global_rt" + )$global_rt %>% + filter(draw %in% draws_to_keep) %>% + mutate( + inf_model_type = "Exponential" + ), + get_draws(fit_exp_to_unstruct, + what = "global_rt" + )$global_rt %>% + filter(draw %in% draws_to_keep) %>% + mutate( + inf_model_type = "Unstructured" + ) +) %>% + mutate( + gen_model_type = "Exponential" + ) +# ------------------------------------------------------------------------------ +# Rand. Corr. Matrix data ------------------------------------------------------ +rand_rt_draws_df <- rbind( + get_draws(fit_rand_to_iid, + what = "global_rt" + )$global_rt %>% + filter(draw %in% draws_to_keep) %>% + mutate( + inf_model_type = "IID" + ), + get_draws(fit_rand_to_exp, + what = "global_rt" + )$global_rt %>% + filter(draw %in% draws_to_keep) %>% + mutate( + inf_model_type = "Exponential" + ), + get_draws(fit_rand_to_unstruct, + what = "global_rt" + )$global_rt %>% + filter(draw %in% draws_to_keep) %>% + mutate( + inf_model_type = "Unstructured" + ) +) %>% + mutate( + gen_model_type = "Rand. Corr. Matrix" + ) + +all_rt_draws_df <- rbind( + iid_rt_draws_df, + exp_rt_draws_df, + rand_rt_draws_df +) %>% + mutate( + inf_model_type = factor( + inf_model_type, + levels = c( + "Exponential", + "Unstructured", + "IID" + ) + ), + gen_model_type = factor( + gen_model_type, + levels = c( + "IID", + "Exponential", + "Rand. Corr. Matrix" + ) + ) + ) + +# Subpop R(t) --------------------------- + +iid_subpop_rt_draws_df <- rbind( + get_draws(fit_iid_to_iid, + what = "subpop_rt" + )$subpop_rt %>% + filter(draw %in% draws_to_keep) %>% + mutate( + inf_model_type = "IID" + ), + get_draws(fit_iid_to_exp, + what = "subpop_rt" + )$subpop_rt %>% + filter(draw %in% draws_to_keep) %>% + mutate( + inf_model_type = "Exponential" + ), + get_draws(fit_iid_to_unstruct, + what = "subpop_rt" + )$subpop_rt %>% + filter(draw %in% draws_to_keep) %>% + mutate( + inf_model_type = "Unstructured" + ) +) %>% + mutate( + gen_model_type = "IID" + ) +# ------------------------------------------------------------------------------ +# Exponential data ------------------------------------------------------------- +exp_subpop_rt_draws_df <- rbind( + get_draws(fit_exp_to_iid, + what = "subpop_rt" + )$subpop_rt %>% + filter(draw %in% draws_to_keep) %>% + mutate( + inf_model_type = "IID" + ), + get_draws(fit_exp_to_exp, + what = "subpop_rt" + )$subpop_rt %>% + filter(draw %in% draws_to_keep) %>% + mutate( + inf_model_type = "Exponential" + ), + get_draws(fit_exp_to_unstruct, + what = "subpop_rt" + )$subpop_rt %>% + filter(draw %in% draws_to_keep) %>% + mutate( + inf_model_type = "Unstructured" + ) +) %>% + mutate( + gen_model_type = "Exponential" + ) +# ------------------------------------------------------------------------------ +# Rand. Corr. Matrix data ------------------------------------------------------ +rand_subpop_rt_draws_df <- rbind( + get_draws(fit_rand_to_iid, + what = "subpop_rt" + )$subpop_rt %>% + filter(draw %in% draws_to_keep) %>% + mutate( + inf_model_type = "IID" + ), + get_draws(fit_rand_to_exp, + what = "subpop_rt" + )$subpop_rt %>% + filter(draw %in% draws_to_keep) %>% + mutate( + inf_model_type = "Exponential" + ), + get_draws(fit_rand_to_unstruct, + what = "subpop_rt" + )$subpop_rt %>% + filter(draw %in% draws_to_keep) %>% + mutate( + inf_model_type = "Unstructured" + ) +) %>% + mutate( + gen_model_type = "Rand. Corr. Matrix" + ) + +all_subpop_rt_draws_df <- rbind( + iid_subpop_rt_draws_df, + exp_subpop_rt_draws_df, + rand_subpop_rt_draws_df +) %>% + mutate( + inf_model_type = factor( + inf_model_type, + levels = c( + "Exponential", + "Unstructured", + "IID" + ) + ), + gen_model_type = factor( + gen_model_type, + levels = c( + "IID", + "Exponential", + "Rand. Corr. Matrix" + ) + ) + ) ``` @@ -1449,9 +1795,6 @@ evaluation metrics will be used to quantify forecast performance. ```{r warning=FALSE} # Hospital admissions results -------------------------------------------------- hosp_ribbon_data <- all_pred_draws_df %>% - filter( - name == "predicted counts" - ) %>% group_by( date, inf_model_type, @@ -1464,10 +1807,7 @@ hosp_ribbon_data <- all_pred_draws_df %>% .groups = "drop" ) hosp_result_plot <- ggplot( - all_pred_draws_df %>% - filter( - name == "predicted counts" - ) + all_pred_draws_df ) + geom_ribbon( data = hosp_ribbon_data, @@ -1511,18 +1851,13 @@ hosp_result_plot <- ggplot( values = c("darkviolet", "deeppink3", "darksalmon") ) + theme_bw() -# ------------------------------------------------------------------------------ + + # Wastewater results ----------------------------------------------------------- -ww_ribbon_data <- all_pred_draws_df %>% - filter( - name == "predicted wastewater" - ) %>% - mutate( - site_lab_name = glue::glue("{subpop}, Lab: {lab}") - ) %>% +ww_ribbon_data <- all_ww_draws_df %>% group_by( date, - subpop, + site, inf_model_type, gen_model_type ) %>% @@ -1533,10 +1868,7 @@ ww_ribbon_data <- all_pred_draws_df %>% .groups = "drop" ) ww_result_plot <- ggplot( - all_pred_draws_df %>% - filter( - name == "predicted wastewater" - ) + all_ww_draws_df ) + geom_ribbon( data = ww_ribbon_data, @@ -1556,7 +1888,7 @@ ww_result_plot <- ggplot( ) + xlab("") + ylab("Genome copies/mL on Log Scale") + - facet_grid(subpop ~ gen_model_type, scales = "free_y") + + facet_grid(site ~ gen_model_type, scales = "free_y") + guides( fill = guide_legend( title = "Assumed Corr. Structure" @@ -1597,10 +1929,9 @@ partially or not at all informed by recent data.
```{r} # Global Rt results ------------------------------------------------------------ -global_rt_ribbon_data <- all_pred_draws_df %>% - filter( - name == "global R(t)" - ) %>% + + +global_rt_ribbon_data <- all_rt_draws_df %>% group_by( date, inf_model_type, @@ -1613,10 +1944,7 @@ global_rt_ribbon_data <- all_pred_draws_df %>% .groups = "drop" ) global_rt_result_plot <- ggplot( - all_pred_draws_df %>% - filter( - name == "global R(t)" - ) + all_pred_draws_df ) + geom_ribbon( data = global_rt_ribbon_data, @@ -1667,13 +1995,10 @@ global_rt_result_plot <- ggplot( theme_bw() # ------------------------------------------------------------------------------ # Site Rt results -------------------------------------------------------------- -site_rt_ribbon_data <- all_pred_draws_df %>% - filter( - name == "subpopulation R(t)" - ) %>% +site_rt_ribbon_data <- all_subpop_rt_draws_df %>% group_by( date, - subpop, + subpop_name, inf_model_type, gen_model_type ) %>% @@ -1682,20 +2007,6 @@ site_rt_ribbon_data <- all_pred_draws_df %>% median = median(pred_value), upper = quantile(pred_value, 0.975, na.rm = TRUE), .groups = "drop" - ) %>% - mutate( - subpop = sub( - pattern = "Site: (\\d+)", - replacement = "Site \\1", - x = subpop, - ignore.case = "remainder of pop" - ) - ) %>% - mutate( - subpop = case_when( - subpop == "remainder of pop" ~ "Aux", - .default = subpop - ) ) site_rt_result_plot <- ggplot() + geom_ribbon( @@ -1727,7 +2038,7 @@ site_rt_result_plot <- ggplot() + ) + xlab("") + ylab("Site Rt") + - facet_grid(subpop ~ gen_model_type, scales = "free_y") + + facet_grid(subpop_name ~ gen_model_type, scales = "free_y") + guides( fill = guide_legend( title = "Assumed Corr. Structure" @@ -2043,7 +2354,6 @@ period. ```{r} hosp_obj_for_eval_forcast <- all_pred_draws_df %>% filter( - name == "predicted counts", date > forecast_date ) %>% inner_join( @@ -2241,7 +2551,6 @@ make two plots one for metrics by date, and another across all dates. ```{r} hosp_obj_for_eval_nowcast <- all_pred_draws_df %>% filter( - name == "predicted counts", date > max(hosp_data$date), date <= forecast_date ) %>% diff --git a/vignettes/wwinference.Rmd b/vignettes/wwinference.Rmd index 9c13c740..e8ebd442 100644 --- a/vignettes/wwinference.Rmd +++ b/vignettes/wwinference.Rmd @@ -17,6 +17,7 @@ vignette: > ```{r setup, echo=FALSE} knitr::opts_chunk$set(dev = "svg") +options(mc.cores = 4) # This tells cmdstan to run the 4 chains in parallel ``` # Quick start @@ -31,7 +32,7 @@ subset of that population, e.g. a municipality within that state. This is intended to be used as a reference for those interested in fitting the `wwinference` model to their own data. -# Package +# Packages In this quick start, we also use `dplyr` `tidybayes` and `ggplot2` packages. These are installed as dependencies when `wwinference` is installed. @@ -59,8 +60,9 @@ from September 1, 2023 to December 1, 2023, with varying sampling frequencies. We will be using this data to produce a forecast of COVID-19 hospital admissions as of December 6, 2023. These data are provided as part of the package data. -These data are already in a format that can be used for `wwinference`. For the -hospital admissions data, it contains: +These data are already in a format that can be used for the `wwinference` package. +For the hospital admissions data, it contains: + - a date (column `date`): the date of the observation, in this case, the date the hospital admissions occurred - a count (column `daily_hosp_admits`): the number of hospital admissions @@ -72,8 +74,7 @@ Additionally, we provide the `hosp_data_eval` dataset which contains the simulated hospital admissions 28 days ahead of the forecast date, which can be used to evaluate the model. -For the wastewater data, the expcted format is a table of observations with the -following columns. The wastewater data should not contain `NA` values for days with +For the wastewater data, the expcted format is a table of observations with the following columns. The wastewater data should not contain `NA` values for days with missing observations, instead these should be excluded: - a date (column `date`): the date the sample was collected - a site indicator (column `site`): the unique identifier for the wastewater treatment plant @@ -100,6 +101,7 @@ head(ww_data) head(hosp_data) ``` + # Pre-processing The user will need to provide data that is in a similar format to the package @@ -126,7 +128,7 @@ params <- get_params( ## Wastewater data pre-processing -The `preprocess_ww_data` function adds the following variables to the original +The `preprocess_ww_data()` function adds the following variables to the original dataset. First, it assigns a unique identifier the unique combinations of labs and sites, since this is the unit we will use for estimating the observation error in the reported measurements. @@ -145,7 +147,7 @@ and `lab`, and will return a dataframe with the column names needed to pass to the downstream model fitting functions. ```{r preprocess-ww-data} -ww_data_preprocessed <- wwinference::preprocess_ww_data( +ww_data_preprocessed <- preprocess_ww_data( ww_data, conc_col_name = "log_genome_copies_per_ml", lod_col_name = "log_lod" @@ -153,13 +155,12 @@ ww_data_preprocessed <- wwinference::preprocess_ww_data( ``` Note that this function assumes that there are no missing values in the concentration column. The package expects observations below the LOD will -be replaced with a numeric value below the LOD. If there are `NA` values in your dataset -when observations are below the LOD, we suggest replacing them with a value +be replaced with a numeric value below the LOD. If there are NAs in your dataset when observations are below the LOD, we suggest replacing them with a value below the LOD in upstream pre-processing. ## Hospital admissions data pre-processing -The `preprocess_hosp_data` function standardizes the column names of the +The `preprocess_count_data()` function standardizes the column names of the resulting datafame. The user must specify the name of the column containing the daily hospital admissions counts and the population size that the hospital admissions are coming from (from in this case, a hypothetical US state). The @@ -168,7 +169,7 @@ return a dataframe with the column names needed to pass to the downstream model fitting functions. ```{r preprocess-hosp-data} -hosp_data_preprocessed <- wwinference::preprocess_count_data( +hosp_data_preprocessed <- preprocess_count_data( hosp_data, count_col_name = "daily_hosp_admits", pop_size_col_name = "state_pop" @@ -184,21 +185,41 @@ ggplot(ww_data_preprocessed) + x = date, y = log_genome_copies_per_ml, color = as.factor(lab_site_name) ), - show.legend = FALSE + show.legend = FALSE, + size = 0.5 ) + geom_point( data = ww_data_preprocessed |> filter( log_genome_copies_per_ml <= log_lod ), aes(x = date, y = log_genome_copies_per_ml, color = "red"), - show.legend = FALSE + show.legend = FALSE, size = 0.5 + ) + + scale_x_date( + date_breaks = "2 weeks", + labels = scales::date_format("%Y-%m-%d") ) + geom_hline(aes(yintercept = log_lod), linetype = "dashed") + facet_wrap(~lab_site_name, scales = "free") + xlab("") + ylab("Genome copies/mL") + ggtitle("Lab-site level wastewater concentration") + - theme_bw() + theme_bw() + + theme( + axis.text.x = element_text( + size = 5, vjust = 1, + hjust = 1, angle = 45 + ), + axis.title.x = element_text(size = 12), + axis.text.y = element_text(size = 5), + strip.text = element_text(size = 5), + axis.title.y = element_text(size = 12), + plot.title = element_text( + size = 10, + vjust = 0.5, hjust = 0.5 + ) + ) + ggplot(hosp_data_preprocessed) + # Plot the hospital admissions data that we will evaluate against in white @@ -211,10 +232,26 @@ ggplot(hosp_data_preprocessed) + ) + # Plot the data we will calibrate to geom_point(aes(x = date, y = count)) + + scale_x_date( + date_breaks = "2 weeks", + labels = scales::date_format("%Y-%m-%d") + ) + xlab("") + ylab("Daily hospital admissions") + ggtitle("State level hospital admissions") + - theme_bw() + theme_bw() + + theme( + axis.text.x = element_text( + size = 8, vjust = 1, + hjust = 1, angle = 45 + ), + axis.title.x = element_text(size = 12), + axis.title.y = element_text(size = 12), + plot.title = element_text( + size = 10, + vjust = 0.5, hjust = 0.5 + ) + ) ``` The closed circles indicate the data the model will be calibrated to, while @@ -229,7 +266,7 @@ we will use the `indicate_ww_exclusions()` function, which will add the flagged outliers to the exclude column where indicated. ```{r indicate-ww-exclusions} -ww_data_to_fit <- wwinference::indicate_ww_exclusions( +ww_data_to_fit <- indicate_ww_exclusions( ww_data_preprocessed, outlier_col_name = "flag_as_ww_outlier", remove_outliers = TRUE @@ -238,7 +275,8 @@ ww_data_to_fit <- wwinference::indicate_ww_exclusions( # Model specification: -We will need to set some metadata to facilitate model specification. This includes: +We will need to set some metadata to facilitate model specification. +This includes: - forecast date (the date we are making a forecast) - number of days to calibrate the model for - number of days to forecast beyond the forecast date @@ -286,17 +324,20 @@ inf_to_hosp <- wwinference::default_covid_inf_to_hosp infection_feedback_pmf <- generation_interval ``` -We will pass these to the `model_spec()` function of the `wwinference()` model, +We will pass these to the `get_model_spec()` function of the `wwinference()` model, along with the other specified parameters above. # Precompiling the model As `wwinference` uses `cmdstan` to fit its models, it is necessary to first -compile the model. This can be done using the compile_model() function. +compile the model. This can be done using the `compile_model()` function + ```{r compile-model} +# temporarily compile from local to make troubleshooting faster/easier model <- wwinference::compile_model() ``` +``` # Fitting the model @@ -317,12 +358,12 @@ to achieve improved model convergence and/or faster model fitting times. See the We also pass our preprocessed datasets (`ww_data_to_fit` and `hosp_data_preprocessed`), specify our model using `get_model_spec()`, -set the MCMC settings using `get_mcmc_options()`, and pass in our +set the MCMC settings by passing a list of arguments to `fit_opts` that will be passed to the `cmdstanr::sample()` function, and pass in our pre-compiled model(`model`) to `wwinference()` where they are combined and used to fit the model. ```{r fitting-model, warning=FALSE, message=FALSE} -ww_fit <- wwinference::wwinference( +ww_fit <- wwinference( ww_data = ww_data_to_fit, count_data = hosp_data_preprocessed, forecast_date = forecast_date, @@ -334,7 +375,7 @@ ww_fit <- wwinference::wwinference( infection_feedback_pmf = infection_feedback_pmf, params = params ), - fit_opts = get_mcmc_options(seed = 123), + fit_opts = list(seed = 123), compiled_model = model ) ``` @@ -369,25 +410,33 @@ Working with the posterior predictions alongside the input data can be useful to check that your model is fitting the data well and that the nowcasted/forecast quantities look reasonable. -We will generate a dataframe that we'll call `draws_df`, that contains -the posterior draws of the estimated, nowcasted, and forecasted expected -observed hospital admissions and wastewater concentrations, as well as the -latent variables of interest including the site-level $\mathcal{R}(t)$ estimates and the -state-level $\mathcal{R}(t)$ estimate. +We can use the `get_draws()` function to generate dataframes that contain +the posterior draws of the estimated, nowcasted, and forecasted quantities, +joined to the relevant data. We can generate this directly on the output of `wwinference()` using: ```{r extracting-draws} -draws_df <- get_draws_df(ww_fit) +draws <- get_draws(ww_fit) -cat( - "Variables in dataframe: ", - sprintf("%s", paste(unique(draws_df$name), collapse = ", ")) -) +print(draws) +``` + +Note that by default the `get_draws()` function will return a list of class `wwinference_fit_draws` +which contains separate dataframes of the posterior draws for predicted counts (`"predicted_counts"`), +wastewater concentrations (`"predicted_ww"`), global $\mathcal{R}(t)$ (`"global_rt"`) estimates, and +subpopulation-level $\mathcal{R}(t)$ estimates ("`subpop_rt"`). +To examine a particular variable (e.g. `"predicted_counts"` for posterior +predicted hospital admissions in this case), access the corresponding tibble using the `$` operator. + + +You can also specify which outputs to return using the `what` argument. +```{r example subset draws} +hosp_draws <- get_draws(ww_fit, what = "predicted_counts") +hosp_draws_df <- hosp_draws$predicted_counts +head(hosp_draws_df) ``` -Note that by default the `get_draws_df()` function will return a tidy long -dataframe with all of the posterior draws joined to applicable data for each of -the included variables. To examine a particular variable (e.g. `"predicted counts"` for posterior -predicted hospital admissions), filter the data frame based on the `name` column. + + ### Using explicit passed arguments rather than S3 methods @@ -395,10 +444,13 @@ Rather than using S3 methods supplied for `wwinference()`, the elements in the `wwinference_fit` object can also be used directly to create this dataframe. This is demonstrated below: -```{r extracting-draws-explicit} -draws_df_explicit <- get_draws_df( +```{r extracting-draws-explicit, eval = FALSE} +draws_explicit <- get_draws( x = ww_fit$raw_input_data$input_ww_data, count_data = ww_fit$raw_input_data$input_count_data, + date_time_spine = ww_fit$raw_input_data$date_time_spine, + site_subpop_spine = ww_fit$raw_input_data$site_subpop_spine, + lab_site_subpop_spine = ww_fit$raw_input_data$lab_site_subpop_spine, stan_data_list = ww_fit$stan_data_list, fit_obj = ww_fit$fit ) @@ -407,39 +459,53 @@ draws_df_explicit <- get_draws_df( ## Plotting the outputs -We can create plots of the outputs using `draws_df` and -the fitting wrapper functions. Note that by default, these plots will not -visualize data that was below the LOD (even though the fit incorporated -them via the censored observation process.) +We can create plots of the outputs using corresponding dataframes in the `draws` +object and the fitting wrapper functions. Note that by default, these plots +will not include outliers that were flagged for exclusion. Data points +that are below the LOD will be plotted in blue. ```{r generating-figures, out.width='100%'} -draws_df <- get_draws_df(ww_fit) - plot_hosp <- get_plot_forecasted_counts( - draws = draws_df, + draws = draws$predicted_counts, count_data_eval = hosp_data_eval, count_data_eval_col_name = "daily_hosp_admits_for_eval", forecast_date = forecast_date ) plot_hosp -plot_ww <- get_plot_ww_conc(draws_df, forecast_date) +plot_ww <- get_plot_ww_conc(draws$predicted_ww, forecast_date) plot_ww -plot_state_rt <- get_plot_global_rt(draws_df, forecast_date) +plot_state_rt <- get_plot_global_rt(draws$global_rt, forecast_date) plot_state_rt -plot_subpop_rt <- get_plot_subpop_rt(draws_df, forecast_date) +plot_subpop_rt <- get_plot_subpop_rt(draws$subpop_rt, forecast_date) plot_subpop_rt ``` +The previous three are equivalent to calling the `plot` method of `wwinference_fit_draws` using the `what` argument: + +```{r, out.width='100%'} +plot( + x = draws, + what = "predicted_counts", + count_data_eval = hosp_data_eval, + count_data_eval_col_name = "daily_hosp_admits_for_eval", + forecast_date = forecast_date +) +plot(draws, what = "predicted_ww", forecast_date = forecast_date) +plot(draws, what = "global_rt", forecast_date = forecast_date) +plot(draws, what = "subpop_rt", forecast_date = forecast_date) +``` + ## Diagnostics We strongly recommend running diagnostics as a post-processing step on the model outputs. This can be done by passing the output of -`wwinference()` into the `get_model_diagnostic_flags()`, `parameter_diagnostics()`, + +`wwinference()` into the `get_model_diagnostic_flags()`, `summary_diagnostics()` and `parameter_diagnostics()` functions. `get_model_diagnostic_flags()` will print out a table of any flags, if any of @@ -448,13 +514,21 @@ We have set default thresholds on the model diagnostics for production-level runs, we recommend adjusting as needed (see below) To further troubleshoot, you can look at -the diagnostic summary and the diagnostics of the individual parameters using +the summary diagnostics using the `summary_diagnostics()` function +and the diagnostics of the individual parameters using the `parameter_diagnostics()` function. +For further information on troubleshooting the model diagnostics, +we recommend the (bayesplot tutorial)[https://mc-stan.org/bayesplot/articles/visual-mcmc-diagnostics.html]. + +You can access the CmdStan object directly using `ww_fit$fit$result` + ```{r diagnostics-using-S3-methods} convergence_flag_df <- get_model_diagnostic_flags(ww_fit) print(convergence_flag_df) -parameter_diagnostics(ww_fit) +summary_diagnostics(ww_fit) +param_diagnostics <- parameter_diagnostics(ww_fit) +head(param_diagnostics) ``` This can also be done explicitly by parsing the elements of the @@ -471,7 +545,7 @@ to identify which components of the model might be driving the convergence issues. For further information on troubleshooting the model diagnostics, -we recommend the (bayesplot tutorial)[https://mc-stan.org/bayesplot/articles/visual-mcmc-diagnostics.html]. +we recommend the [bayesplot tutorial](https://mc-stan.org/bayesplot/articles/visual-mcmc-diagnostics.html). ```{r diagnostics-explicit} convergence_flag_df <- get_model_diagnostic_flags( @@ -497,7 +571,7 @@ rely on the admissions only model if there are covergence or known data issues with the wastewater data. ```{r fit-hosp-only, warning=FALSE, message=FALSE} -fit_hosp_only <- wwinference::wwinference( +fit_hosp_only <- wwinference( ww_data = ww_data_to_fit, count_data = hosp_data_preprocessed, forecast_date = forecast_date, @@ -510,18 +584,18 @@ fit_hosp_only <- wwinference::wwinference( include_ww = FALSE, params = params ), - fit_opts = get_mcmc_options(), + fit_opts = list(seed = 123), compiled_model = model ) ``` ```{r plot-hosp-only, out.width='100%'} -draws_df_hosp_only <- get_draws_df(fit_hosp_only) -plot_hosp_hosp_only <- get_plot_forecasted_counts( - draws = draws_df_hosp_only, +draws_hosp_only <- get_draws(fit_hosp_only) +plot(draws_hosp_only, + what = "predicted_counts", count_data_eval = hosp_data_eval, count_data_eval_col_name = "daily_hosp_admits_for_eval", forecast_date = forecast_date ) -plot_hosp_hosp_only +plot(draws_hosp_only, what = "global_rt", forecast_date = forecast_date) ```