```{r}
set.seed(2024)
-draws_to_keep <- sample(1:max(get_draws_df(fit_iid_to_iid)$draw), 100)
+draws_to_keep <- sample(1:max(get_draws(
+ fit_iid_to_iid,
+ what = "predicted_counts"
+)$predicted_counts$draw), 100)
# IID data ---------------------------------------------------------------------
iid_pred_draws_df <- rbind(
- get_draws_df(fit_iid_to_iid) %>%
+ get_draws(fit_iid_to_iid,
+ what = "predicted_counts"
+ )$predicted_counts %>%
filter(draw %in% draws_to_keep) %>%
mutate(
inf_model_type = "IID"
),
- get_draws_df(fit_iid_to_exp) %>%
+ get_draws(fit_iid_to_exp,
+ what = "predicted_counts"
+ )$predicted_counts %>%
filter(draw %in% draws_to_keep) %>%
mutate(
inf_model_type = "Exponential"
),
- get_draws_df(fit_iid_to_unstruct) %>%
+ get_draws(fit_iid_to_unstruct,
+ what = "predicted_counts"
+ )$predicted_counts %>%
filter(draw %in% draws_to_keep) %>%
mutate(
inf_model_type = "Unstructured"
@@ -1096,17 +1105,23 @@ iid_pred_draws_df <- rbind(
# ------------------------------------------------------------------------------
# Exponential data -------------------------------------------------------------
exp_pred_draws_df <- rbind(
- get_draws_df(fit_exp_to_iid) %>%
+ get_draws(fit_exp_to_iid,
+ what = "predicted_counts"
+ )$predicted_counts %>%
filter(draw %in% draws_to_keep) %>%
mutate(
inf_model_type = "IID"
),
- get_draws_df(fit_exp_to_exp) %>%
+ get_draws(fit_exp_to_exp,
+ what = "predicted_counts"
+ )$predicted_counts %>%
filter(draw %in% draws_to_keep) %>%
mutate(
inf_model_type = "Exponential"
),
- get_draws_df(fit_exp_to_unstruct) %>%
+ get_draws(fit_exp_to_unstruct,
+ what = "predicted_counts"
+ )$predicted_counts %>%
filter(draw %in% draws_to_keep) %>%
mutate(
inf_model_type = "Unstructured"
@@ -1118,17 +1133,23 @@ exp_pred_draws_df <- rbind(
# ------------------------------------------------------------------------------
# Rand. Corr. Matrix data ------------------------------------------------------
rand_pred_draws_df <- rbind(
- get_draws_df(fit_rand_to_iid) %>%
+ get_draws(fit_rand_to_iid,
+ what = "predicted_counts"
+ )$predicted_counts %>%
filter(draw %in% draws_to_keep) %>%
mutate(
inf_model_type = "IID"
),
- get_draws_df(fit_rand_to_exp) %>%
+ get_draws(fit_rand_to_exp,
+ what = "predicted_counts"
+ )$predicted_counts %>%
filter(draw %in% draws_to_keep) %>%
mutate(
inf_model_type = "Exponential"
),
- get_draws_df(fit_rand_to_unstruct) %>%
+ get_draws(fit_rand_to_unstruct,
+ what = "predicted_counts"
+ )$predicted_counts %>%
filter(draw %in% draws_to_keep) %>%
mutate(
inf_model_type = "Unstructured"
@@ -1161,6 +1182,331 @@ all_pred_draws_df <- rbind(
)
)
)
+
+# Wastewater draws------------------------------
+iid_ww_draws_df <- rbind(
+ get_draws(fit_iid_to_iid,
+ what = "predicted_ww"
+ )$predicted_ww %>%
+ filter(draw %in% draws_to_keep) %>%
+ mutate(
+ inf_model_type = "IID"
+ ),
+ get_draws(fit_iid_to_exp,
+ what = "predicted_ww"
+ )$predicted_ww %>%
+ filter(draw %in% draws_to_keep) %>%
+ mutate(
+ inf_model_type = "Exponential"
+ ),
+ get_draws(fit_iid_to_unstruct,
+ what = "predicted_ww"
+ )$predicted_ww %>%
+ filter(draw %in% draws_to_keep) %>%
+ mutate(
+ inf_model_type = "Unstructured"
+ )
+) %>%
+ mutate(
+ gen_model_type = "IID"
+ )
+# ------------------------------------------------------------------------------
+# Exponential data -------------------------------------------------------------
+exp_ww_draws_df <- rbind(
+ get_draws(fit_exp_to_iid,
+ what = "predicted_ww"
+ )$predicted_ww %>%
+ filter(draw %in% draws_to_keep) %>%
+ mutate(
+ inf_model_type = "IID"
+ ),
+ get_draws(fit_exp_to_exp,
+ what = "predicted_ww"
+ )$predicted_ww %>%
+ filter(draw %in% draws_to_keep) %>%
+ mutate(
+ inf_model_type = "Exponential"
+ ),
+ get_draws(fit_exp_to_unstruct,
+ what = "predicted_ww"
+ )$predicted_ww %>%
+ filter(draw %in% draws_to_keep) %>%
+ mutate(
+ inf_model_type = "Unstructured"
+ )
+) %>%
+ mutate(
+ gen_model_type = "Exponential"
+ )
+# ------------------------------------------------------------------------------
+# Rand. Corr. Matrix data ------------------------------------------------------
+rand_ww_draws_df <- rbind(
+ get_draws(fit_rand_to_iid,
+ what = "predicted_ww"
+ )$predicted_ww %>%
+ filter(draw %in% draws_to_keep) %>%
+ mutate(
+ inf_model_type = "IID"
+ ),
+ get_draws(fit_rand_to_exp,
+ what = "predicted_ww"
+ )$predicted_ww %>%
+ filter(draw %in% draws_to_keep) %>%
+ mutate(
+ inf_model_type = "Exponential"
+ ),
+ get_draws(fit_rand_to_unstruct,
+ what = "predicted_ww"
+ )$predicted_ww %>%
+ filter(draw %in% draws_to_keep) %>%
+ mutate(
+ inf_model_type = "Unstructured"
+ )
+) %>%
+ mutate(
+ gen_model_type = "Rand. Corr. Matrix"
+ )
+
+all_ww_draws_df <- rbind(
+ iid_ww_draws_df,
+ exp_ww_draws_df,
+ rand_ww_draws_df
+) %>%
+ mutate(
+ inf_model_type = factor(
+ inf_model_type,
+ levels = c(
+ "Exponential",
+ "Unstructured",
+ "IID"
+ )
+ ),
+ gen_model_type = factor(
+ gen_model_type,
+ levels = c(
+ "IID",
+ "Exponential",
+ "Rand. Corr. Matrix"
+ )
+ )
+ )
+
+# Global R(t) draws-------------------------------------------
+iid_rt_draws_df <- rbind(
+ get_draws(fit_iid_to_iid,
+ what = "global_rt"
+ )$global_rt %>%
+ filter(draw %in% draws_to_keep) %>%
+ mutate(
+ inf_model_type = "IID"
+ ),
+ get_draws(fit_iid_to_exp,
+ what = "global_rt"
+ )$global_rt %>%
+ filter(draw %in% draws_to_keep) %>%
+ mutate(
+ inf_model_type = "Exponential"
+ ),
+ get_draws(fit_iid_to_unstruct,
+ what = "global_rt"
+ )$global_rt %>%
+ filter(draw %in% draws_to_keep) %>%
+ mutate(
+ inf_model_type = "Unstructured"
+ )
+) %>%
+ mutate(
+ gen_model_type = "IID"
+ )
+# ------------------------------------------------------------------------------
+# Exponential data -------------------------------------------------------------
+exp_rt_draws_df <- rbind(
+ get_draws(fit_exp_to_iid,
+ what = "global_rt"
+ )$global_rt %>%
+ filter(draw %in% draws_to_keep) %>%
+ mutate(
+ inf_model_type = "IID"
+ ),
+ get_draws(fit_exp_to_exp,
+ what = "global_rt"
+ )$global_rt %>%
+ filter(draw %in% draws_to_keep) %>%
+ mutate(
+ inf_model_type = "Exponential"
+ ),
+ get_draws(fit_exp_to_unstruct,
+ what = "global_rt"
+ )$global_rt %>%
+ filter(draw %in% draws_to_keep) %>%
+ mutate(
+ inf_model_type = "Unstructured"
+ )
+) %>%
+ mutate(
+ gen_model_type = "Exponential"
+ )
+# ------------------------------------------------------------------------------
+# Rand. Corr. Matrix data ------------------------------------------------------
+rand_rt_draws_df <- rbind(
+ get_draws(fit_rand_to_iid,
+ what = "global_rt"
+ )$global_rt %>%
+ filter(draw %in% draws_to_keep) %>%
+ mutate(
+ inf_model_type = "IID"
+ ),
+ get_draws(fit_rand_to_exp,
+ what = "global_rt"
+ )$global_rt %>%
+ filter(draw %in% draws_to_keep) %>%
+ mutate(
+ inf_model_type = "Exponential"
+ ),
+ get_draws(fit_rand_to_unstruct,
+ what = "global_rt"
+ )$global_rt %>%
+ filter(draw %in% draws_to_keep) %>%
+ mutate(
+ inf_model_type = "Unstructured"
+ )
+) %>%
+ mutate(
+ gen_model_type = "Rand. Corr. Matrix"
+ )
+
+all_rt_draws_df <- rbind(
+ iid_rt_draws_df,
+ exp_rt_draws_df,
+ rand_rt_draws_df
+) %>%
+ mutate(
+ inf_model_type = factor(
+ inf_model_type,
+ levels = c(
+ "Exponential",
+ "Unstructured",
+ "IID"
+ )
+ ),
+ gen_model_type = factor(
+ gen_model_type,
+ levels = c(
+ "IID",
+ "Exponential",
+ "Rand. Corr. Matrix"
+ )
+ )
+ )
+
+# Subpop R(t) ---------------------------
+
+iid_subpop_rt_draws_df <- rbind(
+ get_draws(fit_iid_to_iid,
+ what = "subpop_rt"
+ )$subpop_rt %>%
+ filter(draw %in% draws_to_keep) %>%
+ mutate(
+ inf_model_type = "IID"
+ ),
+ get_draws(fit_iid_to_exp,
+ what = "subpop_rt"
+ )$subpop_rt %>%
+ filter(draw %in% draws_to_keep) %>%
+ mutate(
+ inf_model_type = "Exponential"
+ ),
+ get_draws(fit_iid_to_unstruct,
+ what = "subpop_rt"
+ )$subpop_rt %>%
+ filter(draw %in% draws_to_keep) %>%
+ mutate(
+ inf_model_type = "Unstructured"
+ )
+) %>%
+ mutate(
+ gen_model_type = "IID"
+ )
+# ------------------------------------------------------------------------------
+# Exponential data -------------------------------------------------------------
+exp_subpop_rt_draws_df <- rbind(
+ get_draws(fit_exp_to_iid,
+ what = "subpop_rt"
+ )$subpop_rt %>%
+ filter(draw %in% draws_to_keep) %>%
+ mutate(
+ inf_model_type = "IID"
+ ),
+ get_draws(fit_exp_to_exp,
+ what = "subpop_rt"
+ )$subpop_rt %>%
+ filter(draw %in% draws_to_keep) %>%
+ mutate(
+ inf_model_type = "Exponential"
+ ),
+ get_draws(fit_exp_to_unstruct,
+ what = "subpop_rt"
+ )$subpop_rt %>%
+ filter(draw %in% draws_to_keep) %>%
+ mutate(
+ inf_model_type = "Unstructured"
+ )
+) %>%
+ mutate(
+ gen_model_type = "Exponential"
+ )
+# ------------------------------------------------------------------------------
+# Rand. Corr. Matrix data ------------------------------------------------------
+rand_subpop_rt_draws_df <- rbind(
+ get_draws(fit_rand_to_iid,
+ what = "subpop_rt"
+ )$subpop_rt %>%
+ filter(draw %in% draws_to_keep) %>%
+ mutate(
+ inf_model_type = "IID"
+ ),
+ get_draws(fit_rand_to_exp,
+ what = "subpop_rt"
+ )$subpop_rt %>%
+ filter(draw %in% draws_to_keep) %>%
+ mutate(
+ inf_model_type = "Exponential"
+ ),
+ get_draws(fit_rand_to_unstruct,
+ what = "subpop_rt"
+ )$subpop_rt %>%
+ filter(draw %in% draws_to_keep) %>%
+ mutate(
+ inf_model_type = "Unstructured"
+ )
+) %>%
+ mutate(
+ gen_model_type = "Rand. Corr. Matrix"
+ )
+
+all_subpop_rt_draws_df <- rbind(
+ iid_subpop_rt_draws_df,
+ exp_subpop_rt_draws_df,
+ rand_subpop_rt_draws_df
+) %>%
+ mutate(
+ inf_model_type = factor(
+ inf_model_type,
+ levels = c(
+ "Exponential",
+ "Unstructured",
+ "IID"
+ )
+ ),
+ gen_model_type = factor(
+ gen_model_type,
+ levels = c(
+ "IID",
+ "Exponential",
+ "Rand. Corr. Matrix"
+ )
+ )
+ )
```
@@ -1449,9 +1795,6 @@ evaluation metrics will be used to quantify forecast performance.
```{r warning=FALSE}
# Hospital admissions results --------------------------------------------------
hosp_ribbon_data <- all_pred_draws_df %>%
- filter(
- name == "predicted counts"
- ) %>%
group_by(
date,
inf_model_type,
@@ -1464,10 +1807,7 @@ hosp_ribbon_data <- all_pred_draws_df %>%
.groups = "drop"
)
hosp_result_plot <- ggplot(
- all_pred_draws_df %>%
- filter(
- name == "predicted counts"
- )
+ all_pred_draws_df
) +
geom_ribbon(
data = hosp_ribbon_data,
@@ -1511,18 +1851,13 @@ hosp_result_plot <- ggplot(
values = c("darkviolet", "deeppink3", "darksalmon")
) +
theme_bw()
-# ------------------------------------------------------------------------------
+
+
# Wastewater results -----------------------------------------------------------
-ww_ribbon_data <- all_pred_draws_df %>%
- filter(
- name == "predicted wastewater"
- ) %>%
- mutate(
- site_lab_name = glue::glue("{subpop}, Lab: {lab}")
- ) %>%
+ww_ribbon_data <- all_ww_draws_df %>%
group_by(
date,
- subpop,
+ site,
inf_model_type,
gen_model_type
) %>%
@@ -1533,10 +1868,7 @@ ww_ribbon_data <- all_pred_draws_df %>%
.groups = "drop"
)
ww_result_plot <- ggplot(
- all_pred_draws_df %>%
- filter(
- name == "predicted wastewater"
- )
+ all_ww_draws_df
) +
geom_ribbon(
data = ww_ribbon_data,
@@ -1556,7 +1888,7 @@ ww_result_plot <- ggplot(
) +
xlab("") +
ylab("Genome copies/mL on Log Scale") +
- facet_grid(subpop ~ gen_model_type, scales = "free_y") +
+ facet_grid(site ~ gen_model_type, scales = "free_y") +
guides(
fill = guide_legend(
title = "Assumed Corr. Structure"
@@ -1597,10 +1929,9 @@ partially or not at all informed by recent data.
```{r}
# Global Rt results ------------------------------------------------------------
-global_rt_ribbon_data <- all_pred_draws_df %>%
- filter(
- name == "global R(t)"
- ) %>%
+
+
+global_rt_ribbon_data <- all_rt_draws_df %>%
group_by(
date,
inf_model_type,
@@ -1613,10 +1944,7 @@ global_rt_ribbon_data <- all_pred_draws_df %>%
.groups = "drop"
)
global_rt_result_plot <- ggplot(
- all_pred_draws_df %>%
- filter(
- name == "global R(t)"
- )
+ all_pred_draws_df
) +
geom_ribbon(
data = global_rt_ribbon_data,
@@ -1667,13 +1995,10 @@ global_rt_result_plot <- ggplot(
theme_bw()
# ------------------------------------------------------------------------------
# Site Rt results --------------------------------------------------------------
-site_rt_ribbon_data <- all_pred_draws_df %>%
- filter(
- name == "subpopulation R(t)"
- ) %>%
+site_rt_ribbon_data <- all_subpop_rt_draws_df %>%
group_by(
date,
- subpop,
+ subpop_name,
inf_model_type,
gen_model_type
) %>%
@@ -1682,20 +2007,6 @@ site_rt_ribbon_data <- all_pred_draws_df %>%
median = median(pred_value),
upper = quantile(pred_value, 0.975, na.rm = TRUE),
.groups = "drop"
- ) %>%
- mutate(
- subpop = sub(
- pattern = "Site: (\\d+)",
- replacement = "Site \\1",
- x = subpop,
- ignore.case = "remainder of pop"
- )
- ) %>%
- mutate(
- subpop = case_when(
- subpop == "remainder of pop" ~ "Aux",
- .default = subpop
- )
)
site_rt_result_plot <- ggplot() +
geom_ribbon(
@@ -1727,7 +2038,7 @@ site_rt_result_plot <- ggplot() +
) +
xlab("") +
ylab("Site Rt") +
- facet_grid(subpop ~ gen_model_type, scales = "free_y") +
+ facet_grid(subpop_name ~ gen_model_type, scales = "free_y") +
guides(
fill = guide_legend(
title = "Assumed Corr. Structure"
@@ -2043,7 +2354,6 @@ period.
```{r}
hosp_obj_for_eval_forcast <- all_pred_draws_df %>%
filter(
- name == "predicted counts",
date > forecast_date
) %>%
inner_join(
@@ -2241,7 +2551,6 @@ make two plots one for metrics by date, and another across all dates.
```{r}
hosp_obj_for_eval_nowcast <- all_pred_draws_df %>%
filter(
- name == "predicted counts",
date > max(hosp_data$date),
date <= forecast_date
) %>%
diff --git a/vignettes/wwinference.Rmd b/vignettes/wwinference.Rmd
index 9c13c740..e8ebd442 100644
--- a/vignettes/wwinference.Rmd
+++ b/vignettes/wwinference.Rmd
@@ -17,6 +17,7 @@ vignette: >
```{r setup, echo=FALSE}
knitr::opts_chunk$set(dev = "svg")
+options(mc.cores = 4) # This tells cmdstan to run the 4 chains in parallel
```
# Quick start
@@ -31,7 +32,7 @@ subset of that population, e.g. a municipality within that state.
This is intended to be used as a reference for those
interested in fitting the `wwinference` model to their own data.
-# Package
+# Packages
In this quick start, we also use `dplyr` `tidybayes` and `ggplot2` packages.
These are installed as dependencies when `wwinference` is installed.
@@ -59,8 +60,9 @@ from September 1, 2023 to December 1, 2023, with varying sampling frequencies.
We will be using this data to produce a forecast of COVID-19 hospital admissions
as of December 6, 2023. These data are provided as part of the package data.
-These data are already in a format that can be used for `wwinference`. For the
-hospital admissions data, it contains:
+These data are already in a format that can be used for the `wwinference` package.
+For the hospital admissions data, it contains:
+
- a date (column `date`): the date of the observation, in this case, the date
the hospital admissions occurred
- a count (column `daily_hosp_admits`): the number of hospital admissions
@@ -72,8 +74,7 @@ Additionally, we provide the `hosp_data_eval` dataset which contains the
simulated hospital admissions 28 days ahead of the forecast date, which can be
used to evaluate the model.
-For the wastewater data, the expcted format is a table of observations with the
-following columns. The wastewater data should not contain `NA` values for days with
+For the wastewater data, the expcted format is a table of observations with the following columns. The wastewater data should not contain `NA` values for days with
missing observations, instead these should be excluded:
- a date (column `date`): the date the sample was collected
- a site indicator (column `site`): the unique identifier for the wastewater treatment plant
@@ -100,6 +101,7 @@ head(ww_data)
head(hosp_data)
```
+
# Pre-processing
The user will need to provide data that is in a similar format to the package
@@ -126,7 +128,7 @@ params <- get_params(
## Wastewater data pre-processing
-The `preprocess_ww_data` function adds the following variables to the original
+The `preprocess_ww_data()` function adds the following variables to the original
dataset. First, it assigns a unique identifier
the unique combinations of labs and sites, since this is the unit we will
use for estimating the observation error in the reported measurements.
@@ -145,7 +147,7 @@ and `lab`, and will return a dataframe with the column names needed to
pass to the downstream model fitting functions.
```{r preprocess-ww-data}
-ww_data_preprocessed <- wwinference::preprocess_ww_data(
+ww_data_preprocessed <- preprocess_ww_data(
ww_data,
conc_col_name = "log_genome_copies_per_ml",
lod_col_name = "log_lod"
@@ -153,13 +155,12 @@ ww_data_preprocessed <- wwinference::preprocess_ww_data(
```
Note that this function assumes that there are no missing values in the
concentration column. The package expects observations below the LOD will
-be replaced with a numeric value below the LOD. If there are `NA` values in your dataset
-when observations are below the LOD, we suggest replacing them with a value
+be replaced with a numeric value below the LOD. If there are NAs in your dataset when observations are below the LOD, we suggest replacing them with a value
below the LOD in upstream pre-processing.
## Hospital admissions data pre-processing
-The `preprocess_hosp_data` function standardizes the column names of the
+The `preprocess_count_data()` function standardizes the column names of the
resulting datafame. The user must specify the name of the column containing
the daily hospital admissions counts and the population size that the hospital
admissions are coming from (from in this case, a hypothetical US state). The
@@ -168,7 +169,7 @@ return a dataframe with the column names needed to pass to the downstream model
fitting functions.
```{r preprocess-hosp-data}
-hosp_data_preprocessed <- wwinference::preprocess_count_data(
+hosp_data_preprocessed <- preprocess_count_data(
hosp_data,
count_col_name = "daily_hosp_admits",
pop_size_col_name = "state_pop"
@@ -184,21 +185,41 @@ ggplot(ww_data_preprocessed) +
x = date, y = log_genome_copies_per_ml,
color = as.factor(lab_site_name)
),
- show.legend = FALSE
+ show.legend = FALSE,
+ size = 0.5
) +
geom_point(
data = ww_data_preprocessed |> filter(
log_genome_copies_per_ml <= log_lod
),
aes(x = date, y = log_genome_copies_per_ml, color = "red"),
- show.legend = FALSE
+ show.legend = FALSE, size = 0.5
+ ) +
+ scale_x_date(
+ date_breaks = "2 weeks",
+ labels = scales::date_format("%Y-%m-%d")
) +
geom_hline(aes(yintercept = log_lod), linetype = "dashed") +
facet_wrap(~lab_site_name, scales = "free") +
xlab("") +
ylab("Genome copies/mL") +
ggtitle("Lab-site level wastewater concentration") +
- theme_bw()
+ theme_bw() +
+ theme(
+ axis.text.x = element_text(
+ size = 5, vjust = 1,
+ hjust = 1, angle = 45
+ ),
+ axis.title.x = element_text(size = 12),
+ axis.text.y = element_text(size = 5),
+ strip.text = element_text(size = 5),
+ axis.title.y = element_text(size = 12),
+ plot.title = element_text(
+ size = 10,
+ vjust = 0.5, hjust = 0.5
+ )
+ )
+
ggplot(hosp_data_preprocessed) +
# Plot the hospital admissions data that we will evaluate against in white
@@ -211,10 +232,26 @@ ggplot(hosp_data_preprocessed) +
) +
# Plot the data we will calibrate to
geom_point(aes(x = date, y = count)) +
+ scale_x_date(
+ date_breaks = "2 weeks",
+ labels = scales::date_format("%Y-%m-%d")
+ ) +
xlab("") +
ylab("Daily hospital admissions") +
ggtitle("State level hospital admissions") +
- theme_bw()
+ theme_bw() +
+ theme(
+ axis.text.x = element_text(
+ size = 8, vjust = 1,
+ hjust = 1, angle = 45
+ ),
+ axis.title.x = element_text(size = 12),
+ axis.title.y = element_text(size = 12),
+ plot.title = element_text(
+ size = 10,
+ vjust = 0.5, hjust = 0.5
+ )
+ )
```
The closed circles indicate the data the model will be calibrated to, while
@@ -229,7 +266,7 @@ we will use the `indicate_ww_exclusions()` function, which will add the
flagged outliers to the exclude column where indicated.
```{r indicate-ww-exclusions}
-ww_data_to_fit <- wwinference::indicate_ww_exclusions(
+ww_data_to_fit <- indicate_ww_exclusions(
ww_data_preprocessed,
outlier_col_name = "flag_as_ww_outlier",
remove_outliers = TRUE
@@ -238,7 +275,8 @@ ww_data_to_fit <- wwinference::indicate_ww_exclusions(
# Model specification:
-We will need to set some metadata to facilitate model specification. This includes:
+We will need to set some metadata to facilitate model specification.
+This includes:
- forecast date (the date we are making a forecast)
- number of days to calibrate the model for
- number of days to forecast beyond the forecast date
@@ -286,17 +324,20 @@ inf_to_hosp <- wwinference::default_covid_inf_to_hosp
infection_feedback_pmf <- generation_interval
```
-We will pass these to the `model_spec()` function of the `wwinference()` model,
+We will pass these to the `get_model_spec()` function of the `wwinference()` model,
along with the other specified parameters above.
# Precompiling the model
As `wwinference` uses `cmdstan` to fit its models, it is necessary to first
-compile the model. This can be done using the compile_model() function.
+compile the model. This can be done using the `compile_model()` function
+
```{r compile-model}
+# temporarily compile from local to make troubleshooting faster/easier
model <- wwinference::compile_model()
```
+```
# Fitting the model
@@ -317,12 +358,12 @@ to achieve improved model convergence and/or faster model fitting times. See the
We also pass our preprocessed datasets (`ww_data_to_fit` and
`hosp_data_preprocessed`), specify our model using `get_model_spec()`,
-set the MCMC settings using `get_mcmc_options()`, and pass in our
+set the MCMC settings by passing a list of arguments to `fit_opts` that will be passed to the `cmdstanr::sample()` function, and pass in our
pre-compiled model(`model`) to `wwinference()` where they are combined and
used to fit the model.
```{r fitting-model, warning=FALSE, message=FALSE}
-ww_fit <- wwinference::wwinference(
+ww_fit <- wwinference(
ww_data = ww_data_to_fit,
count_data = hosp_data_preprocessed,
forecast_date = forecast_date,
@@ -334,7 +375,7 @@ ww_fit <- wwinference::wwinference(
infection_feedback_pmf = infection_feedback_pmf,
params = params
),
- fit_opts = get_mcmc_options(seed = 123),
+ fit_opts = list(seed = 123),
compiled_model = model
)
```
@@ -369,25 +410,33 @@ Working with the posterior predictions alongside the input data can be useful
to check that your model is fitting the data well and that the
nowcasted/forecast quantities look reasonable.
-We will generate a dataframe that we'll call `draws_df`, that contains
-the posterior draws of the estimated, nowcasted, and forecasted expected
-observed hospital admissions and wastewater concentrations, as well as the
-latent variables of interest including the site-level $\mathcal{R}(t)$ estimates and the
-state-level $\mathcal{R}(t)$ estimate.
+We can use the `get_draws()` function to generate dataframes that contain
+the posterior draws of the estimated, nowcasted, and forecasted quantities,
+joined to the relevant data.
We can generate this directly on the output of `wwinference()` using:
```{r extracting-draws}
-draws_df <- get_draws_df(ww_fit)
+draws <- get_draws(ww_fit)
-cat(
- "Variables in dataframe: ",
- sprintf("%s", paste(unique(draws_df$name), collapse = ", "))
-)
+print(draws)
+```
+
+Note that by default the `get_draws()` function will return a list of class `wwinference_fit_draws`
+which contains separate dataframes of the posterior draws for predicted counts (`"predicted_counts"`),
+wastewater concentrations (`"predicted_ww"`), global $\mathcal{R}(t)$ (`"global_rt"`) estimates, and
+subpopulation-level $\mathcal{R}(t)$ estimates ("`subpop_rt"`).
+To examine a particular variable (e.g. `"predicted_counts"` for posterior
+predicted hospital admissions in this case), access the corresponding tibble using the `$` operator.
+
+
+You can also specify which outputs to return using the `what` argument.
+```{r example subset draws}
+hosp_draws <- get_draws(ww_fit, what = "predicted_counts")
+hosp_draws_df <- hosp_draws$predicted_counts
+head(hosp_draws_df)
```
-Note that by default the `get_draws_df()` function will return a tidy long
-dataframe with all of the posterior draws joined to applicable data for each of
-the included variables. To examine a particular variable (e.g. `"predicted counts"` for posterior
-predicted hospital admissions), filter the data frame based on the `name` column.
+
+
### Using explicit passed arguments rather than S3 methods
@@ -395,10 +444,13 @@ Rather than using S3 methods supplied for `wwinference()`, the elements in the
`wwinference_fit` object can also be used directly to create this dataframe.
This is demonstrated below:
-```{r extracting-draws-explicit}
-draws_df_explicit <- get_draws_df(
+```{r extracting-draws-explicit, eval = FALSE}
+draws_explicit <- get_draws(
x = ww_fit$raw_input_data$input_ww_data,
count_data = ww_fit$raw_input_data$input_count_data,
+ date_time_spine = ww_fit$raw_input_data$date_time_spine,
+ site_subpop_spine = ww_fit$raw_input_data$site_subpop_spine,
+ lab_site_subpop_spine = ww_fit$raw_input_data$lab_site_subpop_spine,
stan_data_list = ww_fit$stan_data_list,
fit_obj = ww_fit$fit
)
@@ -407,39 +459,53 @@ draws_df_explicit <- get_draws_df(
## Plotting the outputs
-We can create plots of the outputs using `draws_df` and
-the fitting wrapper functions. Note that by default, these plots will not
-visualize data that was below the LOD (even though the fit incorporated
-them via the censored observation process.)
+We can create plots of the outputs using corresponding dataframes in the `draws`
+object and the fitting wrapper functions. Note that by default, these plots
+will not include outliers that were flagged for exclusion. Data points
+that are below the LOD will be plotted in blue.
```{r generating-figures, out.width='100%'}
-draws_df <- get_draws_df(ww_fit)
-
plot_hosp <- get_plot_forecasted_counts(
- draws = draws_df,
+ draws = draws$predicted_counts,
count_data_eval = hosp_data_eval,
count_data_eval_col_name = "daily_hosp_admits_for_eval",
forecast_date = forecast_date
)
plot_hosp
-plot_ww <- get_plot_ww_conc(draws_df, forecast_date)
+plot_ww <- get_plot_ww_conc(draws$predicted_ww, forecast_date)
plot_ww
-plot_state_rt <- get_plot_global_rt(draws_df, forecast_date)
+plot_state_rt <- get_plot_global_rt(draws$global_rt, forecast_date)
plot_state_rt
-plot_subpop_rt <- get_plot_subpop_rt(draws_df, forecast_date)
+plot_subpop_rt <- get_plot_subpop_rt(draws$subpop_rt, forecast_date)
plot_subpop_rt
```
+The previous three are equivalent to calling the `plot` method of `wwinference_fit_draws` using the `what` argument:
+
+```{r, out.width='100%'}
+plot(
+ x = draws,
+ what = "predicted_counts",
+ count_data_eval = hosp_data_eval,
+ count_data_eval_col_name = "daily_hosp_admits_for_eval",
+ forecast_date = forecast_date
+)
+plot(draws, what = "predicted_ww", forecast_date = forecast_date)
+plot(draws, what = "global_rt", forecast_date = forecast_date)
+plot(draws, what = "subpop_rt", forecast_date = forecast_date)
+```
+
## Diagnostics
We strongly recommend running diagnostics as a post-processing step on the
model outputs.
This can be done by passing the output of
-`wwinference()` into the `get_model_diagnostic_flags()`, `parameter_diagnostics()`,
+
+`wwinference()` into the `get_model_diagnostic_flags()`, `summary_diagnostics()`
and `parameter_diagnostics()` functions.
`get_model_diagnostic_flags()` will print out a table of any flags, if any of
@@ -448,13 +514,21 @@ We have set default thresholds on the model diagnostics for production-level
runs, we recommend adjusting as needed (see below)
To further troubleshoot, you can look at
-the diagnostic summary and the diagnostics of the individual parameters using
+the summary diagnostics using the `summary_diagnostics()` function
+and the diagnostics of the individual parameters using
the `parameter_diagnostics()` function.
+For further information on troubleshooting the model diagnostics,
+we recommend the (bayesplot tutorial)[https://mc-stan.org/bayesplot/articles/visual-mcmc-diagnostics.html].
+
+You can access the CmdStan object directly using `ww_fit$fit$result`
+
```{r diagnostics-using-S3-methods}
convergence_flag_df <- get_model_diagnostic_flags(ww_fit)
print(convergence_flag_df)
-parameter_diagnostics(ww_fit)
+summary_diagnostics(ww_fit)
+param_diagnostics <- parameter_diagnostics(ww_fit)
+head(param_diagnostics)
```
This can also be done explicitly by parsing the elements of the
@@ -471,7 +545,7 @@ to identify which components of the model might be driving the convergence
issues.
For further information on troubleshooting the model diagnostics,
-we recommend the (bayesplot tutorial)[https://mc-stan.org/bayesplot/articles/visual-mcmc-diagnostics.html].
+we recommend the [bayesplot tutorial](https://mc-stan.org/bayesplot/articles/visual-mcmc-diagnostics.html).
```{r diagnostics-explicit}
convergence_flag_df <- get_model_diagnostic_flags(
@@ -497,7 +571,7 @@ rely on the admissions only model if there are covergence or known data issues
with the wastewater data.
```{r fit-hosp-only, warning=FALSE, message=FALSE}
-fit_hosp_only <- wwinference::wwinference(
+fit_hosp_only <- wwinference(
ww_data = ww_data_to_fit,
count_data = hosp_data_preprocessed,
forecast_date = forecast_date,
@@ -510,18 +584,18 @@ fit_hosp_only <- wwinference::wwinference(
include_ww = FALSE,
params = params
),
- fit_opts = get_mcmc_options(),
+ fit_opts = list(seed = 123),
compiled_model = model
)
```
```{r plot-hosp-only, out.width='100%'}
-draws_df_hosp_only <- get_draws_df(fit_hosp_only)
-plot_hosp_hosp_only <- get_plot_forecasted_counts(
- draws = draws_df_hosp_only,
+draws_hosp_only <- get_draws(fit_hosp_only)
+plot(draws_hosp_only,
+ what = "predicted_counts",
count_data_eval = hosp_data_eval,
count_data_eval_col_name = "daily_hosp_admits_for_eval",
forecast_date = forecast_date
)
-plot_hosp_hosp_only
+plot(draws_hosp_only, what = "global_rt", forecast_date = forecast_date)
```