From b3a3f1a14387420f501d27b2c411734e0fb999b7 Mon Sep 17 00:00:00 2001
From: Kaitlyn Johnson <94390107+kaitejohnson@users.noreply.github.com>
Date: Wed, 26 Jun 2024 10:42:23 -0400
Subject: [PATCH] updates to codebase to reflect recent dev (#93)
---
.gitignore | 6 -
.pre-commit-config.yaml | 21 +-
_targets_eval.R | 6 +-
_targets_eval_postprocessing.R | 195 +++-
model_definition.md | 4 +-
scratch/ar1_example.R | 24 +
scratch/debug_ww_data.R | 574 +++++++++
scratch/site_level_inf_dynamics.Rmd | 1281 +++++++++++++++++++++
scratch/sites_sum_to_states_test.Rmd | 324 ++++++
scratch/state_level_model.Rmd | 433 +++++++
scratch/test_epi_class.R | 50 +
scratch/toy_example_ind_variation.Rmd | 683 +++++++++++
scratch/ww_data_quick_figs.R | 59 +
setup_container.R | 37 +
src/setup_eval.R | 15 +-
src/write_eval_config.R | 2 +
wweval/NAMESPACE | 3 +
wweval/R/combine_outputs.R | 4 +
wweval/R/eval_post_process.R | 33 +-
wweval/R/get_epidemic_phases_from_rt.R | 84 ++
wweval/R/get_table_sufficient_ww.R | 80 ++
wweval/R/model_run_diagnostic_flags.R | 33 +
wweval/R/sample_model.R | 10 +-
wweval/man/get_convergence_df.Rd | 25 +
wweval/man/get_epidemic_phases_from_rt.Rd | 26 +
wweval/man/get_table_sufficient_ww.Rd | 53 +
26 files changed, 4000 insertions(+), 65 deletions(-)
create mode 100644 scratch/ar1_example.R
create mode 100644 scratch/debug_ww_data.R
create mode 100644 scratch/site_level_inf_dynamics.Rmd
create mode 100644 scratch/sites_sum_to_states_test.Rmd
create mode 100644 scratch/state_level_model.Rmd
create mode 100644 scratch/test_epi_class.R
create mode 100644 scratch/toy_example_ind_variation.Rmd
create mode 100644 scratch/ww_data_quick_figs.R
create mode 100644 setup_container.R
create mode 100644 wweval/R/get_epidemic_phases_from_rt.R
create mode 100644 wweval/R/get_table_sufficient_ww.R
create mode 100644 wweval/man/get_convergence_df.Rd
create mode 100644 wweval/man/get_epidemic_phases_from_rt.Rd
create mode 100644 wweval/man/get_table_sufficient_ww.Rd
diff --git a/.gitignore b/.gitignore
index 1b2525c8..a1b5205c 100644
--- a/.gitignore
+++ b/.gitignore
@@ -15,12 +15,6 @@ secrets.yaml
# Azure batch configuration
batch_config.toml
-# batch
-batch/
-Makefile
-Containerfile
-.containerignore
-
# rendered documents ignored by default
*.html
*.pdf
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 073ffca6..d80f1e46 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -19,7 +19,26 @@ repos:
hooks:
- id: style-files
- id: lintr
-#####
+
+########
+# Python
+- repo: https://github.com/psf/black
+ rev: 23.10.0
+ hooks:
+ - id: black
+ args: ['--line-length', '79']
+- repo: https://github.com/PyCQA/isort
+ rev: 5.12.0
+ hooks:
+ - id: isort
+ args: ['--profile', 'black',
+ '--line-length', '79']
+- repo: https://github.com/astral-sh/ruff-pre-commit
+ rev: v0.1.0
+ hooks:
+ - id: ruff
+
+#########
# Secrets
- repo: https://github.com/Yelp/detect-secrets
rev: v1.4.0
diff --git a/_targets_eval.R b/_targets_eval.R
index 9d2ce92f..79489a31 100644
--- a/_targets_eval.R
+++ b/_targets_eval.R
@@ -869,7 +869,8 @@ downstream_targets <- list(
name = plot_summarized_scores_w_data,
command = get_plot_scores_w_data(
grouped_submission_scores,
- eval_hosp_data
+ eval_hosp_data,
+ eval_config$figure_dir
),
pattern = map(grouped_submission_scores),
iteration = "list",
@@ -900,7 +901,8 @@ downstream_targets <- list(
name = plot_quantile_comparison,
command = get_plot_quantile_comparison(
all_hosp_quantiles,
- eval_hosp_data
+ eval_hosp_data,
+ eval_config$figure_dir
),
pattern = map(all_hosp_quantiles),
iteration = "list"
diff --git a/_targets_eval_postprocessing.R b/_targets_eval_postprocessing.R
index c9e31f52..75d60418 100644
--- a/_targets_eval_postprocessing.R
+++ b/_targets_eval_postprocessing.R
@@ -105,6 +105,15 @@ upstream_targets <- list(
last_hosp_data_date = eval_config$eval_date,
ww_data_mapping = eval_config$ww_data_mapping
)
+ ),
+ # Returns a dataframe with each location and date and a corresponding
+ # epidemic phase
+ tar_target(
+ name = epidemic_phases,
+ command = get_epidemic_phases_from_rt(
+ locations = unique(eval_config$location_ww),
+ retro_rt_path = eval_config$retro_rt_path
+ )
)
)
@@ -141,14 +150,15 @@ combined_targets <- list(
## Flags------------------------------------------------------------------
tar_target(
name = all_flags_ww,
- command = combine_outputs(
- output_type = "flags",
- scenarios = eval_config$scenario,
- forecast_dates = eval_config$forecast_dates,
- locations = eval_config$location_ww,
- eval_output_subdir = eval_config$output_dir,
- model_type = "ww"
- )
+ command =
+ combine_outputs(
+ output_type = "flags",
+ scenarios = eval_config$scenario,
+ forecast_dates = eval_config$forecast_date_ww,
+ locations = eval_config$location_ww,
+ eval_output_subdir = eval_config$output_dir,
+ model_type = "ww"
+ )
),
tar_target(
name = all_flags_hosp,
@@ -161,6 +171,25 @@ combined_targets <- list(
model_type = "hosp"
)
),
+ tar_target(
+ name = all_flags,
+ command = dplyr::bind_rows(all_flags_ww, all_flags_hosp)
+ ),
+ tar_target(
+ name = convergence_df_ww,
+ command = get_convergence_df(
+ all_flags_ww,
+ default_scenario = "status_quo"
+ ) |>
+ dplyr::rename(any_flags_ww = any_flags)
+ ),
+ tar_target(
+ name = convergence_df_hosp,
+ command = get_convergence_df(all_flags_hosp,
+ default_scenario = "no_wastewater"
+ ) |>
+ dplyr::rename(any_flags_hosp = any_flags)
+ ),
### Scores from quantiles-------------------------------------------------
tar_target(
name = all_ww_scores_quantiles,
@@ -169,7 +198,7 @@ combined_targets <- list(
scenarios = eval_config$scenario,
forecast_dates = eval_config$forecast_date_ww,
locations = eval_config$location_ww,
- eval_output_subdir = file.path("output", "eval"),
+ eval_output_subdir = eval_config$output_dir,
model_type = "ww"
)
),
@@ -180,7 +209,7 @@ combined_targets <- list(
scenarios = "no_wastewater",
forecast_dates = eval_config$forecast_date_hosp,
locations = eval_config$location_hosp,
- eval_output_subdir = file.path("output", "eval"),
+ eval_output_subdir = eval_config$output_dir,
model_type = "hosp"
)
),
@@ -192,7 +221,7 @@ combined_targets <- list(
scenarios = eval_config$scenario,
forecast_dates = eval_config$forecast_date_ww,
locations = eval_config$location_ww,
- eval_output_subdir = file.path("output", "eval"),
+ eval_output_subdir = eval_config$output_dir,
model_type = "ww"
)
),
@@ -203,7 +232,7 @@ combined_targets <- list(
scenarios = "no_wastewater",
forecast_dates = eval_config$forecast_date_hosp,
locations = eval_config$location_hosp,
- eval_output_subdir = file.path("output", "eval"),
+ eval_output_subdir = eval_config$output_dir,
model_type = "hosp"
)
),
@@ -214,7 +243,7 @@ combined_targets <- list(
scenarios = eval_config$scenario,
forecast_dates = eval_config$forecast_date_ww,
locations = eval_config$location_ww,
- eval_output_subdir = file.path("output", "eval"),
+ eval_output_subdir = eval_config$output_dir,
model_type = "ww"
)
),
@@ -226,7 +255,7 @@ combined_targets <- list(
scenarios = eval_config$scenario,
forecast_dates = eval_config$forecast_date_ww,
locations = eval_config$location_ww,
- eval_output_subdir = file.path("output", "eval"),
+ eval_output_subdir = eval_config$output_dir,
model_type = "ww"
)
),
@@ -237,15 +266,131 @@ combined_targets <- list(
scenarios = "no_wastewater",
forecast_dates = eval_config$forecast_date_hosp,
locations = eval_config$location_hosp,
- eval_output_subdir = file.path("output", "eval"),
+ eval_output_subdir = eval_config$output_dir,
model_type = "hosp"
)
)
)
+# Head-to-head comparison targets-------------------------------------------
+# This set of targets will be conditioned on the presence of sufficient
+# wastewater, whereas the below targets assume that for every location and
+# forecast date we had to submit a forecast, and so we used the hospital
+# admissions only model if wastewater was missing.
+# These are only relevant for the status quo scenario
+head_to_head_targets <- list(
+ tar_target(
+ name = all_ww_quantiles_sq,
+ command = all_ww_quantiles |>
+ dplyr::filter(scenario == "status_quo")
+ ),
+ # Get a table of locations and forecast dates with sufficient wastewater
+ tar_target(
+ name = table_of_loc_dates_w_ww,
+ command = get_table_sufficient_ww(all_ww_quantiles)
+ ),
+ # Get a table indicating whether there are locations and forecast dates with
+ # convergence issues
+ tar_target(
+ name = convergence_df,
+ command = convergence_df_hosp |>
+ dplyr::left_join(convergence_df_ww,
+ by = c("location", "forecast_date")
+ )
+ ),
+ # Get the full set of quantiles, filtered down to only states and
+ # forecast dates with sufficient wastewater for both ww model and hosp only
+ # model. Then join the convergence df
+ tar_target(
+ name = hosp_quantiles_filtered,
+ command = dplyr::bind_rows(
+ all_ww_hosp_quantiles,
+ all_hosp_model_quantiles
+ ) |>
+ dplyr::left_join(table_of_loc_dates_w_ww,
+ by = c("location", "forecast_date")
+ ) |>
+ dplyr::filter(
+ ww_sufficient # filters to location forecast dates with sufficient ww
+ ) |>
+ dplyr::left_join(
+ convergence_df,
+ by = c(
+ "location",
+ "forecast_date"
+ )
+ ) |>
+ dplyr::left_join(
+ epidemic_phases,
+ by = c(
+ "location" = "state_abbr",
+ "date" = "reference_date"
+ )
+ )
+ ),
+ # Do the same thing for the sampled scores, combining ww and hosp under
+ # the status quo scenario, filtering to the locations and forecast dates
+ # with sufficient wastewater, and then joining the convergence flags
+ tar_target(
+ name = scores_filtered,
+ command = dplyr::bind_rows(
+ all_hosp_scores,
+ all_ww_scores |>
+ dplyr::filter(scenario == "status_quo")
+ ) |>
+ dplyr::left_join(table_of_loc_dates_w_ww,
+ by = c("location", "forecast_date")
+ ) |>
+ dplyr::filter(ww_sufficient) |>
+ dplyr::left_join(
+ convergence_df,
+ by = c(
+ "location",
+ "forecast_date"
+ )
+ ) |>
+ dplyr::left_join(
+ epidemic_phases,
+ by = c(
+ "location" = "state_abbr",
+ "date" = "reference_date"
+ )
+ )
+ ),
+ # Repeat for the quantile-based scores
+ tar_target(
+ name = scores_quantiles_filtered,
+ command = dplyr::bind_rows(
+ all_hosp_scores_quantiles,
+ all_ww_scores_quantiles |>
+ dplyr::filter(scenario == "status_quo")
+ ) |>
+ dplyr::left_join(table_of_loc_dates_w_ww,
+ by = c("location", "forecast_date")
+ ) |>
+ dplyr::filter(
+ isTRUE(ww_sufficient)
+ ) |>
+ dplyr::left_join(
+ convergence_df,
+ by = c(
+ "location",
+ "forecast_date"
+ )
+ ) |>
+ dplyr::left_join(
+ epidemic_phases,
+ by = c(
+ "location" = "state_abbr",
+ "date" = "reference_date"
+ )
+ )
+ )
+)
+
-# Head-to-head-scenario targets------------------------------------------------
-head_to_head_scenario_targets <- list(
+# Scenario targets------------------------------------------------
+scenario_targets <- list(
tar_target(
name = all_raw_scores,
command = data.table::as.data.table(
@@ -262,18 +407,7 @@ head_to_head_scenario_targets <- list(
name = all_errors,
command = dplyr::bind_rows(all_hosp_errors, all_ww_errors)
),
- tar_target(
- name = all_flags,
- command = dplyr::bind_rows(all_flags_ww, all_flags_hosp)
- ),
- tar_target(
- name = nonconverge_df,
- command = all_flags |> distinct(
- location, forecast_date, model, scenario
- ) |>
- dplyr::mutate(convergence = FALSE) # Add a column that indicates did not pass
- # convergence. We will use this in the head-to-head comparison to statify by convergence
- ),
+
## Raw scores-----------------------------------------
# These are the scores from each scenario and location without buffering
@@ -653,7 +787,8 @@ hub_comparison_plots <- list(
list(
upstream_targets,
combined_targets,
- head_to_head_scenario_targets,
+ head_to_head_targets,
+ scenario_targets,
hub_targets,
hub_comparison_plots
)
diff --git a/model_definition.md b/model_definition.md
index e2f04e3a..f1e8b178 100644
--- a/model_definition.md
+++ b/model_definition.md
@@ -119,7 +119,9 @@ where $\gamma$ is the _infection feedback term_ controlling the strength of the
Following other semi-mechanistic renewal frameworks, we model the _expected_ hospital admissions per capita $H(t)$ as a convolution of the _expected_ latent incident infections per capita $I(t)$, and a discrete infection to hospitalization distribution $d(\tau)$, scaled by the probability of being hospitalized $p_\mathrm{hosp}(t)$.
-To account for day-of-week effects in hospital reporting, we use an estimated _weekday effect_ $\omega(t)$. If $t$ and $t'$ are the same day of the week, $\omega(t) = \omega(t')$. The seven values that $\omega(t)$ takes on are constrained to have mean 1.
+To account for day-of-week effects in hospital reporting, we use an estimated _weekday effect_ $\omega(t)$. If $t$ and $t'$ are the same day of the week, $\omega(t) = \omega(t')$.
+The seven values that $\omega(t)$ takes on are constrained to be non-negative and have a mean of 1.
+This allows us to model the possibility that certain days of the week could have systematically high or low admissions reporting while holding the predicted weekly total reported admissions constant (i.e. the predicted weekly total is the same with and without these day-of-week reporting effects).
$$H(t) = \omega(t) p_\mathrm{hosp}(t) \sum_{\tau = 0}^{T_d} d(\tau) I(t-\tau)$$
diff --git a/scratch/ar1_example.R b/scratch/ar1_example.R
new file mode 100644
index 00000000..a6607f4f
--- /dev/null
+++ b/scratch/ar1_example.R
@@ -0,0 +1,24 @@
+ar1 <- function(mu, ac, sd, z) {
+ n <- length(z)
+ x <- rep(0, n)
+ tvd <- rep(0, n)
+
+ tvd[1] <- sd * z[1]
+ x[1] <- mu[1] + tvd[1]
+
+ for (i in 2:n) {
+ tvd[i] <- ac * tvd[i - 1] + sd * z[i]
+ x[i] <- mu[i] + tvd[i]
+ }
+ return(x)
+}
+
+p_hosp_mean <- rep(0.01, 26)
+p_hosp_logit <- qlogis(p_hosp_mean)
+ac <- 0.01
+sd <- 0.3
+z <- rnorm(26)
+
+p_hosp_t_logit <- ar1(p_hosp_logit, ac, sd, z)
+
+plot(plogis(p_hosp_t_logit))
diff --git a/scratch/debug_ww_data.R b/scratch/debug_ww_data.R
new file mode 100644
index 00000000..d651088d
--- /dev/null
+++ b/scratch/debug_ww_data.R
@@ -0,0 +1,574 @@
+# Take a look at the SCAN data
+library(zoo)
+library(ggridges)
+library(magrittr)
+
+nwss <- readr::read_csv(here::here("input", "ww_data", "nwss_data", "2024-01-30.csv"))
+
+test <- nwss %>% dplyr::filter(lab_id == 4702)
+
+# EDA on outliers
+nwss_subset <- init_subset_nwss_data(nwss)
+
+ww_target_type <- "pcr_target_avg_conc"
+ww_data <- nwss_subset %>%
+ ungroup() %>%
+ rename(
+ date = sample_collect_date,
+ ww = {{ ww_target_type }},
+ ww_pop = population_served
+ ) %>%
+ mutate(
+ location = toupper(wwtp_jurisdiction),
+ site = wwtp_name,
+ lab = lab_id
+ # since we might expect to see
+ ) %>%
+ select(
+ date, location, ww, site, lab, lab_wwtp_unique_id, ww_pop,
+ below_LOD, lod_sewage
+ )
+
+# Add the county names to the WW data
+site_county_map <- get_site_county_map(
+ nwss,
+ county_site_map_path = file.path("input", "ww_data", "county_site_map.csv")
+)
+ww_data2 <- ww_data %>%
+ left_join(site_county_map,
+ by = "site"
+ ) %>%
+ mutate(
+ full_county_name = ifelse(is.na(full_county_name),
+ glue::glue("{county_codes}, {location}"),
+ full_county_name
+ )
+ )
+ww_data_outliers_flagged <- flag_ww_outliers(ww_data2)
+
+ww_data_mod <- ww_data_outliers_flagged %>%
+ group_by(lab_wwtp_unique_id) %>%
+ arrange(date, "desc") %>%
+ mutate(
+ log_conc = log(ww)
+ ) %>%
+ mutate(
+ log_conc_t_min_1 = lag(log_conc, 1),
+ log_conc_t_min_2 = lag(log_conc, 2),
+ log_conc_t_plus_1 = lead(log_conc, 1),
+ log_conc_t_plus_2 = lead(log_conc, 2)
+ ) %>%
+ mutate(
+ exp_log_conc = mean(
+ c(
+ log_conc_t_min_1, log_conc_t_min_2,
+ log_conc_t_plus_2, log_conc_t_plus_2
+ ),
+ na.rm = TRUE
+ )
+ ) %>%
+ mutate(
+ norm_dif_true_v_exp = abs((log_conc - exp_log_conc) / (exp_log_conc)),
+ dif_true_v_exp = abs(log_conc - exp_log_conc)
+ )
+
+mean_norm_dif <- mean(ww_data_mod$norm_dif_true_v_exp, na.rm = TRUE)
+median_norm_dif <- median(ww_data_mod$norm_dif_true_v_exp, na.rm = TRUE)
+
+ggplot(ww_data_mod) +
+ geom_density(
+ aes(
+ x = norm_dif_true_v_exp,
+ fill = as.factor(flag_as_ww_outlier)
+ ),
+ alpha = 0.3
+ ) +
+ geom_vline(xintercept = mean_norm_dif, linetype = "dashed") +
+ geom_vline(xintercept = median_norm_dif, linetype = "dotted") +
+ scale_x_continuous(trans = "log") +
+ theme_bw() +
+ guides(fill = guide_legend(title = "Outlier?")) +
+ xlab("Normalized difference expected and real")
+
+ggplot(ww_data_mod) +
+ geom_density(
+ aes(
+ x = dif_true_v_exp,
+ fill = as.factor(flag_as_ww_outlier)
+ ),
+ alpha = 0.3
+ ) +
+ scale_x_continuous(trans = "log") +
+ theme_bw() +
+ guides(fill = guide_legend(title = "Outlier?")) +
+ xlab("Difference expected and real")
+ggplot(ww_data_mod) +
+ geom_density(aes(x = dif_true_v_exp),
+ alpha = 0.3
+ ) +
+ scale_x_continuous(trans = "log") +
+ theme_bw() +
+ xlab("Difference expected and real")
+
+ggplot(ww_data_mod) +
+ geom_density(aes(x = norm_dif_true_v_exp)) +
+ geom_vline(xintercept = median_norm_dif, linetype = "dotted") +
+ geom_vline(xintercept = mean_norm_dif, linetype = "dashed") +
+ scale_x_continuous(trans = "log") +
+ theme_bw() +
+ xlab("Normalized difference expected and real")
+
+ggplot(ww_data_mod) +
+ geom_point(
+ aes(
+ x = log_conc, y = norm_dif_true_v_exp,
+ color = as.factor(flag_as_ww_outlier)
+ ),
+ size = 0.1, alpha = 0.3
+ ) +
+ scale_x_continuous(trans = "log") +
+ scale_y_continuous(trans = "log") +
+ theme_bw() +
+ guides(color = guide_legend(title = "Outlier?")) +
+ ylab("Normalized difference expected and real") +
+ xlab("log(conc)")
+
+ggplot(ww_data_mod) +
+ geom_point(
+ aes(
+ x = log_conc, y = dif_true_v_exp,
+ color = as.factor(flag_as_ww_outlier)
+ ),
+ size = 0.1, alpha = 0.3
+ ) +
+ scale_x_continuous(trans = "log") +
+ scale_y_continuous(trans = "log") +
+ theme_bw() +
+ guides(color = guide_legend(title = "Outlier?")) +
+ ylab("Difference expected and real") +
+ xlab("log(conc)")
+
+
+
+
+
+
+
+
+# check for duplicates
+test <- nwss %>%
+ select(
+ sample_collect_date, wwtp_name,
+ lab_id
+ ) %>%
+ unique()
+unique_wwtps <- nwss %>%
+ select(wwtp_name) %>%
+ unique()
+unique_labs <- nwss %>%
+ select(lab_id) %>%
+ unique()
+unique_combos_map <- nwss %>%
+ select(wwtp_name, lab_id) %>%
+ unique() %>%
+ mutate(lab_wwtp_unique_id = row_number())
+
+nwss_w_unique_ids <- nwss %>% left_join(unique_combos_map,
+ by = c("wwtp_name", "lab_id")
+)
+
+test2 <- nwss %>%
+ select(sample_collect_date, wwtp_name, lab_id) %>%
+ unique()
+test3 <- nwss %>%
+ select(sample_collect_date, wwtp_name) %>%
+ unique()
+
+test <- nwss %>%
+ group_by(sample_collect_date, wwtp_name) %>%
+ reframe(lab_ids = (unique(lab_id)))
+
+source("R/pre_processing.R")
+
+ww_data_summarized <- nwss %>%
+ group_by(wwtp_jurisdiction) %>%
+ select(wwtp_name) %>%
+ unique() %>%
+ summarise(n_sites_per_state = n()) %>%
+ left_join(
+ nwss %>%
+ filter(sample_collect_date >= ymd(today() - months(1))) %>%
+ group_by(wwtp_jurisdiction) %>%
+ select(wwtp_name) %>%
+ unique() %>%
+ summarise(n_recent_sites_per_state = n())
+ )
+
+median(ww_data_summarized$n_recent_sites_per_state)
+
+ggplot(ww_data_summarized) +
+ geom_histogram(aes(x = n_recent_sites_per_state), fill = "blue", alpha = 0.3) +
+ geom_vline(aes(xintercept = 14), linetype = "dashed") +
+ xlab("Number of sites per state") +
+ ylab("Number of states") +
+ theme_bw() +
+ ggtitle("Distribution of number of sites per state")
+
+ggplot(nwss %>% filter(wwtp_jurisdiction %in% c("nj"))) +
+ geom_line(
+ aes(
+ x = ymd(sample_collect_date), y = pcr_target_avg_conc,
+ color = as.factor(wwtp_name)
+ ),
+ show.legend = FALSE
+ ) +
+ facet_wrap(~sample_matrix, nrow = 3) +
+ xlab("") +
+ ylab("Avg PCR concentration") +
+ theme_bw() +
+ ggtitle("MA WW broken down by type of sampling")
+
+ggplot(nwss %>% filter(wwtp_jurisdiction %in% c("ca"))) +
+ geom_line(
+ aes(
+ x = ymd(sample_collect_date), y = pcr_target_avg_conc,
+ color = as.factor(wwtp_name)
+ ),
+ alpha = 0.5, size = 0.5,
+ show.legend = FALSE
+ ) +
+ # facet_wrap(~wwtp_name, scales = 'free') +
+ facet_wrap(~pcr_target_units, nrow = 3) +
+ xlab("") +
+ ylab("Avg PCR concentration") + # scale_y_log10()+
+ theme_bw() +
+ ggtitle("CA WW broken down by unit type")
+
+
+
+
+
+
+
+nwss_subset_raw <- init_subset_nwss_data(nwss)
+
+
+
+
+ggplot(nwss_subset_raw %>% filter(wwtp_jurisdiction %in% c("ny"))) +
+ geom_line(
+ aes(
+ x = ymd(sample_collect_date), y = pcr_target_avg_conc,
+ group = wwtp_name, color = wwtp_name
+ ),
+ show.legend = FALSE
+ ) +
+ facet_wrap(~wwtp_jurisdiction, scales = "free") +
+ xlab("") +
+ ylab("Avg PCR concentration") +
+ theme_bw()
+
+ggplot(nwss_subset_raw %>% filter(wwtp_name == 2023)) +
+ geom_point(
+ aes(
+ x = ymd(sample_collect_date), y = pcr_target_avg_conc,
+ group = wwtp_name, color = wwtp_name
+ ),
+ show.legend = FALSE
+ ) +
+ facet_wrap(~wwtp_name, scales = "free") +
+ xlab("") +
+ ylab("Avg PCR concentration") +
+ scale_y_log10() +
+ theme_bw()
+
+states <- c("vt", "nj", "ny", "tx", "ct", "ma", "ri", "ca", "al")
+states <- "ma"
+single_state_raw <- nwss_subset_raw %>%
+ remove_outliers() %>%
+ filter(
+ wwtp_jurisdiction %in% c(states),
+ sample_collect_date >= "2022-10-10",
+ sample_collect_date <= "2022-12-26"
+ )
+state <- "ma"
+ggplot(nwss_subset_raw %>% remove_outliers() %>% filter(wwtp_jurisdiction == state)) +
+ geom_density_ridges_gradient(aes(
+ y = ymd(sample_collect_date),
+ x = pcr_target_avg_conc,
+ group = sample_collect_date
+ ), jittered_points = TRUE) +
+ scale_fill_viridis_d() +
+ coord_flip() +
+ facet_wrap(~wwtp_jurisdiction) +
+ xlab("Site specific pcr_target_avc_conc ") +
+ ylab("") +
+ theme_bw() +
+ ggtitle(paste0("Within state WW concentration distributions across sites in ", toupper(state)))
+
+
+ggplot(single_state_raw) +
+ geom_line(
+ aes(
+ x = ymd(sample_collect_date),
+ y = pcr_target_avg_conc,
+ color = as.factor(wwtp_name)
+ ),
+ show.legend = FALSE, size = 0.5, alpha = 0.5
+ ) +
+ theme_bw() +
+ facet_wrap(~wwtp_jurisdiction, scales = "free") +
+ ylab("Site specific pcr_target_avg_conc") +
+ xlab("") +
+ labs(color = "Site") +
+ ggtitle(paste0("Distribution of concentrations across days and sites "))
+
+ggplot(single_state_raw) +
+ geom_line(
+ aes(
+ x = ymd(sample_collect_date),
+ y = pcr_target_flowpop_lin,
+ color = as.factor(wwtp_name)
+ ),
+ show.legend = FALSE, size = 0.5, alpha = 0.5
+ ) +
+ theme_bw() +
+ facet_wrap(~wwtp_jurisdiction, scales = "free") +
+ ylab("Site specific pcr_target_flowpop_lin") +
+ xlab("") +
+ labs(color = "Site") +
+ ggtitle(paste0("Distribution of concentrations across days and sites"))
+
+
+
+var_by_pop <- nwss_subset_raw %>%
+ remove_outliers() %>%
+ group_by(wwtp_name) %>%
+ summarise(
+ pop = max(population_served),
+ variance = var(pcr_target_avg_conc)
+ )
+
+ggplot(var_by_pop) +
+ geom_point(aes(x = pop, y = variance)) +
+ theme_bw() +
+ xlab("N") +
+ ylab("Variance in site level concentrations across time") +
+ scale_y_log10() +
+ scale_x_log10()
+
+
+
+
+
+nwss_subset <- remove_outliers(nwss_subset_raw)
+
+omicron <- nwss_subset %>% filter(
+ sample_collect_date >= "2022-10-10",
+ sample_collect_date <= "2023-01-16"
+)
+
+ggplot(omicron %>% filter(wwtp_jurisdiction %in% c("ma"))) +
+ geom_line(
+ aes(
+ x = ymd(sample_collect_date), y = pcr_target_avg_conc,
+ group = wwtp_name, color = as.factor(wwtp_name)
+ ),
+ show.legend = FALSE
+ ) +
+ xlab("") +
+ ylab("Avg PCR concentration") +
+ scale_x_date(
+ date_breaks = "2 weeks",
+ labels = scales::date_format("%Y-%m-%d")
+ ) +
+ theme(
+ axis.text.x = element_text(
+ size = 10, vjust = 0.5,
+ hjust = 1, angle = 45
+ ),
+ axis.title.x = element_text(size = 10),
+ axis.title.y = element_text(size = 10),
+ plot.title = element_text(
+ size = 9,
+ vjust = 0.5, hjust = 0.5
+ )
+ ) +
+ ggtitle("Site-level PCR concentration in Massachusetts winter 2022-2023") +
+ theme_bw()
+
+
+
+ggplot(nwss_subset %>% filter(wwtp_name == 2023)) +
+ geom_line(
+ aes(
+ x = ymd(sample_collect_date), y = pcr_target_avg_conc,
+ group = wwtp_name, color = wwtp_name
+ ),
+ alpha = 0.3, show.legend = FALSE
+ ) +
+ facet_wrap(~wwtp_jurisdiction, scales = "free") +
+ xlab("") +
+ ylab("Avg PCR concentration") +
+ theme_bw()
+
+ggplot(nwss_subset %>% filter(wwtp_jurisdiction %in% c("ny"))) +
+ geom_line(
+ aes(
+ x = ymd(sample_collect_date), y = pcr_target_avg_conc,
+ group = wwtp_name, color = wwtp_name
+ ),
+ alpha = 0.3, show.legend = FALSE
+ ) +
+ facet_wrap(~wwtp_jurisdiction, scales = "free") +
+ xlab("") +
+ ylab("Avg PCR concentration") +
+ theme_bw()
+
+
+
+
+
+nwss_by_week <- get_weekly_summary(nwss_subset)
+
+
+
+
+nwss_by_state <- get_state_level_summary(nwss_by_week)
+# Naive population-weighted state level averaging with national level averaging
+ggplot(Omicron %>% filter(wwtp_jurisdiction %in% c("ma"))) +
+ geom_line(
+ aes(
+ x = ymd(sample_collect_date), y = pcr_target_avg_conc,
+ group = wwtp_name, color = as.factor(wwtp_name)
+ ),
+ show.legend = FALSE
+ ) +
+ geom_point(
+ data = nwss_by_state %>% filter(
+ wwtp_jurisdiction == "ma",
+ midweek_date >= "2022-10-10",
+ midweek_date <= "2023-01-16"
+ ),
+ aes(x = midweek_date, y = pop_weighted_conc_w_thres),
+ shape = 24, fill = "black"
+ ) +
+ xlab("") +
+ ylab("Avg PCR concentration") +
+ scale_x_date(
+ date_breaks = "2 weeks",
+ labels = scales::date_format("%Y-%m-%d")
+ ) +
+ theme(
+ axis.text.x = element_text(
+ size = 10, vjust = 0.5,
+ hjust = 1, angle = 45
+ ),
+ axis.title.x = element_text(size = 10),
+ axis.title.y = element_text(size = 10),
+ plot.title = element_text(
+ size = 9,
+ vjust = 0.5, hjust = 0.5
+ )
+ ) +
+ ggtitle("Site-level PCR concentration in Massachusetts winter 2022-2023") +
+ theme_bw()
+
+
+omicron_by_state <- nwss_by_state %>% filter(
+ midweek_date <= "2021-12-31",
+ midweek_date >= "2021-07-01"
+)
+
+ggplot(omicron_by_state %>% filter(wwtp_jurisdiction %in% c(
+ "ny", "mo", "nc", "ca",
+ "va", "ma"
+))) +
+ geom_line(aes(x = ymd(midweek_date), y = pop_weighted_conc_w_thres),
+ show.legend = FALSE
+ ) +
+ facet_wrap(~wwtp_jurisdiction, scales = "free") +
+ xlab("") +
+ ylab("Avg PCR concentration") + # scale_y_log10()+
+ theme_bw()
+
+
+
+ggplot(nwss_by_state %>% filter(
+ wwtp_jurisdiction %in% c(
+ "ny", "va", "ca",
+ "tx", "fl", "ma"
+ ),
+ midweek_date >= "2023-01-01"
+)) +
+ geom_line(aes(x = ymd(midweek_date), y = pop_weighted_conc), color = "gray") +
+ geom_line(aes(x = ymd(midweek_date), y = unweighted_avg_conc), color = "darkblue") +
+ geom_line(aes(x = ymd(midweek_date), y = pop_weighted_conc_w_thres), color = "darkred") +
+ geom_line(aes(x = ymd(midweek_date), y = rlng_avg_pop_weighted_conc_w_thres), color = "purple4") +
+ facet_wrap(~wwtp_jurisdiction, scales = "free") +
+ coord_cartesian(xlim = c(ymd("2023-01-01"), ymd("2023-06-28"))) +
+ xlab("") +
+ ylab("Avg PCR concentration") +
+ theme_bw()
+
+ggplot(nwss_by_state %>% filter(wwtp_jurisdiction %in% c("ny"))) +
+ geom_line(aes(x = ymd(midweek_date), y = pop_weighted_conc),
+ color = "gray", alpha = 0.5
+ ) +
+ geom_line(aes(x = ymd(midweek_date), y = unweighted_avg_conc),
+ color = "darkblue", alpha = 0.5
+ ) +
+ geom_line(aes(x = ymd(midweek_date), y = pop_weighted_conc_w_thres),
+ color = "darkred", alpha = 0.5
+ ) +
+ geom_line(aes(x = ymd(midweek_date), y = rlng_avg_pop_weighted_conc_w_thres),
+ color = "purple4", alpha = 0.5
+ ) +
+ geom_line(aes(x = ymd(midweek_date), y = ntl_pop_weighted_conc),
+ color = "gray"
+ ) +
+ geom_line(aes(x = ymd(midweek_date), y = ntl_unweighted_avg_conc),
+ color = "darkblue"
+ ) +
+ geom_line(aes(x = ymd(midweek_date), y = ntl_pop_weighted_conc_w_thres),
+ color = "darkred"
+ ) +
+ geom_line(aes(x = ymd(midweek_date), y = rlng_avg_ntl_pop_weighted_conc_w_thres),
+ color = "purple4"
+ ) +
+ facet_wrap(~wwtp_jurisdiction, scales = "free") +
+ coord_cartesian(xlim = c(ymd("2021-01-01"), ymd("2023-06-28"))) +
+ xlab("") +
+ ylab("Avg PCR concentration") +
+ theme_bw() +
+ ggtitle("Viral concentration in WW calculated 3 ways")
+
+ggplot(nwss_by_state) +
+ geom_line(aes(
+ x = ymd(midweek_date), y = pop_weighted_conc_w_thres,
+ group = wwtp_jurisdiction
+ ), color = "darkred", alpha = 0.1) +
+ geom_line(aes(x = ymd(midweek_date), y = ntl_pop_weighted_conc_w_thres),
+ color = "black", alpha = 1
+ ) +
+ coord_cartesian(xlim = c(ymd("2021-01-01"), ymd("2023-06-28"))) +
+ xlab("") +
+ ylab("Avg PCR concentration") +
+ theme_bw() +
+ ggtitle("Viral concentration state-level averages")
+
+# Would be cool to use ggdist to plot distributions over time within a state
+# across wwtps to see how they vary
+
+ggplot(nwss_by_week %>% filter(wwtp_jurisdiction == "nj")) +
+ geom_density_ridges_gradient(aes(
+ y = midweek_date,
+ x = site_weekly_avg_conc,
+ group = midweek_date
+ ), jittered_points = TRUE) +
+ scale_fill_viridis_d() +
+ coord_flip() +
+ xlab("Site specific concentration") +
+ ylab("") +
+ theme_bw() +
+ ggtitle("Within state WW concentration distributions across sites")
diff --git a/scratch/site_level_inf_dynamics.Rmd b/scratch/site_level_inf_dynamics.Rmd
new file mode 100644
index 00000000..79e692f4
--- /dev/null
+++ b/scratch/site_level_inf_dynamics.Rmd
@@ -0,0 +1,1281 @@
+---
+title: "Site level hierarchical model"
+author: "Kaitlyn Johnson"
+date: "2023-09-01"
+output: html_document
+---
+
+```{r setup, include=FALSE}
+library(cfaforecastrenewalww)
+library(cmdstanr)
+library(lubridate)
+library(ggplot2)
+library(dplyr)
+library(here)
+library(tidybayes)
+source(here::here("src/write_config.R"))
+knitr::opts_chunk$set(echo = TRUE)
+cfaforecastrenewalww::setup_secrets(here::here("secrets.yaml"))
+```
+
+# Motivation
+- This is an implementation of a renewal approach for inference of hospitalization and wastewater viral concentrations by estimating a time-varying effective reproductive number $R(t)$. Specifically
+this Rmd is focused on implementing a hierarchical site-level infectious disease
+dynamics model. This assumes that at each time step, the $R(t)$ at each site is
+drawn from a shared lognormal distribution with mean $\mu_{R(t)state}$.
+This means functionally that each site is allowed to have its own disease dyanmics
+that are informed by information from the other sites and constrained by the
+overall state-level hospital admissions. We will
+continue to allow for site level observation error and a site level multiplier
+on the observed wastewater concentration.
+
+# Approach
+- For each state, we will treat the combination of sites/measurements as observations
+from the underlying wastewater concentration in the whole date
+- For the hospital admissions data, we use state-level total hospital admission data
+that was available as of the forecast date, using the covidcast package and the
+`as_of` input. We impose a 9 day reporting delay to mirror the current data scenario.
+
+
+Note: This Rmd mirrors the single model, single location, single forecast date pipeline
+set up in the `_targets.R` file. We will use the Rmd format to continue to refine the model.
+
+
+# Get the state level data from NWSS
+```{r}
+ww_data_path <- save_timestamped_nwss_data(
+ ww_path_to_save =
+ here::here(
+ "input", "ww_data",
+ "nwss_data"
+ )
+)
+
+config_written <- write_config(
+ save_config = TRUE,
+ config_path = here::here("input", "config"),
+ location = "GA",
+ prod_run = FALSE,
+ run_id = "test",
+ date_run = lubridate::today(),
+ model_type = "site-level infection dynamics",
+ forecast_date = "2024-04-22",
+ hosp_data_source = "NHSN",
+ pull_from_local = FALSE,
+ hosp_data_dir = here::here("input", "hosp_data", "vintage_datasets"),
+ population_data_path = here::here("input", "locations.csv"),
+ param_file_path = here::here("input", "params.toml"),
+ include_ww = 1,
+ hosp_reporting_delay = 5,
+ ww_data_path = here::here(
+ "input", "ww_data",
+ "nwss_data", "2024-04-21.csv"
+ )
+)
+
+
+config_vars_ss <- get_config_vals(config_written)
+# Site level WW data from NWSS
+
+ww_data_raw <- do.call(get_ww_data, config_vars_ss)
+
+ww_data <- ww_data_raw %>% filter(location == config_vars_ss$location)
+```
+# Subsample sites
+We want to be able to compare the model forecasts (R(t) and hospital admissions forecasts) for a state
+with a large number of sites fit to all of them vs fit to a subset of them, and see if this meaningfully impacts results.
+```{r}
+ww_data_subsampled <- subsample_sites(ww_data,
+ prop_sites = 0.2
+)
+```
+
+```{r}
+# Get training data
+train_data_orig <- do.call(
+ get_all_training_data,
+ c(list(ww_data_raw = ww_data), config_vars_ss,
+ subsample_sites = 1,
+ prop_sites = 0.2
+ )
+)
+train_data <- flag_ww_outliers(train_data_orig)
+train_data <- train_data %>% mutate(
+ site_lab = stringr::str_glue("Site: {site}, Lab: {lab}")
+)
+
+ggplot(train_data %>% filter(period != "forecast", !is.na(site))) +
+ geom_line(aes(x = date, y = ww, color = as.factor(site), group = site),
+ show.legend = FALSE
+ ) +
+ geom_point(aes(x = date, y = ww, color = as.factor(site), group = site),
+ show.legend = FALSE
+ ) +
+ theme_bw() +
+ facet_wrap(~site) +
+ xlab("") +
+ ylab("Genome copies per mL") +
+ scale_y_continuous(trans = "log") +
+ guides(color = guide_legend(title = "Site")) +
+ scale_x_date(
+ date_breaks = "2 weeks",
+ labels = scales::date_format("%Y-%m-%d")
+ )
+ggplot(train_data %>% filter(period != "forecast", !is.na(site))) +
+ geom_line(aes(x = date, y = ww, color = as.factor(site_lab), group = site_lab)) +
+ geom_point(aes(x = date, y = ww, color = as.factor(site_lab), group = site),
+ show.legend = FALSE
+ ) +
+ geom_vline(aes(xintercept = forecast_date), linetype = "dashed") +
+ theme_bw() +
+ # facet_wrap(~site_lab, scales = "free")+
+ xlab("") +
+ ylab("Genome copies per mL") +
+ # scale_y_continuous(trans = "log") +
+ guides(color = guide_legend(title = "")) +
+ scale_x_date(
+ date_breaks = "2 weeks",
+ labels = scales::date_format("%Y-%m-%d")
+ ) +
+ theme(
+ axis.text.x = element_text(
+ size = 10, vjust = 1,
+ hjust = 1, angle = 45
+ ),
+ axis.title.x = element_text(size = 10),
+ axis.title.y = element_text(size = 10),
+ plot.title = element_text(
+ size = 9,
+ vjust = 0.5, hjust = 0.5
+ )
+ )
+
+
+ggplot(train_data) +
+ geom_density(aes(
+ x = log(ww),
+ fill = as.factor(flag_as_ww_outlier)
+ ), alpha = 0.3) +
+ theme_bw() +
+ guides(fill = guide_legend(title = "Outlier?"))
+
+ggplot(train_data) +
+ geom_density(aes(x = log(ww), ), alpha = 0.3, fill = "gray") +
+ theme_bw()
+```
+
+# Model description
+
+We have observed data from 90 days of hospital admissions and 14 weeks of wastewater measurements for a single state. We are going to make all of the simplifying assumptions:
+- the population shedding into the WW is the same as those captured by the hospital admissions
+- the IHR is constant during the period of inference
+- the shedding kinetics and shedding amount are constant during the period of inference
+- for now, we will assume that the number of genomes shed per individual is constant
+across infected individuals, since the population is large enough that we don't expect
+inter-individual variability to add too much noise to the signal
+
+This model builds heavily off of (and in some instances borrows functions from) the
+[EpiNow2 R package](https://github.com/epiforecasts/EpiNow2)
+
+## Renewal Equation Stan Model
+In brief, this model assumes that incident infections $I(t)$ are generated from a vector of $R(t)$ values and an initial number of infections $I(0)$ seeded on day 0:
+$$I(t) = R(t) \sum_{s=1}^{t}I(t-s)g(s)$$
+Where $g(s)$ is the generation interval, which describes the distribution of times
+from incident infection to secondary infection (i.e. infectiousness profile) and $R(t)$ describes the number of expected secondary infections of an index case at time $t$. Because the data does not start at the initial seeding event, we assume that prior to the first observation initial infections grow exponentially to give rise to the early observations such that $I(t') = I(0)exp(rt')$ where $t'$ is the "unobserved time" before the first observed data point.
+
+The model estimates $R(t)$ by estimating a weekly random walk such that $R(k) = R(k-1)\eta$, where $\eta ~N(0, \eta_{sd}$, $k$ is the week, and the magnitude of the step size, $eta_{sd}$ is estimated in the model calibration. For the $R(t)$ during the forecast period, we assume $R(t)$ is dampened in proportion to the number of cumulative infections following an SIR approximation, which is implemented when damp_type =1 (see EpiNow2 documentation for details) or we assume that the R(t) is dampened by the current number of infections with a drift term, per Jason's Ebola model. We will likely modify both of these significantly, particularly if we observe consisteny over or underprediction in the forecasts.
+
+
+We model the process of generating two type of observations. First the hospital admissions: Each infectee is assumed to have some (independent) chance of getting hospitalizated $p_{hosp}$, and *if* they will eventually be hospitalized then the probability distribution for days after infection before observation and reporting is $d(t)$. Therefore, conditional on some time series of state level daily infections generated by one of the models, $I_{state\mu}(t)$, the *expected* number of state-level reported hospitalizations on each day $t$ is,
+
+$$H[t] = p_{hosp}\sum_{\tau\geq 0}d[\tau] I_{state\mu}[t-\tau]$$
+We estimate $p_{hosp}$ and for now are setting the hospital admissions delay distribution $d(t)$. We assume that the observation process is negative binomial, such that:
+$$\overline{H}[t] \sim NegBinom(H[t], \phi_h)$$
+
+
+In a similar fashion for generated the expected WW observations, we assume that each individual follows the same shaped shedding kinetics distribution $S(t)$. We assume the shedding kinetics follows a hinge function
+in log10 space.
+$$\begin{align*}
+log10(S[t]) =
+
+ \begin{cases}
+ \frac{V_{peak}}{t_{peak}}t & t\leq t_{peak} \\
+ V_{peak} + wt_{peak} - wt & t \geq t_{peak} \\
+ \end{cases}
+ \end{align*}
+$$
+
+Where $V_{peak}$ is the log10 peak viral load that occurs at $t_{peak}$ days since infection onset, and $w$ is the rate at which viral load wanes after $t_{peak}$. We convert to natural scale with $S[t] = 10^{log10(S[t])}$. The parameters of this hinge function are estimated (with relatively strong priors taken from fecal shedding literature and literature on viral loads in nasal passages). The individual level viral kinetics $S[t]$ of each infected individual are normalized to sum to 1 over the course of the infection and are multiplied by $G$, the number of genome copies shed by each infected individual throughout their infection. Therefore, conditional on some time series of daily infections in each site $I_j(t)$, the *expected* number of viral genomes shed in WW on each day $t$ is,
+
+$$V_{j}[t] = \ G\sum_{\tau\geq 0} S(\tau)I_j(t-\tau)$$
+Where $\sum_{\tau\geq 0} S(\tau)I_j(t-\tau)$ is the net number of infected
+individuals in site $j$ on each day $t$, which we will call $\iota_j(t)$.
+While $S[t]$ describes the kinetics of how an infected individual sheds over the course of
+their infection, there is also individual variabilty in the total amount of virus shed per infected individual over the course of their infection, $G$.
+
+## Dispersion in viral genomes shed per infected individual
+We will assume that an individual sheds $G_i$ genomes per infection and $G_i$ is negative binomially distributed with mean $\mu_G$ and dispersion $phi$,
+$$ G_i \sim NegBinom(\mu_G, \phi) $$
+Then we can write the expected sum of the genomes shed from the infected individuals in site $j$
+as, $G_j(t) = \sum_{i=1}^\iota_j(t) G_i$:
+$$ G_j(t) \sim NegBinom(\iota_j(t)\mu_G, \iota_j(t)\phi) $$
+where $N_j$ is the number of infected individuals in each site on each day. This is is a result of a fun trick that the sum of Negative Binomials are Negative Binomials.
+What it means is rather intuitive,
+that as you have a larger population of infected individuals shedding into the WW,
+you observe lower overall dispersion ($\phi$ increases). $G_j(t)$ is the expected
+number of viral genomes shed in each site on each day. To implement this, we will
+use the Gaussian approximation described [here](https://academic.oup.com/jrsssa/article/185/Supplement_1/S65/7069481?login=true#rssa12971-sec-0020).
+
+
+In practice, each site
+processes samples differently, which adds both a site-level scaling factor $M_j$
+and site level variability, which we will call $\phi_j$. If we assume observation
+are also Negative Binomial, then the observed genomes in each site on each day
+$\overline{G}_j(t)$ can be defined as:
+$$ \overline{G}_j(t) \sim NegBinom(M_jG_j(t), \phi_j) $$
+This defines a model with a site-level scaling factor and a site-level measurement error.
+Our observations will be in terms of genomes per person per day, so $\overline{G*}_j(t) = \frac{\overline{G}_j(t)}{N_j}$ where $N_j$ is the population in the WW catchment area. Note the observation errors don't have to be Negative Binomially distributed, could also use
+something like a student t.
+
+
+
+
+
+## Hierarchical infectious disease dynamics
+We will assume that at each time step, the site level infections $I_j(t) = R_j(t) \sum_{s=1}^{t}I_j(t-s)g(s)$ are drawn from a shared distribution of state-level $R(t)$ such that:
+$$ R_j(t) \sim logNormal(R_{state\mu}(t), \sigma_{R(t)})$$
+
+
+
+# Assign parameters
+```{r}
+# This contains all the model priors and parameter settings. Informative priors
+# include the shedding viral kinetic parameters (timing of peak viral shedding,
+# magnitude of peak viral shedding, and duration of viral shedding). These
+# were informed in combination from literature on fecal shedding and viral loads
+# in the nasal package.
+params <- get_params(config_vars_ss$param_file_path) #
+print(params)
+
+params$p_hosp_w_sd_sd <- 0.05
+# Pull the prior hyperparameters into the global environment so you can create
+# the init_fun to be passed to stan
+par_names <- colnames(params)
+for (i in seq_along(par_names)) {
+ assign(par_names[i], as.double(params[i]))
+}
+
+
+stan_data <- do.call(get_stan_data_site_level_model, c(
+ config_vars_ss,
+ list(params = params),
+ list(train_data = train_data)
+))
+
+pop <- train_data %>%
+ select(pop) %>%
+ unique() %>%
+ pull(pop)
+stopifnot("More than one population size in training data" = length(pop) == 1)
+
+n_weeks <- as.numeric(stan_data$n_weeks)
+tot_weeks <- as.numeric(stan_data$tot_weeks)
+n_ww_sites <- as.numeric(stan_data$n_ww_sites)
+n_ww_lab_sites <- as.numeric(stan_data$n_ww_lab_sites)
+ot <- stan_data$ot
+ht <- stan_data$ht
+
+# Estimate of number of initial infections
+i0 <- mean(train_data$daily_hosp_admits[1:7]) / p_hosp_mean
+```
+
+
+
+
+# Initialize the parameter search using center of the priors + a bit of noise
+```{r}
+init_fun <- function() {
+ site_level_inf_inits(train_data, params, stan_data)
+}
+```
+
+# Compile the model
+```{r, echo = FALSE}
+model_file_path <- here::here(
+ "cfaforecastrenewalww", "inst",
+ "stan", "renewal_ww_hosp_site_level_inf_dynamics.stan"
+)
+
+model <- compile_model(model_filepath = model_file_path)
+```
+
+# Fit the model
+```{r, echo = FALSE}
+fit_dynamic_rt <- model$sample(
+ data = stan_data,
+ seed = 123,
+ # init = init_lists, #nolint
+ init = init_fun, # nolint
+ iter_sampling = 150,
+ iter_warmup = 250,
+ chains = 4,
+ parallel_chains = 4
+)
+```
+Quick test
+```{r}
+all_draws <- fit_dynamic_rt$draws()
+
+
+autoreg_rt_site <- all_draws %>%
+ spread_draws(autoreg_rt_site) %>%
+ mutate(draw = `.draw`) %>%
+ mutate(
+ name = "autoreg_rt_site",
+ t = NA
+ ) %>%
+ rename(value = autoreg_rt_site) %>%
+ select(name, t, value, draw)
+
+prior_autoreg_rt_site <- data.frame(value = rbeta(1000, autoreg_rt_a, autoreg_rt_b))
+prior_autoreg_rt_site <- data.frame(value = rbeta(1000, 1, 4))
+
+sigma_rt <- all_draws %>%
+ spread_draws(sigma_rt) %>%
+ sample_draws(ndraws = 100) %>%
+ mutate(draw = row_number()) %>%
+ mutate(
+ name = "sigma_rt",
+ ) %>%
+ rename(value = sigma_rt) %>%
+ select(name, value, draw)
+
+sigma_rt_prior <- data.frame(value = rnorm(1000, 0, params$sigma_rt_prior))
+ggplot() +
+ geom_density(data = sigma_rt_prior, aes(x = value), fill = "blue", alpha = 0.1) +
+ geom_density(data = sigma_rt, aes(x = value), fill = "red", alpha = 0.1) +
+ theme_bw() +
+ ggtitle("Average deviation of site and state R(t)") +
+ coord_cartesian(xlim = c(0, 1))
+
+ggplot() +
+ geom_density(data = prior_autoreg_rt_site, aes(x = value), fill = "blue", alpha = 0.1) +
+ geom_density(data = autoreg_rt_site, aes(x = value), fill = "red", alpha = 0.1) +
+ theme_bw() +
+ ggtitle("Autoreg coefficient in site level R(t) deviation") +
+ coord_cartesian(xlim = c(0, 0.8))
+```
+```{r}
+r_state <- all_draws %>%
+ spread_draws(rt[t]) %>%
+ rename(value = rt) %>%
+ mutate(
+ draw = `.draw`,
+ name = "r_state"
+ ) %>%
+ select(name, t, value, draw) %>%
+ group_by(t, name) %>%
+ summarise(
+ median_Rt = quantile(value, 0.5, na.rm = TRUE),
+ lb = quantile(value, 0.025, na.rm = TRUE),
+ ub = quantile(value, 0.975, na.rm = TRUE),
+ ub_50th = quantile(value, 0.75, na.rm = TRUE),
+ lb_50th = quantile(value, 0.25, na.rm = TRUE)
+ ) %>%
+ left_join(train_data %>% select(forecast_date, t, date) %>% distinct())
+
+
+site_map <- train_data %>%
+ select(site, site_index, ww_pop) %>%
+ distinct() %>%
+ filter(!is.na(site))
+
+if (sum(stan_data$subpop_size) < stan_data$state_pop) {
+ subpop_map <- site_map %>%
+ rbind(c(1, stan_data$subpop_size[stan_data$n_subpops])) %>%
+ rename(subpop_size = ww_pop)
+} else {
+ subpop_map <- site_map %>%
+ rename(subpop_size = ww_pop)
+}
+
+r_site <- all_draws %>%
+ spread_draws(r_site_t[site_index, t]) %>%
+ rename(value = r_site_t) %>%
+ mutate(
+ draw = `.draw`,
+ name = "r_site"
+ ) %>%
+ select(name, t, value, site_index, draw) %>%
+ group_by(t, site_index, name) %>%
+ summarise(
+ median_Rt = quantile(value, 0.5, na.rm = TRUE),
+ lb = quantile(value, 0.025, na.rm = TRUE),
+ ub = quantile(value, 0.975, na.rm = TRUE),
+ ub_50th = quantile(value, 0.75, na.rm = TRUE),
+ lb_50th = quantile(value, 0.25, na.rm = TRUE)
+ ) %>%
+ left_join(subpop_map) %>%
+ left_join(
+ train_data %>%
+ select(forecast_date, t, date) %>%
+ distinct()
+ ) %>%
+ mutate(
+ site_name = paste0("Site: ", site)
+ )
+
+ggplot(r_state) +
+ geom_line(aes(x = date, y = median_Rt)) +
+ geom_ribbon(aes(x = date, ymin = lb, ymax = ub), alpha = 0.2) +
+ geom_ribbon(aes(x = date, ymin = lb_50th, ymax = ub_50th), alpha = 0.2) +
+ geom_hline(aes(yintercept = 1), linetype = "dashed") +
+ ylab("State R(t)") +
+ xlab("") + # scale_y_continuous(trans = "log") +
+ # coord_cartesian(ylim = c(0.5, 2.2))+
+ theme_bw()
+
+ggplot(r_site) +
+ geom_line(aes(x = date, y = median_Rt, color = as.factor(site_name))) +
+ geom_ribbon(aes(x = date, ymin = lb, ymax = ub, fill = as.factor(site_name)), alpha = 0.2) +
+ geom_ribbon(
+ aes(
+ x = date,
+ ymin = lb_50th, ymax = ub_50th,
+ fill = as.factor(site_name)
+ ),
+ alpha = 0.2
+ ) +
+ geom_hline(aes(yintercept = 1), linetype = "dashed") +
+ facet_wrap(~site_name) +
+ ylab("Site R(t)") +
+ xlab("") +
+ theme_bw()
+
+sites_to_display <- train_data |>
+ dplyr::select(site) |>
+ dplyr::filter(!is.na(site)) |>
+ unique() |>
+ pull() |>
+ head(2)
+
+ggplot(r_site %>% filter(site %in% sites_to_display)) +
+ geom_line(aes(x = date, y = median_Rt, color = as.factor(site_name)),
+ show.legend = FALSE
+ ) +
+ geom_ribbon(aes(x = date, ymin = lb, ymax = ub, fill = as.factor(site_name)),
+ alpha = 0.2, show.legend = FALSE
+ ) +
+ geom_ribbon(aes(x = date, ymin = lb_50th, ymax = ub_50th, fill = as.factor(site_name)),
+ alpha = 0.2, show.legend = FALSE
+ ) +
+ geom_hline(aes(yintercept = 1), linetype = "dashed") +
+ facet_wrap(~site_name) +
+ ylab("Site R(t)") +
+ xlab("") +
+ theme_bw()
+
+ggplot() +
+ geom_line(
+ data = r_site %>% filter(site %in% sites_to_display[1]),
+ aes(x = date, y = median_Rt), color = "darkorange2",
+ show.legend = FALSE
+ ) +
+ geom_ribbon(
+ data = r_site %>% filter(site %in% sites_to_display[1]),
+ aes(x = date, ymin = lb, ymax = ub), fill = "darkorange2",
+ alpha = 0.2, show.legend = FALSE
+ ) +
+ geom_ribbon(
+ data = r_site %>% filter(site %in% sites_to_display[1]),
+ aes(x = date, ymin = lb_50th, ymax = ub_50th), fill = "darkorange2",
+ alpha = 0.2, show.legend = FALSE
+ ) +
+ geom_line(
+ data = r_site %>% filter(site %in% sites_to_display[2]),
+ aes(x = date, y = median_Rt), color = "darkgreen",
+ show.legend = FALSE
+ ) +
+ geom_ribbon(
+ data = r_site %>% filter(site %in% sites_to_display[2]),
+ aes(x = date, ymin = lb, ymax = ub), fill = "darkgreen",
+ alpha = 0.2, show.legend = FALSE
+ ) +
+ geom_ribbon(
+ data = r_site %>% filter(site %in% sites_to_display[2]),
+ aes(x = date, ymin = lb_50th, ymax = ub_50th), fill = "darkgreen",
+ alpha = 0.2, show.legend = FALSE
+ ) +
+ geom_hline(aes(yintercept = 1), linetype = "dashed") +
+ ylab("Site R(t)") +
+ xlab("") +
+ scale_y_continuous(trans = "log") +
+ coord_cartesian(ylim = c(0.5, 2.2)) +
+ theme_bw()
+```
+
+Time-varying IHR
+```{r}
+p_hosp <- all_draws %>%
+ spread_draws(p_hosp[t]) %>%
+ sample_draws(ndraws = 20) %>%
+ rename(value = p_hosp) %>%
+ mutate(
+ draw = `.draw`,
+ name = "p_hosp"
+ ) %>%
+ select(name, t, value, draw)
+
+ggplot(p_hosp) +
+ geom_line(aes(x = t, y = value, group = draw), size = 0.1) +
+ xlab("Time (days)") +
+ ylab("IHR(t)")
+
+p_hosp_mean <- all_draws %>%
+ spread_draws(p_hosp_mean) %>%
+ mutate(draw = row_number()) %>%
+ mutate(
+ name = "p_hosp_mean",
+ ) %>%
+ mutate(value = plogis(p_hosp_mean)) %>%
+ select(name, value, draw)
+ggplot(p_hosp_mean) +
+ aes(x = value) +
+ stat_halfeye() +
+ xlab("Mean of IHR")
+
+p_hosp_w_sd <- all_draws %>%
+ spread_draws(p_hosp_w_sd) %>%
+ mutate(draw = row_number()) %>%
+ mutate(
+ name = "p_hosp_w_sd",
+ ) %>%
+ rename(value = p_hosp_w_sd) %>%
+ select(name, value, draw)
+ggplot(p_hosp_w_sd) +
+ aes(x = value) +
+ stat_halfeye() +
+ xlab("Stdev of step size of IHR(t)")
+```
+Hospital admissions
+```{r}
+hosp_state <- all_draws %>%
+ spread_draws(pred_hosp[t]) %>%
+ rename(value = pred_hosp) %>%
+ mutate(
+ draw = `.draw`,
+ name = "pred_hosp"
+ ) %>%
+ select(name, t, value, draw) %>%
+ group_by(t, name) %>%
+ summarise(
+ median_hosp = quantile(value, 0.5, na.rm = TRUE),
+ lb = quantile(value, 0.025, na.rm = TRUE),
+ ub = quantile(value, 0.975, na.rm = TRUE),
+ ub_50th = quantile(value, 0.75, na.rm = TRUE),
+ lb_50th = quantile(value, 0.25, na.rm = TRUE)
+ ) %>%
+ left_join(train_data %>% select(forecast_date, daily_hosp_admits, t, date) %>% distinct())
+
+
+ggplot(hosp_state) +
+ geom_line(aes(x = date, y = median_hosp)) +
+ geom_point(aes(x = date, y = daily_hosp_admits)) +
+ geom_ribbon(aes(x = date, ymin = lb, ymax = ub), alpha = 0.2) +
+ geom_ribbon(aes(x = date, ymin = lb_50th, ymax = ub_50th), alpha = 0.2) +
+ geom_vline(aes(xintercept = forecast_date), linetype = "dashed") +
+ ylab("State-level hospital admissions") +
+ xlab("") +
+ theme_bw()
+
+hosp_state_draws <- all_draws %>%
+ spread_draws(pred_hosp[t]) %>%
+ sample_draws(ndraws = 100) %>%
+ rename(value = pred_hosp) %>%
+ mutate(
+ draw = `.draw`,
+ name = "pred_hosp"
+ ) %>%
+ select(name, t, value, draw) %>%
+ left_join(train_data %>% select(forecast_date, daily_hosp_admits, t, date) %>% distinct())
+
+ggplot(hosp_state_draws) +
+ geom_line(aes(x = date, y = value, group = draw), size = 0.1, alpha = 0.1) +
+ geom_point(aes(x = date, y = daily_hosp_admits)) +
+ geom_vline(aes(xintercept = forecast_date), linetype = "dashed") +
+ ylab("State-level hospital admissions") +
+ xlab("") +
+ theme_bw()
+```
+Site-lab concentrations
+
+```{r}
+lab_site_map <- train_data %>%
+ dplyr::select(lab_wwtp_unique_id, lab_site_index, site, lab) %>%
+ dplyr::mutate(
+ lab_site = glue::glue("Site: {site} Lab: {lab}")
+ ) %>%
+ dplyr::distinct()
+
+conc <- all_draws %>%
+ spread_draws(pred_ww[lab_site_index, t]) %>%
+ rename(value = pred_ww) %>%
+ mutate(
+ draw = `.draw`,
+ name = "log_conc"
+ ) %>%
+ select(name, t, value, lab_site_index, draw) %>%
+ group_by(t, lab_site_index, name) %>%
+ summarise(
+ median_conc = quantile(value, 0.5, na.rm = TRUE),
+ lb = quantile(value, 0.025, na.rm = TRUE),
+ ub = quantile(value, 0.975, na.rm = TRUE),
+ ub_50th = quantile(value, 0.75, na.rm = TRUE),
+ lb_50th = quantile(value, 0.25, na.rm = TRUE)
+ ) %>%
+ left_join(lab_site_map) %>%
+ left_join(
+ train_data %>%
+ select(forecast_date, t, date) %>%
+ distinct()
+ ) %>%
+ left_join(
+ train_data %>% select(
+ ww, t, lab_site_index, ww_pop, below_LOD, lod_sewage,
+ flag_as_ww_outlier
+ ),
+ by = c("t", "lab_site_index")
+ )
+
+ggplot(conc) +
+ geom_line(aes(x = date, y = median_conc, color = lab_site), show.legend = FALSE) +
+ geom_ribbon(aes(x = date, ymin = lb, ymax = ub, fill = lab_site),
+ alpha = 0.2, show.legend = FALSE
+ ) +
+ geom_point(aes(x = date, y = log(ww)), show.legend = FALSE) +
+ geom_ribbon(aes(x = date, ymin = lb_50th, ymax = ub_50th, fill = lab_site),
+ alpha = 0.2, show.legend = FALSE
+ ) +
+ geom_vline(aes(xintercept = forecast_date), linetype = "dashed") +
+ facet_wrap(~lab_site) +
+ ylab("Site Concentration(t)") +
+ xlab("") +
+ theme_bw()
+
+ggplot(data = conc %>% filter(site == sites_to_display[1])) +
+ geom_line(aes(x = date, y = median_conc), color = "darkorange2") +
+ geom_ribbon(aes(x = date, ymin = lb, ymax = ub), fill = "darkorange2", alpha = 0.2) +
+ geom_point(aes(x = date, y = log(ww))) +
+ geom_ribbon(aes(x = date, ymin = lb_50th, ymax = ub_50th), fill = "darkorange2", alpha = 0.2) +
+ geom_vline(aes(xintercept = forecast_date), linetype = "dashed") +
+ facet_wrap(~lab_site) +
+ ylab("Site-level genome copies per mL") +
+ xlab("") +
+ theme_bw()
+
+ggplot(data = conc %>% filter(site == sites_to_display[2])) +
+ geom_line(aes(x = date, y = median_conc), color = "darkgreen") +
+ geom_ribbon(aes(x = date, ymin = lb, ymax = ub), fill = "darkgreen", alpha = 0.2) +
+ geom_point(aes(x = date, y = log(ww))) +
+ geom_ribbon(aes(x = date, ymin = lb_50th, ymax = ub_50th), fill = "darkgreen", alpha = 0.2) +
+ geom_vline(aes(xintercept = forecast_date), linetype = "dashed") +
+ facet_wrap(~lab_site) +
+ ylab("Site-level genome copies per mL") +
+ xlab("") +
+ theme_bw()
+```
+
+
+
+```{r}
+log_r_state <- all_draws %>%
+ spread_draws(log_r_mu_t_in_weeks[t]) %>%
+ rename(value = log_r_mu_t_in_weeks) %>%
+ mutate(
+ draw = `.draw`,
+ name = "log_r_state"
+ ) %>%
+ select(name, t, value, draw)
+
+log_r_site <- all_draws %>%
+ spread_draws(log_r_site_t_in_weeks[t]) %>%
+ rename(value = log_r_site_t_in_weeks) %>%
+ mutate(
+ draw = `.draw`,
+ name = "log_r_state"
+ ) %>%
+ select(name, t, value, draw)
+
+sigma_rt <- all_draws %>%
+ spread_draws(sigma_rt) %>%
+ sample_draws(ndraws = 100) %>%
+ mutate(draw = row_number()) %>%
+ mutate(
+ name = "sigma_rt",
+ ) %>%
+ rename(value = sigma_rt) %>%
+ select(name, value, draw)
+ggplot(sigma_rt) +
+ aes(x = value) +
+ stat_halfeye() +
+ xlab("Standard deviation between site and state level R(t)s")
+
+infection_feedback <- all_draws %>%
+ spread_draws(infection_feedback) %>%
+ sample_draws(ndraws = 100) %>%
+ mutate(draw = row_number()) %>%
+ mutate(
+ name = "infection_feedback",
+ ) %>%
+ rename(value = infection_feedback) %>%
+ select(name, value, draw)
+ggplot(infection_feedback) +
+ aes(x = value) +
+ stat_halfeye() +
+ xlab("Infection feedback")
+
+autoreg_rt_site <- all_draws %>%
+ spread_draws(autoreg_rt_site) %>%
+ sample_draws(ndraws = 100) %>%
+ mutate(draw = row_number()) %>%
+ mutate(
+ name = "autoreg_rt_site",
+ ) %>%
+ rename(value = autoreg_rt_site) %>%
+ select(name, value, draw)
+ggplot(autoreg_rt_site) +
+ aes(x = value) +
+ stat_halfeye() +
+ xlab("AR term on Rt between sites and state mean")
+
+autoreg_rt <- all_draws %>%
+ spread_draws(autoreg_rt) %>%
+ sample_draws(ndraws = 100) %>%
+ mutate(draw = row_number()) %>%
+ mutate(
+ name = "autoreg_rt",
+ ) %>%
+ rename(value = autoreg_rt) %>%
+ select(name, value, draw)
+ggplot(autoreg_rt) +
+ aes(x = value) +
+ stat_halfeye() +
+ xlab("AR term on Rt between previous time step R(t)")
+```
+
+
+# Look at the generated quantitities
+
+```{r}
+all_draws <- fit_dynamic_rt$draws()
+
+# Dataframes with ndraws (long format)
+hosp_draws <- all_draws %>%
+ spread_draws(pred_hosp[t]) %>%
+ rename(value = pred_hosp) %>%
+ mutate(
+ draw = row_number(),
+ name = "pred_hosp"
+ ) %>%
+ select(name, t, value, draw)
+
+
+site_map <- train_data %>%
+ select(site, site_index, ww_pop) %>%
+ distinct()
+lab_site_map <- train_data %>%
+ select(lab_wwtp_unique_id, lab_site_index, site) %>%
+ distinct()
+
+ww_draws <- all_draws %>%
+ spread_draws(pred_ww[lab_site_index, t]) %>%
+ rename(value = pred_ww) %>%
+ group_by(lab_site_index, t) %>%
+ mutate(
+ draw = row_number(),
+ name = "pred_ww",
+ value = exp(value)
+ ) %>%
+ select(name, lab_site_index, t, value, draw)
+
+sampled_draws <- sample(length(unique(ww_draws$draw)), 100)
+
+# Gather R(t)
+site_level_rt <- all_draws %>%
+ spread_draws(r_site_t[site_index, t]) %>%
+ rename(value = r_site_t) %>%
+ group_by(site_index, t) %>%
+ mutate(
+ draw = row_number(),
+ name = "R_site_t"
+ ) %>%
+ ungroup() %>%
+ select(name, site_index, t, value, draw) %>%
+ filter(draw %in% sampled_draws) %>%
+ left_join(site_map, by = "site_index") %>%
+ left_join(
+ train_data %>%
+ select(-ww, -site, -site_index, -ww_pop) %>%
+ distinct()
+ ) %>%
+ left_join(
+ train_data %>%
+ select(ww, site_index, t),
+ by = c("t", "site_index")
+ )
+
+site_level_rt_summary <- site_level_rt %>%
+ group_by(t, site_index) %>%
+ summarise(
+ site_level_rt_median = quantile(value, 0.5),
+ site_level_rt_lb = quantile(value, 0.025),
+ site_level_rt_ub = quantile(value, 0.975)
+ ) %>%
+ left_join(site_map, by = "site_index") %>%
+ left_join(
+ train_data %>%
+ select(-ww, -site, -site_index, -ww_pop) %>%
+ distinct()
+ ) %>%
+ left_join(train_data %>% select(ww, site_index, t), by = c("t", "site_index"))
+
+
+state_rt <- all_draws %>%
+ spread_draws(rt[t]) %>%
+ rename(value = rt) %>%
+ group_by(t) %>%
+ mutate(
+ draw = row_number(),
+ name = "rt"
+ ) %>%
+ select(name, t, value, draw) %>%
+ left_join(train_data %>% select(forecast_date, t, date, period) %>% distinct())
+
+state_rt_summary <- state_rt %>%
+ group_by(t) %>%
+ summarise(
+ exp_rt_median = quantile(value, 0.5),
+ exp_rt_lb = quantile(value, 0.025),
+ exp_rt_ub = quantile(value, 0.975)
+ ) %>%
+ left_join(
+ train_data %>%
+ select(forecast_date, t, date) %>%
+ distinct()
+ )
+
+state_rt <- state_rt %>%
+ filter(draw %in% sampled_draws)
+
+
+ggplot(site_level_rt_summary %>% filter(date <= forecast_date)) +
+ geom_line(aes(x = date, y = site_level_rt_median, color = as.factor(site))) +
+ geom_ribbon(
+ aes(
+ x = date,
+ ymin = site_level_rt_lb, ymax = site_level_rt_ub,
+ fill = as.factor(site)
+ ),
+ alpha = 0.2
+ ) +
+ geom_line(
+ data = state_rt_summary %>% filter(date <= forecast_date),
+ aes(x = date, y = exp_rt_median), color = "black"
+ ) +
+ geom_ribbon(
+ data = state_rt_summary %>% filter(date <= forecast_date),
+ aes(
+ x = date, ymin = exp_rt_lb,
+ ymax = exp_rt_ub
+ ), fill = "black", alpha = 0.1
+ ) +
+ facet_wrap(~site) +
+ ylab("Estimated R(t)") +
+ xlab("") +
+ ggtitle("Estimated site level R(t)") +
+ theme_bw()
+
+
+# Cross-sectional R(t) d
+d <- 3
+site_level_rt_end <- site_level_rt %>%
+ filter(date == forecast_date - days(d)) %>%
+ left_join(
+ site_level_rt %>%
+ filter(date == forecast_date - days(d)) %>%
+ group_by(site) %>%
+ summarise(median_rt = quantile(value, 0.5)),
+ by = "site"
+ )
+state_rt_end <- state_rt %>%
+ filter(date == forecast_date - days(d)) %>%
+ left_join(
+ state_rt %>%
+ filter(date == forecast_date - days(d)) %>%
+ summarise(median_rt = quantile(value, 0.5)),
+ by = "t"
+ )
+
+
+ggplot() +
+ geom_density(
+ data = site_level_rt_end,
+ aes(x = value, fill = as.factor(site), group = as.factor(site)), alpha = 0.3
+ ) +
+ geom_vline(data = site_level_rt_end, aes(xintercept = median_rt, color = as.factor(site))) +
+ geom_density(
+ data = state_rt_end,
+ aes(x = value), fill = "black", alpha = 0.3
+ ) +
+ facet_wrap(~site) +
+ geom_vline(data = state_rt_end, aes(xintercept = median_rt), color = "black") +
+ ggtitle("R(t) estimates at forecast date by site and overall")
+```
+
+
+
+
+Site-level ww genome copies per mL vs observations
+```{r}
+ww_draws_w_data <- ww_draws %>%
+ filter(draw %in% sampled_draws) %>%
+ left_join(lab_site_map, by = "lab_site_index") %>%
+ left_join(train_data %>% select(
+ date, location, pop, daily_hosp_admits_for_eval,
+ daily_hosp_admits, t, hosp_reporting_delay,
+ period,
+ forecast_date, day_of_week, include_ww
+ ) %>%
+ distinct()) %>%
+ left_join(
+ train_data %>% select(
+ ww, t, lab_site_index, ww_pop, below_LOD, lod_sewage,
+ flag_as_ww_outlier
+ ),
+ by = c("t", "lab_site_index")
+ )
+
+sites_to_plot <- unique(lab_site_map$lab_wwtp_unique_id)
+
+ggplot(ww_draws_w_data %>% filter(
+ date <= forecast_date + days(7),
+ lab_wwtp_unique_id %in% sites_to_plot
+)) +
+ geom_line(
+ aes(
+ x = date, y = value, group = draw,
+ color = as.factor(lab_wwtp_unique_id)
+ ),
+ linewidth = 0.1, alpha = 0.1, show.legend = FALSE
+ ) +
+ geom_line(
+ data = ww_draws_w_data %>% filter(
+ period != "forecast",
+ lab_wwtp_unique_id %in% sites_to_plot
+ ),
+ aes(
+ x = date, y = value, group = draw,
+ color = as.factor(lab_wwtp_unique_id)
+ ),
+ linewidth = 0.1, alpha = 0.3, show.legend = FALSE
+ ) +
+ geom_point(
+ data = ww_draws_w_data %>% filter(
+ period != "forecast",
+ lab_wwtp_unique_id %in% sites_to_plot
+ ),
+ aes(x = date, y = lod_sewage, group = draw), color = "navy"
+ ) +
+ geom_point(aes(x = date, y = ww), size = 1, shape = 21) +
+ geom_point(
+ data = ww_draws_w_data %>% filter(flag_as_ww_outlier == 1),
+ aes(x = date, y = ww), size = 1.5, shape = 21, color = "purple"
+ ) +
+ geom_point(
+ data = ww_draws_w_data %>% filter(below_LOD == 1),
+ aes(x = date, y = ww), size = 1.5, shape = 21, color = "red"
+ ) +
+ geom_point(
+ data = ww_draws_w_data %>% filter(
+ period != "forecast",
+ lab_wwtp_unique_id %in% sites_to_plot
+ ),
+ aes(x = date, y = ww),
+ fill = "black", size = 1, shape = 21
+ ) +
+ geom_vline(aes(xintercept = forecast_date), linetype = "dashed") +
+ facet_wrap(~lab_wwtp_unique_id, scales = "free") +
+ theme_bw() +
+ xlab("") +
+ ylab("Genome copies/mL")
+
+ggplot(ww_draws_w_data %>% filter(
+ date <= forecast_date + days(7),
+ lab_wwtp_unique_id %in% sites_to_plot
+)) +
+ geom_line(
+ aes(
+ x = date, y = log(value), group = draw,
+ color = as.factor(lab_wwtp_unique_id)
+ ),
+ linewidth = 0.1, alpha = 0.1, show.legend = FALSE
+ ) +
+ geom_line(
+ data = ww_draws_w_data %>% filter(
+ period != "forecast",
+ lab_wwtp_unique_id %in% sites_to_plot
+ ),
+ aes(
+ x = date, y = log(value), group = draw,
+ color = as.factor(lab_wwtp_unique_id)
+ ),
+ linewidth = 0.1, alpha = 0.3, show.legend = FALSE
+ ) +
+ geom_point(
+ data = ww_draws_w_data %>% filter(
+ period != "forecast",
+ lab_wwtp_unique_id %in% sites_to_plot
+ ),
+ aes(x = date, y = log(lod_sewage), group = draw), color = "navy"
+ ) +
+ geom_point(aes(x = date, y = log(ww)), size = 1, shape = 21) +
+ geom_point(
+ data = ww_draws_w_data %>% filter(flag_as_ww_outlier == 1),
+ aes(x = date, y = log(ww)), size = 1.5, shape = 21, color = "purple"
+ ) +
+ geom_point(
+ data = ww_draws_w_data %>% filter(below_LOD == 1),
+ aes(x = date, y = log(ww)), size = 1.5, shape = 21, color = "red"
+ ) +
+ geom_point(
+ data = ww_draws_w_data %>% filter(
+ period != "forecast",
+ lab_wwtp_unique_id %in% sites_to_plot
+ ),
+ aes(x = date, y = log(ww)),
+ fill = "black", size = 1, shape = 21
+ ) +
+ geom_vline(aes(xintercept = forecast_date), linetype = "dashed") +
+ facet_wrap(~lab_wwtp_unique_id, scales = "free") +
+ theme_bw() +
+ xlab("") +
+ ylab("log(genome copies/mL)")
+
+sites_to_plot <- ww_draws_w_data %>%
+ ungroup() %>%
+ select(site) %>%
+ filter(!is.na(site)) %>%
+ unique() %>%
+ pull()
+ggplot(ww_draws_w_data %>% filter(
+ date <= forecast_date + days(7),
+ site %in% sites_to_plot
+)) +
+ geom_line(
+ aes(
+ x = date, y = log(value), group = draw,
+ color = as.factor(site)
+ ),
+ linewidth = 0.1, alpha = 0.1, show.legend = FALSE
+ ) +
+ geom_line(
+ data = ww_draws_w_data %>% filter(
+ period != "forecast",
+ site %in% sites_to_plot
+ ),
+ aes(
+ x = date, y = log(value), group = draw,
+ color = as.factor(site)
+ ),
+ linewidth = 0.1, alpha = 0.3, show.legend = FALSE
+ ) +
+ geom_point(aes(x = date, y = log(ww)), size = 1, shape = 21) +
+ geom_point(
+ data = ww_draws_w_data %>% filter(flag_as_ww_outlier == 1),
+ aes(x = date, y = log(ww)), size = 1.5, shape = 21, color = "purple"
+ ) +
+ geom_point(
+ data = ww_draws_w_data %>% filter(below_LOD == 1),
+ aes(x = date, y = log(ww)), size = 1.5, shape = 21, color = "red"
+ ) +
+ geom_point(
+ data = ww_draws_w_data %>% filter(
+ period != "forecast",
+ site %in% sites_to_plot
+ ),
+ aes(x = date, y = log(ww)),
+ fill = "black", size = 1, shape = 21
+ ) +
+ geom_vline(aes(xintercept = forecast_date), linetype = "dashed") +
+ facet_wrap(~site, scales = "free") +
+ theme_bw() +
+ xlab("") +
+ ylab("log(genome copies/mL)") +
+ ggtitle("Observed and estimated site-level wastewater concentrations")
+```
+
+```{r}
+ww_site_modifier_draws <- all_draws %>%
+ spread_draws(ww_site_mod[site_index]) %>%
+ # sample_draws(ndraws = n_draws) %>%
+ rename(value = ww_site_mod) %>%
+ mutate(
+ draw = `.draw`,
+ name = "ww site multiplier"
+ ) %>%
+ select(name, site_index, value, draw) %>%
+ left_join(site_map)
+
+i0_ww_site <- all_draws %>%
+ spread_draws(i0_site_over_n[site_index]) %>%
+ rename(value = i0_site_over_n) %>%
+ mutate(
+ draw = `.draw`,
+ name = "i0/N ww site"
+ ) %>%
+ select(name, site_index, value, draw) %>%
+ left_join(site_map)
+
+growth_site <- all_draws %>%
+ spread_draws(growth_site[site_index]) %>%
+ rename(value = growth_site) %>%
+ mutate(
+ draw = `.draw`,
+ name = "initial growth rate"
+ ) %>%
+ select(name, site_index, value, draw) %>%
+ left_join(site_map)
+
+i_at_site <- all_draws %>%
+ spread_draws(site_i0_over_n_start[site_index]) %>%
+ rename(value = site_i0_over_n_start) %>%
+ mutate(
+ draw = `.draw`,
+ name = "site i0 over N at start"
+ ) %>%
+ select(name, site_index, value, draw) %>%
+ left_join(site_map)
+
+ggplot(ww_site_modifier_draws) +
+ geom_density(aes(
+ x = exp(value), fill = as.factor(site),
+ group = site
+ ), alpha = 0.5) +
+ theme_bw() +
+ xlab("Site level multiplier")
+
+ggplot(i0_ww_site) +
+ geom_density(aes(
+ x = value, fill = as.factor(site),
+ group = site
+ ), alpha = 0.5) +
+ theme_bw() +
+ xlab("log(i0/N) in each site")
+
+ggplot(i_at_site) +
+ geom_density(aes(
+ x = log(value), fill = as.factor(site),
+ group = site
+ ), alpha = 0.5) +
+ theme_bw() +
+ xlab("log(i0/N) in each site at start of R(t)")
+
+ggplot(growth_site) +
+ geom_density(aes(
+ x = value, fill = as.factor(site),
+ group = site
+ ), alpha = 0.5) +
+ theme_bw() +
+ xlab("Initial growth rate in each site")
+
+ww_site_sigma_draws <- all_draws %>%
+ spread_draws(sigma_ww_site[site_index]) %>%
+ # sample_draws(ndraws = n_draws) %>%
+ rename(value = sigma_ww_site) %>%
+ mutate(
+ draw = row_number(),
+ name = "ww site sigma"
+ ) %>%
+ select(name, site_index, value, draw) %>%
+ left_join(site_map)
+
+ww_site_sigma_summary <- ww_site_sigma_draws %>%
+ group_by(site) %>%
+ summarise(
+ sigma_median = quantile(value, 0.5),
+ sigma_lb = quantile(value, 0.025),
+ sigma_ub = quantile(value, 0.975)
+ ) %>%
+ left_join(ww_site_sigma_draws %>% select(name, site, ww_pop) %>% distinct(),
+ by = "site"
+ )
+
+
+ggplot(ww_site_sigma_draws) +
+ geom_density(aes(
+ x = value, fill = as.factor(site),
+ group = site
+ ), alpha = 0.5) +
+ scale_x_log10() +
+ theme_bw() +
+ xlab("Site level standard deviation")
+
+theme_set(theme_ggdist())
+ggplot(data = ww_site_sigma_draws, aes(y = as.factor(ww_pop), x = value)) +
+ stat_halfeye() +
+ ylab("Catchment area population") +
+ xlab("Phi") +
+ ggtitle("WW site-level stdev vs population size")
+
+
+ggplot(ww_site_sigma_summary) +
+ geom_linerange(aes(
+ x = ww_pop, ymin = sigma_lb,
+ ymax = sigma_ub
+ )) +
+ geom_point(aes(x = ww_pop, y = sigma_median)) +
+ scale_x_log10() +
+ theme_bw() +
+ xlab("Population served by WW catchment") +
+ ylab("phi of WW site (~1/error)")
+
+hosp_draws_sampled <- hosp_draws %>%
+ filter(draw %in% sampled_draws) %>%
+ left_join(train_data, by = c("t"))
+
+ggplot(hosp_draws_sampled %>% filter(
+ date <= forecast_date + days(10)
+)) +
+ geom_line(aes(x = date, y = value, group = draw),
+ linewidth = 0.1, alpha = 0.1,
+ color = "darkred", show.legend = FALSE
+ ) +
+ geom_line(
+ data = hosp_draws_sampled %>% filter(
+ period != "forecast"
+ ),
+ aes(x = date, y = value, group = draw),
+ linewidth = 0.1, alpha = 0.3,
+ color = "darkred", show.legend = FALSE
+ ) +
+ geom_point(aes(x = date, y = daily_hosp_admits_for_eval),
+ size = 1, shape = 21
+ ) +
+ geom_point(
+ data = hosp_draws_sampled %>% filter(
+ period != "forecast"
+ ),
+ aes(x = date, y = daily_hosp_admits_for_eval), fill = "black",
+ size = 1, shape = 21
+ ) +
+ geom_vline(aes(xintercept = forecast_date), linetype = "dashed") +
+ theme_bw() +
+ xlab("") +
+ ylab("Hospital admissions")
+```
+
+We now don't expect the sigma inherent to the WW site to be correlated with population size,
+because in theory this is taken care of by the sum of negative binomials in the
+site level expected true genomes (being a sum of the individuals infected at that
+time points dispersion). So the site level phi is now just additional
+variability introduced by the site independent of population shedding.
diff --git a/scratch/sites_sum_to_states_test.Rmd b/scratch/sites_sum_to_states_test.Rmd
new file mode 100644
index 00000000..75fe9b9e
--- /dev/null
+++ b/scratch/sites_sum_to_states_test.Rmd
@@ -0,0 +1,324 @@
+---
+title: "Sites sum to states recover infections"
+author: "Kaitlyn Johnson"
+date: "2024-03-15"
+output: html_document
+---
+
+```{r setup, include=FALSE}
+library(cfaforecastrenewalww)
+library(cmdstanr)
+library(lubridate)
+library(ggplot2)
+library(dplyr)
+library(here)
+library(tidybayes)
+library(scoringutils)
+source(here::here("src/write_config.R"))
+knitr::opts_chunk$set(echo = TRUE)
+cfaforecastrenewalww::setup_secrets(here::here("secrets.yaml"))
+```
+# Motivation
+The purpose of this Rmd is to generate simulated data with known per capita
+infections and hospital admissions using the `generate_simulated_data()`
+function. We want to compare the ability to accurately recover the known
+input values of per capita infections and hospital admissions for two models:
+1. The proposed model formulation (subpopulations sum to state infections,
+which accounts for the sites subpopulation size properly)
+2. The new model formulation (sites are random draws from the state, model
+doesn't know site populaiton size)
+
+# Generate simulated data if needed
+```{r}
+set.seed(1)
+toy_data_and_params <- generate_simulated_data(
+ r_in_weeks = c(
+ rep(1.1, 5), rep(0.9, 5),
+ 1 + 0.007 * 1:5, rep(0.95, 5),
+ rep(1.01, 6)
+ ),
+ n_sites = 4,
+ # Set up sites to cover a relatively
+ # small subset of the state population
+ ww_pop_sites = c(1e6, 5e5, 1e5, 5e4),
+ pop_size = 10e6
+)
+
+
+sim_data_df <- toy_data_and_params$example_df
+param_df <- toy_data_and_params$param_df
+
+saveRDS(sim_data_df, here::here("input", "sim_data_df"))
+```
+
+# Set the output path based on the model you're testing
+```{r}
+output_file_path <- file.path(here::here("output", "tests", "sites_sum_to_states")) # OR
+# output_file_path <- file.path(here::here("output", "tests", "original")) #nolint
+```
+# If running on old model, read in the simulated data because generate data
+is outdated in how it generates data
+
+```{r}
+# Read in the simulated data
+sim_data_df <- readRDS(here::here("input", "sim_data_df"))
+```
+
+# Visualize the data
+```{r}
+ggplot(sim_data_df) +
+ geom_point(
+ aes(
+ x = date, y = exp(log_conc),
+ color = as.factor(lab_wwtp_unique_id)
+ ),
+ show.legend = FALSE
+ ) +
+ geom_point(
+ data = sim_data_df %>% filter(below_LOD == 1),
+ aes(x = date, y = exp(log_conc), color = "red"),
+ show.legend = FALSE
+ ) +
+ geom_hline(aes(yintercept = exp(lod_sewage)), linetype = "dashed") +
+ facet_wrap(~lab_wwtp_unique_id, scales = "free") +
+ geom_vline(aes(xintercept = forecast_date), linetype = "dashed") +
+ xlab("") +
+ ylab("Genome copies/mL") +
+ ggtitle("Lab-site level wastewater concentration") +
+ theme_bw()
+
+ggplot(sim_data_df) +
+ geom_point(aes(x = date, y = daily_hosp_admits_for_eval),
+ shape = 21, color = "black", fill = "white"
+ ) +
+ geom_point(aes(x = date, y = daily_hosp_admits)) +
+ geom_vline(aes(xintercept = forecast_date), linetype = "dashed") +
+ xlab("") +
+ ylab("Daily hospital admissions") +
+ ggtitle("State level hospital admissions") +
+ theme_bw()
+
+ggplot(sim_data_df) +
+ geom_point(aes(x = date, y = inf_per_capita), color = "black") +
+ geom_vline(aes(xintercept = forecast_date), linetype = "dashed") +
+ xlab("") +
+ ylab("State infections per capita") +
+ ggtitle("True state infections per capita") +
+ theme_bw()
+```
+# Pre-process fake data to fit the model
+```{r}
+params <- get_params(
+ system.file("extdata", "example_params.toml",
+ package = "cfaforecastrenewalww"
+ )
+)
+params
+
+forecast_date <- sim_data_df %>%
+ pull(forecast_date) %>%
+ unique()
+forecast_time <- as.integer(max(sim_data_df$date) - forecast_date)
+
+# Assign the model that we want to fit to so that we grab the correct
+# model and initialization list
+model_type <- "site-level infection dynamics"
+model_file_path <- get_model_file_path(model_type)
+# Compile the model
+model <- compile_model(file.path(here(
+ model_file_path
+)))
+
+
+# Function calls for linear scale ww data
+train_data_raw <- sim_data_df %>%
+ mutate(
+ ww = exp(log_conc),
+ period = case_when(
+ !is.na(daily_hosp_admits) ~ "calibration",
+ is.na(daily_hosp_admits) & date <= forecast_date ~
+ "nowcast",
+ TRUE ~ "forecast"
+ ),
+ include_ww = 1,
+ site_index = site,
+ lab_site_index = lab_wwtp_unique_id
+ )
+
+# Apply outliers to data
+train_data <- flag_ww_outliers(train_data_raw)
+
+# Get the generation interval and time from infection to hospital admission
+# delay distribution to pass to stan.
+# Use the same values as we do in the
+# generation of these vectors in the generate simulated data function....
+# See Song Woo Park et al
+# https://www.medrxiv.org/content/10.1101/2024.01.12.24301247v1 for why
+# we use a double-censored pmf here
+generation_interval <- simulate_double_censored_pmf(
+ max = params$gt_max, meanlog = params$mu_gi,
+ sdlog = params$sigma_gi, fun_dist = rlnorm, n = 5e6
+) %>% drop_first_and_renormalize()
+
+inc <- make_incubation_period_pmf(
+ params$backward_scale, params$backward_shape, params$r
+)
+sym_to_hosp <- make_hospital_onset_delay_pmf(
+ params$neg_binom_mu,
+ params$neg_binom_size
+)
+inf_to_hosp <- make_reporting_delay_pmf(inc, sym_to_hosp)
+
+# Format as a list for stan
+stan_data <- get_stan_data_site_level_model(
+ train_data,
+ params,
+ forecast_date,
+ forecast_time,
+ model_type = model_type,
+ generation_interval = generation_interval,
+ inf_to_hosp = inf_to_hosp,
+ infection_feedback_pmf = generation_interval
+)
+
+init_fun <- function() {
+ site_level_inf_inits(train_data, params, stan_data)
+}
+```
+# Model fit
+```{r}
+fit_dynamic_rt <- model$sample(
+ data = stan_data,
+ seed = 123,
+ init = init_fun,
+ iter_sampling = 500,
+ iter_warmup = 250,
+ chains = 4,
+ parallel_chains = 4
+)
+```
+
+```{r}
+all_draws <- fit_dynamic_rt$draws()
+
+# Predicted observed hospital admissions
+exp_obs_hosp <- all_draws %>%
+ spread_draws(pred_hosp[t]) %>%
+ select(pred_hosp, `.draw`, t) %>%
+ rename(draw = `.draw`)
+
+# Estimate of latent incident inf per capita
+est_inf <- all_draws %>%
+ spread_draws(state_inf_per_capita[t]) %>%
+ filter(t > params$uot) %>%
+ mutate(t = t - params$uot) %>%
+ select(state_inf_per_capita, `.draw`, t) %>%
+ rename(draw = `.draw`)
+
+
+
+input_df <- sim_data_df %>%
+ select(
+ t, date, daily_hosp_admits_for_eval,
+ inf_per_capita, pop
+ ) %>%
+ distinct()
+
+inf <- est_inf %>% left_join(
+ input_df %>%
+ select(t, date, inf_per_capita),
+ by = "t"
+)
+hosp <- exp_obs_hosp %>% left_join(
+ input_df %>%
+ select(
+ t, date,
+ daily_hosp_admits_for_eval
+ ),
+ by = "t"
+)
+```
+# Score and plot estimates vs known truth data
+```{r}
+# Scoring
+inf_scores <- inf %>%
+ rename(
+ true_value = inf_per_capita,
+ prediction = state_inf_per_capita,
+ sample = draw
+ ) %>%
+ mutate(model = "ww_model") %>%
+ scoringutils::score()
+summarized_inf_scores <- inf_scores %>% scoringutils::summarise_scores(by = "model")
+print(summarized_inf_scores)
+
+hosp_scores <- hosp %>%
+ rename(
+ true_value = daily_hosp_admits_for_eval,
+ prediction = pred_hosp,
+ sample = draw
+ ) %>%
+ mutate(model = "ww_model") %>%
+ scoringutils::score()
+summarized_hosp_scores <- hosp_scores %>% scoringutils::summarise_scores(by = "model")
+print(summarized_hosp_scores)
+
+samples <- sample(1:max(hosp$draw), 100)
+inf_subset <- inf %>% filter(draw %in% samples) # sample the draws for plotting
+
+inf_plot <- ggplot(inf_subset) +
+ geom_line(aes(x = date, y = state_inf_per_capita, group = draw),
+ color = "gray", alpha = 0.2, size = 0.5
+ ) +
+ geom_point(aes(x = date, y = inf_per_capita)) +
+ xlab("") +
+ ylab("Infections per capita") +
+ theme_bw()
+inf_plot
+
+samples <- sample(1:max(hosp$draw), 100)
+hosp_subset <- hosp %>% filter(draw %in% samples) # sample the draws for plotting
+
+hosp_plot <- ggplot(hosp_subset) +
+ geom_line(aes(x = date, y = pred_hosp, group = draw),
+ color = "darkred", alpha = 0.2, size = 0.5
+ ) +
+ geom_point(aes(x = date, y = daily_hosp_admits_for_eval)) +
+ xlab("") +
+ ylab("Daily hospital admissions") +
+ theme_bw()
+hosp_plot
+```
+
+# Save the scores and figures for model
+Output file path will be "sites_sum_to_states" unless run the laod in of the sim data
+which should be run from branch in an old model
+```{r}
+create_dir(output_file_path)
+write.csv(summarized_inf_scores, here::here(output_file_path, "inf_scores.csv"),
+ row.names = FALSE
+)
+write.csv(summarized_hosp_scores, here::here(output_file_path, "hosp_scores.csv"),
+ row.names = FALSE
+)
+ggsave(
+ file.path(here::here(
+ output_file_path,
+ "inf_per_capita.png"
+ )),
+ plot = inf_plot,
+ width = 8,
+ height = 5,
+ bg = "white"
+)
+ggsave(
+ file.path(here::here(
+ output_file_path,
+ "hosp.png"
+ )),
+ plot = hosp_plot,
+ width = 8,
+ height = 5,
+ bg = "white"
+)
+```
diff --git a/scratch/state_level_model.Rmd b/scratch/state_level_model.Rmd
new file mode 100644
index 00000000..f19aa7cc
--- /dev/null
+++ b/scratch/state_level_model.Rmd
@@ -0,0 +1,433 @@
+---
+title: "State level model"
+author: "Kaitlyn Johnson"
+date: "2023-08-22"
+output: html_document
+---
+
+```{r setup, include = FALSE}
+library(cfaforecastrenewalww)
+library(lubridate)
+library(ggplot2)
+library(dplyr)
+library(tidybayes)
+library(httr)
+library(tidybayes)
+source(here::here("src", "write_config.R"))
+knitr::opts_chunk$set(echo = FALSE, message = FALSE, warning = FALSE)
+cfaforecastrenewalww::setup_secrets(here::here("secrets.yaml"))
+```
+# Motivation
+- This is a first attempt at implementing a renewal approach for inference of hospitalization and wastewater viral concentrations by estimating a time-varying effective reproductive number $R(t)$.
+
+# Approach
+- As a first pass, we aggregate the WW viral concentration (`pcr_target_avg_conc`
+in NWSS data) to a single weekly value for each site (assigned to Wednesday)
+and calculate a population weighted average, thresholding at population sizes of
+300,000 to not overweight large catchment areas
+- For the hospital admissions data, we use state-level total hospital admission data
+that was available as of the forecast date, using the covidcast package and the
+`as_of` input. We impose a 9 day reporting delay to mirror the current data scenario.
+
+
+Note: This Rmd mirrors the single model, single location, single forecast date pipeline
+set up in the `_targets.R` file. We will use the Rmd format to continue to refine the model.
+
+
+# Get the state level data from NWSS
+```{r}
+# Get the config file
+config_written <- write_config(
+ save_config = TRUE,
+ config_path = here::here("input", "config"),
+ location = "AK",
+ prod_run = FALSE,
+ run_id = "test",
+ date_run = ymd("2024-03-06"),
+ model_type = "state-level aggregated wastewater",
+ ww_geo_type = "state",
+ forecast_date = ymd("2024-03-06"),
+ hosp_data_source = "NHSN",
+ population_data_path = here::here("input", "locations.csv"),
+ pull_from_local = FALSE,
+ include_ww = 1,
+ hosp_reporting_delay = 4,
+ ww_data_path = here::here("input", "ww_data", "nwss_data", "2024-03-05.csv")
+)
+
+
+config_vars_ss <- get_config_vals(config_written)
+# WW data from NWSS (using methods described in slide deck, weekly values aggregated
+# to the state level)
+ww_data_raw <- do.call(get_ww_data, config_vars_ss)
+
+# Combined WW data with hospitalizations from vintage dataset to train the model
+train_data <- do.call(get_all_training_data, c(list(ww_data_raw = ww_data_raw), config_vars_ss))
+
+# Get data spine for joining to stan outputs
+min_date <- min(train_data$date)
+max_date <- ymd(config_vars_ss$forecast_date + days(config_vars_ss$forecast_time))
+dates <- seq(
+ from = min_date,
+ to = max_date,
+ by = "days"
+)
+t <- seq(from = 1, to = length(dates), by = 1)
+
+date_df <- data.frame(date = dates, t = t)
+
+plot_data <- plot_combined_data(
+ comb = train_data, figure_file_path = NA,
+ write_files = FALSE
+)
+plot_data
+```
+
+# Model description
+```{r}
+```
+
+We have observed data from 90 days of hospital admissions and 14 weeks of wastewater measurements for a single state. We are going to make all of the simplifying assumptions:
+- the population shedding into the WW is the same as those captured by the hospital admissions
+- the IHR is constant during the period of inference
+- the shedding kinetics and shedding amount are constant during the period of inference
+- for now, we will assume that the number of genomes shed per individual is constant
+across infected individuals, since the population is large enough that we don't expect
+inter-individual variability to add too much noise to the signal
+
+This model builds heavily off of (and in some instances borrows functions from) the
+[EpiNow2 R package](https://github.com/epiforecasts/EpiNow2)
+
+## Renewal Equation Stan Model
+In brief, this model assumes that incident infections $I(t)$ are generated from a vector of $R(t)$ values and an initial number of infections $I(0)$ seeded on day 0:
+$$I(t) = R(t) \sum_{s=1}^{t}I(t-s)g(s)$$
+Where $g(s)$ is the generation interval, which describes the changes over time in infectiousness of an index case over the course of their infection and $R(t)$ describes the number of expected secondary infections of an index case at time $t$. Because the data does not start at the initial seeding event, we assume that prior to the first observation initial infections grow exponentially to give rise to the early observations such that $I(t') = I(0)exp(rt')$ where $t'$ is the "unobserved time" before the first observed data point.
+
+The model estimates $R(t)$ by estimating a weekly random walk such that $R(k) = R(k-1)\eta$, where $\eta ~N(0, \eta_{sd}$, $k$ is the week, and the magnitude of the step size, $eta_{sd}$ is estimated in the model calibration. For the $R(t)$ during the forecast period, we assume $R(t)$ is dampened in proportion to the number of cumulative infections following an SIR approximation, which is implemented when damp_type =1 (see EpiNow2 documentation for details) or we assume that the R(t) is dampened by the current number of infections with a drift term, per Jason's Ebola model. We will likely modify both of these significantly, particularly if we observe consisteny over or underprediction in the forecasts.
+
+
+We model the process of generating two type of observations. First the hospital admissions: Each infectee is assumed to have some (independent) chance of getting hospitalizated $p_{hosp}$, and *if* they will eventually be hospitalized then the probability distribution for days after infection before observation and reporting is $d(t)$. Therefore, conditional on some time series of daily infections generated by one of the models, $I(t)$, the *expected* number of reported hospitalizations on each day $t$ is,
+
+$$\overline{H}[t] = p_{hosp}\sum_{\tau\geq 0}d[\tau] I[t-\tau]$$
+We estimate $p_{hosp}$ and for now are setting the hospital admissions delay distribution $d(t)$.
+
+In a similar fashion for generated the expected WW observations, we assume that each individual follows the same shaped shedding kinetics distribution $S(t)$. We assume the shedding kinetics follows a hinge function
+in log10 space.
+$$\begin{align*}
+S'[t] =
+
+ \begin{cases}
+ \frac{V_{peak}}{t_{peak}}t & t\leq t_{peak} \\
+ V_{peak} + wt_{peak} - wt & t \geq t_{peak} \\
+ \end{cases}
+ \end{align*}
+
+
+$$
+
+Where $V_{peak}$ is the log10 peak viral load that occurs at $t_{peak}$ days since infection onset, and $w$ is the rate at which viral load wanes after $t_{peak}$. We convert to natural scale with $S[t] = 10^{S'[t]}$. The parameters of this hinge function are estimated (with relatively strong priors taken from fecal shedding literature and literature on viral loads in nasal passages). The individual level viral kinetics $S[t]$ of each infected individual are normalized to sum to 1 over the course of the infection, and are multiplied by $G$, the number of genome copies shed by each infected individual throughout their infection. Therefore, conditional on some time series of daily infections generated by one of the models, $I(t)$, the *expected* number of viral genomes shed in WW on each day $t$ is,
+
+$$\overline{V}[t] = \ G\sum_{\tau\geq 0} S(\tau)I(t-\tau)$$
+While this initial formulation describes $G$, the amount of virus shed per infected individual, as a constant, we know the individual viral shedding amount is highly dispersed across infected individuals (with perhaps 1% of infectees producing 100x more viral genomes than the remaining infected individuals). However, in this first pass, we just assume that every individual sheds at the same rate $G$. In practice, we will likely need this to be site/location specific
.
+
+
Lastly, we expect to observe the data in terms of a concentration, so we assume:
+
+$$\overline{C}[t] = \ \frac{ \overline{V}[t]}{NW}$$
+In practice, we fit $\frac{G}{W}$ the number of genomes shed per mL of WW produced
+per person per day.
+
+# Assign parameters
+```{r}
+# This contains all the model priors and parameter settings. Informative priors
+# include the shedding viral kinetic parameters (timing of peak viral shedding,
+# magnitude of peak viral shedding, and duration of viral shedding). These
+# were informed in combination from literature on fecal shedding and viral loads
+# in the nasal package.
+params <- get_params(here::here("input", "params.toml")) #
+print(params)
+
+stan_data <- do.call(get_stan_data, c(
+ list(train_data = train_data),
+ list(params = params), config_vars_ss
+))
+
+# Pull the prior hyperparameters into the global environment so you can create
+# the init_fun to be passed to stan
+par_names <- colnames(params)
+for (i in seq_along(par_names)) {
+ assign(par_names[i], as.double(params[i]))
+}
+
+
+# Get other variables needed from data
+pop <- train_data %>%
+ select(pop) %>%
+ unique() %>%
+ pull(pop)
+stopifnot("More than one population size in training data" = length(pop) == 1)
+
+n_weeks <- as.numeric(stan_data$n_weeks)
+tot_weeks <- as.numeric(stan_data$tot_weeks)
+
+# Estimate of number of initial infections
+i0 <- mean(train_data$daily_hosp_admits[1:7]) / p_hosp_mean
+```
+
+# Initialize the parameter search using center of the priors + a bit of noise
+```{r}
+init_fun <- function() {
+ state_agg_inits(train_data, params, stan_data)
+}
+```
+
+# Compile the model
+```{r, include = FALSE}
+model_file_path <- here::here("cfaforecastrenewalww", "inst", "stan", "renewal_ww_hosp.stan")
+model <- compile_model(model_file_path)
+```
+# Fit the model
+```{r, include = FALSE}
+fit_dynamic_rt <- model$sample(
+ data = stan_data,
+ seed = 123,
+ init = init_fun,
+ iter_sampling = 500,
+ iter_warmup = 250,
+ chains = 4,
+ parallel_chains = 4
+)
+```
+
+# Look at the parameter draws (for static parameters)
+```{r}
+all_draws <- fit_dynamic_rt$draws()
+phi_h <- all_draws %>%
+ spread_draws(phi_h) %>%
+ sample_draws(ndraws = 100) %>%
+ mutate(draw = `.draw`) %>%
+ mutate(
+ name = "phi_h",
+ ) %>%
+ rename(value = phi_h) %>%
+ select(name, value, draw)
+ggplot(phi_h) +
+ aes(x = value) +
+ stat_halfeye() +
+ xlab("Dispersion in observed error in hospitalizations")
+
+autoreg_p_hosp <- all_draws %>%
+ spread_draws(autoreg_p_hosp) %>%
+ sample_draws(ndraws = 100) %>%
+ mutate(draw = `.draw`) %>%
+ mutate(
+ name = "autoreg_p_hosp",
+ ) %>%
+ rename(value = autoreg_p_hosp) %>%
+ select(name, value, draw)
+ggplot(autoreg_p_hosp) +
+ aes(x = value) +
+ stat_halfeye() +
+ xlab("Time-varying IHR autoregulation coefficient")
+
+infection_feedback <- all_draws %>%
+ spread_draws(infection_feedback) %>%
+ mutate(draw = `.draw`) %>%
+ mutate(
+ name = "infection_feedback",
+ ) %>%
+ rename(value = infection_feedback) %>%
+ select(name, value, draw)
+ggplot(infection_feedback) +
+ aes(x = value) +
+ stat_halfeye() +
+ xlab("Estimated infection feedback")
+
+autoreg_rt <- all_draws %>%
+ spread_draws(autoreg_rt) %>%
+ sample_draws(ndraws = 100) %>%
+ mutate(draw = `.draw`) %>%
+ mutate(
+ name = "autoreg_rt",
+ ) %>%
+ rename(value = autoreg_rt) %>%
+ select(name, value, draw)
+ggplot(autoreg_rt) +
+ aes(x = value) +
+ stat_halfeye() +
+ xlab("autoregressive term on R(t)")
+
+eta_sd <- all_draws %>%
+ spread_draws(eta_sd) %>%
+ sample_draws(ndraws = 100) %>%
+ mutate(draw = `.draw`) %>%
+ mutate(
+ name = "eta_sd",
+ ) %>%
+ rename(value = eta_sd) %>%
+ select(name, value, draw)
+ggplot(eta_sd) +
+ aes(x = value) +
+ stat_halfeye() +
+ xlab("RW step size (eta)")
+```
+# Draws from generated quantities
+```{r}
+gen_quantities_draws <- get_generated_quantities_draws(all_draws, config_vars_ss$model_type)
+gen_quantities_draws_metadata <- gen_quantities_draws %>%
+ mutate(
+ forecast_date = config_vars_ss$forecast_date,
+ location = config_vars_ss$location,
+ include_ww = config_vars_ss$include_ww,
+ hosp_reporting_delay = config_vars_ss$hosp_reporting_delay
+ ) %>%
+ filter(name %in% c("pred_hosp", "pred_ww", "R(t)", "p_hosp")) %>%
+ left_join(date_df, by = "t") %>%
+ left_join(train_data %>% select(
+ t, date, ww, daily_hosp_admits,
+ daily_hosp_admits_for_eval, pop
+ ), by = c("date", "t"))
+hosp_per_100k <- gen_quantities_draws_metadata %>%
+ filter(name == "pred_hosp") %>% # grab the generated quantities
+ mutate(
+ name = "pred_hosp_per_100k",
+ value = 1e5 * value / pop
+ )
+gen_quants_draws <- rbind(gen_quantities_draws_metadata, hosp_per_100k)
+
+metadata_df <- get_metadata(train_data)
+last_hosp_data_date <- metadata_df$last_hosp_data_date
+
+# Make a column for matched observed data
+model_draws <- gen_quants_draws %>%
+ mutate(
+ obs_data = case_when(
+ name == "pred_hosp" ~ daily_hosp_admits_for_eval,
+ name == "pred_hosp_per_100k" ~ 1e5 * daily_hosp_admits_for_eval / pop,
+ name == "pred_ww" ~ ww,
+ TRUE ~ NA
+ ),
+ # Column for period
+ period = case_when(
+ date <= last_hosp_data_date ~ "calibration",
+ (date > last_hosp_data_date & date <= forecast_date) ~ "nowcast",
+ date > forecast_date ~ "forecast"
+ ),
+ model_type = config_vars_ss$model_type
+ )
+
+subsetted_model_draws <- model_draws %>%
+ ungroup() %>%
+ filter(draw %in% sample(1:max(draw), 100))
+ww_draws <- subsetted_model_draws %>% filter(name == "pred_ww")
+
+
+plot_hosp_draws <- get_plot_draws(subsetted_model_draws,
+ "pred_hosp",
+ figure_file_path = config_vars_ss$output_dir,
+ from_full_df = TRUE,
+ days_pre_forecast_date_plot = 365,
+ show_calibration_data = FALSE,
+ write_files = FALSE,
+ show_median = FALSE
+)
+plot_hosp_draws
+
+
+plot_ww_draws <- get_plot_draws(subsetted_model_draws,
+ "pred_ww",
+ figure_file_path = config_vars_ss$output_dir,
+ from_full_df = TRUE,
+ config_vars_ss$output_file_path,
+ days_pre_forecast_date_plot = 150,
+ write_files = FALSE
+)
+plot_ww_draws
+
+p_hosp_t <- gen_quantities_draws_metadata %>% filter(name == "p_hosp")
+ggplot(p_hosp_t) +
+ geom_line(aes(x = date, y = value, group = draw), size = 0.1, alpha = 0.1)
+```
+Plot R(t)
+```{r}
+rt_draws <- model_draws %>% filter(name == "R(t)")
+rt_quantiles <- rt_draws %>%
+ group_by(date, period) %>%
+ summarise(
+ R_t_median = quantile(value, 0.5),
+ R_t_25th = quantile(value, 0.25),
+ R_t_75th = quantile(value, 0.75),
+ R_t_975th = quantile(value, 0.975),
+ R_t_025th = quantile(value, 0.025)
+ )
+plot_color <- "blue3"
+last_ww_data_date <- max(
+ model_draws$date[!is.na(model_draws$ww) & model_draws$period != "forecast"]
+)
+last_hosp_data_date <- max(model_draws$date[!is.na(model_draws$daily_hosp_admits)])
+last_obs_data_date <- max(last_ww_data_date, last_hosp_data_date)
+
+plot <- ggplot(rt_quantiles) +
+ geom_line(aes(x = date, y = R_t_median, color = period)) +
+ geom_ribbon(aes(x = date, ymin = R_t_25th, ymax = R_t_75th, fill = period), alpha = 0.2) +
+ geom_ribbon(aes(x = date, ymin = R_t_025th, ymax = R_t_975th, fill = period), alpha = 0.4) +
+ geom_vline(aes(xintercept = last_obs_data_date), linetype = "dashed") +
+ geom_hline(aes(yintercept = 1), linetype = "dotted") +
+ xlab("") +
+ ylab("R(t)") +
+ theme_bw() +
+ scale_x_date(
+ date_breaks = "2 weeks",
+ labels = scales::date_format("%Y-%m-%d")
+ ) +
+ theme(
+ axis.text.x = element_text(
+ size = 10, vjust = 0.5,
+ hjust = 1, angle = 90
+ ),
+ axis.title.x = element_text(size = 10),
+ axis.title.y = element_text(size = 10),
+ plot.title = element_text(
+ size = 9,
+ vjust = 0.5, hjust = 0.5
+ )
+ ) +
+ # coord_cartesian(ylim = c(0.8, 1.3))+
+ ggtitle("R(t) estimates")
+plot
+sampled_draws <- sample(1:max(rt_draws$draw), 100)
+plot <- ggplot(rt_draws %>% filter(draw %in% sampled_draws)) +
+ geom_line(aes(x = date, y = value, group = draw, color = period), alpha = 0.3) +
+ geom_vline(aes(xintercept = last_obs_data_date), linetype = "dashed") +
+ geom_hline(aes(yintercept = 1), linetype = "dotted") +
+ xlab("") +
+ ylab("R(t)") +
+ theme_bw() +
+ scale_x_date(
+ date_breaks = "2 weeks",
+ labels = scales::date_format("%Y-%m-%d")
+ ) +
+ theme(
+ axis.text.x = element_text(
+ size = 10, vjust = 0.5,
+ hjust = 1, angle = 90
+ ),
+ axis.title.x = element_text(size = 10),
+ axis.title.y = element_text(size = 10),
+ plot.title = element_text(
+ size = 9,
+ vjust = 0.5, hjust = 0.5
+ )
+ ) +
+ scale_y_continuous(trans = "log") +
+ # coord_cartesian(ylim = c(0.8, 1.3))+
+ ggtitle("R(t) estimates")
+plot
+
+
+# Estimate of R(t) on forecast date
+rt <- rt_quantiles %>% filter(date == config_vars_ss$forecast_date)
+print(paste0(
+ "R(t) = ", round(rt$R_t_median, 3),
+ " [", round(rt$R_t_025th, 3), " , ", round(rt$R_t_975th, 3), "]"
+))
+```
diff --git a/scratch/test_epi_class.R b/scratch/test_epi_class.R
new file mode 100644
index 00000000..dc7e558c
--- /dev/null
+++ b/scratch/test_epi_class.R
@@ -0,0 +1,50 @@
+library(tidyverse)
+
+
+test <- tibble(
+ trend = c(
+ rep("increasing", 5), rep("uncertain", 5),
+ rep("decreasing", 5), "increasing", "uncertain", "decreasing",
+ "decreasing", "uncertain", "uncertain", "increasing", "increasing"
+ ),
+ date = c(
+ seq(
+ from = lubridate::ymd("2024-01-01"),
+ to = lubridate::ymd("2024-01-15"),
+ by = "days"
+ ),
+ seq(
+ from = lubridate::ymd("2024-01-01"),
+ to = lubridate::ymd("2024-01-08"),
+ by = "days"
+ )
+ ),
+ state_abbr = c(rep("MA", 15), rep("NJ", 8))
+) |>
+ dplyr::group_by(state_abbr) |>
+ dplyr::mutate(
+ groups_phase = data.table::rleid(trend)
+ ) |>
+ dplyr::ungroup()
+
+summarized_by_trend <- test |>
+ dplyr::distinct(groups_phase, state_abbr, trend) |>
+ dplyr::arrange(state_abbr, groups_phase) |>
+ dplyr::group_by(state_abbr) |>
+ dplyr::mutate(
+ lag_phase = dplyr::lag(trend),
+ lead_phase = dplyr::lead(trend),
+ phase_reclass = dplyr::case_when(
+ trend == "uncertain" & lag_phase == lead_phase ~ lead_phase,
+ trend == "uncertain" & lag_phase == "decreasing" & lead_phase == "increasing" ~ "nadir",
+ trend == "uncertain" & lag_phase == "increasing" & lead_phase == "decreasing" ~ "peak",
+ TRUE ~ trend
+ )
+ ) |>
+ dplyr::ungroup() |>
+ dplyr::select(state_abbr, groups_phase, phase_reclass)
+
+rt_cat <- test |> dplyr::left_join(
+ summarized_by_trend,
+ by = c("state_abbr", "groups_phase")
+)
diff --git a/scratch/toy_example_ind_variation.Rmd b/scratch/toy_example_ind_variation.Rmd
new file mode 100644
index 00000000..051fc3a1
--- /dev/null
+++ b/scratch/toy_example_ind_variation.Rmd
@@ -0,0 +1,683 @@
+---
+title: "Individual variation toy examples"
+author: "Kaitlyn Johnson"
+date: "2024-05-13"
+output: html_document
+---
+
+The purpose of this document is to set up a toy example of a model that tries
+to estimate the mean and coefficient of variation of an individual random variable
+drawn from a gamma distribution using draws from the sum of gammas of a different
+and known number of individual components. This relies on the following:
+$$
+Y \sim \sum_{i=1}^{i= N} Gamma(\alpha = \frac{1}{cv^2}, \beta = \frac{1}{ \mu cv^2}) = Gamma(\alpha = \frac{N}{cv^2}, \beta = \frac{1}{\mu cv^2})
+$$
+We will first implement this using a linear scale implementation, then will try
+using log scale, and then a non-centered normal approximation adopted from
+[EpiSewer](https://github.com/adrian-lison/EpiSewer/blob/main/README.md)
+
+```{r setup, include=FALSE}
+knitr::opts_chunk$set(echo = TRUE)
+library(rstan)
+library(ggplot2)
+library(tidybayes)
+library(bayesplot)
+```
+
+
+Write a function generate the data with a specified individual mean, individual
+coefficient of variation, population sizes of the individual things we're
+summing over, and the number of draws from each population size. We'll start
+with a simple example where individual components have a mean 1 and coefficient
+of variation of 2, and we have observations of the sum of their components from
+population sizes of 5, 20, and 100.
+```{r}
+get_data <- function(mu = 1,
+ cv = 2,
+ pop = c(5, 20, 100),
+ n_draws = 10,
+ prior_mu = c(0, 1),
+ prior_cv = c(0, 1),
+ sigma_obs = 0,
+ prior_sigma_obs = c(0, 1)) {
+ pop_vector <- rep(pop, each = n_draws)
+ y <- rep(0, n_draws * length(pop))
+ for (i in seq_along(pop_vector)) {
+ alpha <- pop_vector[i] / cv^2
+ beta <- 1 / (mu * (cv^2))
+ y[i] <- rgamma(1, alpha, beta)
+ # same as y[i] <- sum(rgamma(pop_vector[i], 1 / cv^2, 1 / (mu * (cv^2))))
+ if (sigma_obs > 0) {
+ y[i] <- rlnorm(1, log(y[i]), sigma_obs)
+ }
+ }
+
+ data <- list(
+ N = length(pop_vector),
+ y = y,
+ pop_vector = pop_vector,
+ prior_mu = prior_mu,
+ prior_cv = prior_cv,
+ prior_sigma_obs = prior_sigma_obs
+ )
+
+ return(data)
+}
+
+standata <- get_data()
+```
+# Linear scale implementation
+Next we will write a linear scale stan model to fit this data. We will use
+`rstan` just for this example because it interfaces well with Rmd and will
+give us the diagnostic outputs in the format we want them.
+```{stan, output.var = "model_linear"}
+
+data {
+ int N;
+ vector[N] y;
+ vector[N] pop_vector;
+ real prior_mu[2];
+ real prior_cv[2];
+}
+
+// The parameters accepted by the model. Our model
+// accepts two parameters 'mu' and 'cv', the mean and coefficient of variation
+// in the individual R.V.s
+parameters {
+ real mu; // mean of individual component
+ real cv; // coefficent of variation in individual data
+}
+
+
+// The model to be estimated.
+model {
+ // Assume we have a truncated half N prior with a standard deviation of 2 times the mean
+ // not sure if thats reasonable...
+ mu ~ normal(prior_mu[1], prior_mu[2]);
+ cv ~ normal(prior_cv[1], prior_cv[2]);
+
+ //Formula for the sum of N gammas: Y ~ gamma(N*alpha, beta)
+ y ~ gamma(pop_vector./cv^2, 1 / (mu * (cv^2)));
+
+}
+
+```
+Fit the model using `rstan`
+```{r}
+fit <- sampling(model_linear,
+ standata,
+ warmup = 500,
+ iter = 2000,
+ chains = 4,
+ cores = 4,
+ seed = 42,
+ init = 0,
+ control = list(adapt_delta = 0.99, max_treedepth = 10)
+)
+```
+Analyze the outputs using `bayesplot` and `tidybayes`.
+In this example we will make pairs plots, trace plots, and also plots
+that compare the posterior estimate to the known true parameter.
+```{r}
+# Extract posterior draws for later use
+posterior <- as.array(fit)
+np <- nuts_params(fit)
+
+mcmc_pairs(posterior,
+ np = np, pars = c("mu", "cv"),
+ off_diag_args = list(size = 0.75)
+)
+color_scheme_set("mix-brightblue-gray")
+mcmc_trace(posterior, pars = c("mu", "cv"), np = np) +
+ xlab("Post-warmup iteration")
+
+
+params <- fit |>
+ tidybayes::spread_draws(mu, cv) |>
+ dplyr::mutate(
+ mu_true = 1,
+ cv_true = 2
+ )
+
+ggplot(params, aes(x = mu)) +
+ stat_halfeye() +
+ geom_vline(aes(xintercept = mu_true)) +
+ theme_bw() +
+ xlab("Mean") +
+ ylab("")
+
+ggplot(params, aes(x = cv)) +
+ stat_halfeye() +
+ geom_vline(aes(xintercept = cv_true)) +
+ theme_bw() +
+ xlab("Coefficient of Variation") +
+ ylab("")
+```
+
+
+We are a little off the estimate of the `cv`, but overall the convergence diagnostics look pretty good. Let's try to get closer to the wastewater problem by making the orders
+of magnitude of the `mu` the number of genomes shed and the `pop` the
+number of infected individuals in the population more realistic.
+```{r}
+cv_set <- 2
+mu_set <- exp(7)
+standata <- get_data(
+ mu = mu_set,
+ cv = cv_set,
+ prior_mu = c(1.2 * mu_set, 2 * mu_set), # make priors imperfect but on the same scale
+ prior_cv = c(0.9 * cv_set, 1),
+ n_draws = 10,
+ pop = c(10, 1000, 1e5)
+)
+df <- data.frame(y = standata$y, pop = rep(standata$pop_vector))
+# Quick plot of the data
+ggplot(df) +
+ geom_point(aes(x = pop, y = y)) +
+ scale_x_continuous(trans = "log") +
+ scale_y_continuous(trans = "log")
+
+
+fit <- sampling(model_linear,
+ standata,
+ warmup = 500,
+ iter = 2000,
+ chains = 4,
+ cores = 4,
+ seed = 42,
+ init = 0,
+ control = list(adapt_delta = 0.99, max_treedepth = 10)
+)
+```
+
+Looking at the outputs from the fit of the model
+```{r}
+# Extract posterior draws for later use
+posterior <- as.array(fit)
+np <- nuts_params(fit)
+
+mcmc_pairs(posterior,
+ np = np, pars = c("mu", "cv"),
+ off_diag_args = list(size = 0.75)
+)
+color_scheme_set("mix-brightblue-gray")
+mcmc_trace(posterior, pars = c("mu", "cv"), np = np) +
+ xlab("Post-warmup iteration")
+
+
+params <- fit |>
+ tidybayes::spread_draws(mu, cv) |>
+ dplyr::mutate(
+ mu_true = mu_set,
+ cv_true = cv_set
+ )
+
+ggplot(params, aes(x = mu)) +
+ stat_halfeye() +
+ geom_vline(aes(xintercept = mu_true)) +
+ theme_bw() +
+ xlab("Mean") +
+ ylab("")
+
+ggplot(params, aes(x = cv)) +
+ stat_halfeye() +
+ geom_vline(aes(xintercept = cv_true)) +
+ theme_bw() +
+ xlab("Coefficient of Variation") +
+ ylab("")
+```
+
+These look pretty good! Because of overflow though, we probably still want to
+be estimating things on the log scale, because in practice this quantity will
+be the total number of genomes shed in the sewershed, which could be
+quite large
+
+# Log scale implementation using `expgamma` function
+Let's see if we log transform the data and rescale to estimate
+the mean and the relative variation in shedding, if we can improve the estimates
+
+```{r}
+# First write a function to get the data in log scale
+get_log_scale_data <- function(log_mu = 1,
+ cv = 2,
+ pop = c(5, 20, 100),
+ n_draws = 10,
+ log_prior_mu = c(1, 2),
+ prior_cv = c(0, 1)) {
+ pop_vector <- rep(pop, each = n_draws)
+ log_y <- rep(0, n_draws * length(pop))
+ mu <- exp(log_mu)
+ for (i in seq_along(pop_vector)) {
+ alpha <- pop_vector[i] / cv^2
+ beta <- 1 / (mu * (cv^2))
+ log_y[i] <- log(rgamma(1, alpha, beta))
+ }
+
+ data <- list(
+ N = length(pop_vector),
+ log_y = log_y,
+ pop_vector = pop_vector,
+ log_prior_mu = log_prior_mu,
+ prior_cv = prior_cv
+ )
+
+ return(data)
+}
+```
+Now generate the log scaled data
+```{r}
+mu_set <- exp(7)
+cv_set <- 2
+standata <- get_log_scale_data(
+ log_mu = log(mu_set),
+ cv = cv_set,
+ log_prior_mu = c(0.5 * log(mu_set), 2 * log(mu_set)),
+ prior_cv = c(0, cv_set),
+ pop = c(10, 1000, 1e5)
+)
+df <- data.frame(log_y = standata$log_y, pop = rep(standata$pop_vector))
+# Quick plot of the data
+ggplot(df) +
+ geom_point(aes(x = pop, y = log_y)) +
+ scale_x_continuous(trans = "log")
+```
+
+Modify the stan model to fit the log data and to estimate separately a a relative
+shedding variation and the lean in log scale
+```{stan, output.var = "log_model"}
+functions{
+real expgamma_lpdf(vector xs,
+ vector shapes_k,
+ vector scales_theta){
+
+ return(sum(
+ -shapes_k .* log(scales_theta) -
+ lgamma(shapes_k) +
+ shapes_k .* xs - (exp(xs) ./ scales_theta)));
+}
+real expgamma_lpdf(vector x, real shape_k, real scale_theta){
+
+ return(sum(
+ -shape_k * log(scale_theta) -
+ lgamma(shape_k) +
+ shape_k * x - (exp(x) / scale_theta)));
+}
+
+real expgamma_lpdf(real x, real shape_k, real scale_theta){
+
+ return(
+ -shape_k * log(scale_theta) -
+ lgamma(shape_k) +
+ shape_k * x - (exp(x) / scale_theta));
+}
+
+}
+
+
+data {
+ int N;
+ vector[N] log_y;
+ vector[N] pop_vector;
+ real log_prior_mu[2];
+ real prior_cv[2];
+}
+
+
+// The parameters accepted by the model. Our model
+// accepts two parameters 'mu' and 'cv', the mean and coefficient of variation
+// in the individual R.V.s
+parameters {
+ real log_mu; // mean of individual component
+ real cv; // coefficent of variation in individual data
+}
+
+transformed parameters{
+ real mu = exp(log_mu ); // We can transform the mean before passing to the gamma
+}
+
+
+// The model to be estimated.
+model {
+ log_mu ~ normal(log_prior_mu[1], log_prior_mu[2]); // lognormal prior
+ cv ~ normal(prior_cv[1], prior_cv[2]);
+
+ //Formula for the sum of N gammas: Y ~ gamma(N*alpha, beta)
+ for(i in 1:N){
+ log_y[i] ~ expgamma(pop_vector[i]./cv^2, mu * (cv^2));
+}
+
+
+}
+
+```
+
+```{r}
+fit <- sampling(log_model,
+ standata,
+ warmup = 500,
+ iter = 2000,
+ chains = 4,
+ cores = 4,
+ seed = 42,
+ init = 0,
+ control = list(adapt_delta = 0.99, max_treedepth = 10)
+)
+```
+Look at the parameter estimates
+```{r}
+posterior <- as.array(fit)
+np <- nuts_params(fit)
+
+mcmc_pairs(posterior,
+ np = np, pars = c("mu", "cv"),
+ off_diag_args = list(size = 0.75)
+)
+color_scheme_set("mix-brightblue-gray")
+mcmc_trace(posterior, pars = c("mu", "cv"), np = np) +
+ xlab("Post-warmup iteration")
+
+
+params <- fit |>
+ tidybayes::spread_draws(mu, cv) |>
+ dplyr::mutate(
+ mu_true = mu_set,
+ cv_true = cv_set
+ )
+
+ggplot(params, aes(x = mu)) +
+ stat_halfeye() +
+ geom_vline(aes(xintercept = mu_true)) +
+ theme_bw() +
+ xlab("Mean") +
+ ylab("")
+
+ggplot(params, aes(x = cv)) +
+ stat_halfeye() +
+ geom_vline(aes(xintercept = cv_true)) +
+ theme_bw() +
+ xlab("Coefficient of Variation") +
+ ylab("")
+```
+
+These look ok but not great.
+
+# Non-centered parameterization
+This relies on an approximation to the sum of iid gammas as implemented in
+EpiSewer.
+
+The non-centered parameterization re-expresses the model and likelihood as:
+$$
+\begin{aligned}
+\zeta \sim Normal(0,1) \\
+
+Y = N\mu + \zeta \sqrt{N}cv
+\end{aligned}
+$$
+Where $y$ are the observations of the summed distributions over the population,
+$N$ is the population size (so analogous to number of infections in the
+wastewater model), $\mu$ is the individual mean and $cv$ is the individual
+coefficient of variation.
+
+With observation noise, we can write this as:
+$$
+\begin{aligned}
+ \zeta \sim Normal(0,1) \\
+\overline{y} = N\mu + \zeta \sqrt{N}cv \\
+Y \sim Normal(\overline{y}, \sigma)
+\end{aligned}
+$$
+Where $\overline{y}$ is the expected value of the observations and $\sigma$ is the observation model noise.
+We'll start with the first implementation without observation noise. To compute the log likelihood,
+we have to account for a [change of variables](https://mc-stan.org/docs/stan-users-guide/reparameterization.html#changes-of-variables) to solve for $\zeta$
+
+```{stan, output.var = "ncp_model_no_noise"}
+
+functions {
+
+// Non-centered paramaterization of a normal approximation for the
+// sum of N i.i.d. Gamma distributed RVs with mean 1 and a specified cv
+vector gamma_sum_approx(real cv, vector N, vector noise_noncentered) {
+ // sqrt(N) * cv is the standard deviation of the sum of Gamma distributions
+ return N + noise_noncentered .* sqrt(N) * cv;
+}
+
+
+}
+
+data {
+ int N;
+ vector[N] y;
+ vector[N] pop_vector;
+ real prior_mu[2];
+ real prior_cv[2];
+}
+
+// The parameters accepted by the model. Our model
+// accepts two parameters 'mu' and 'cv', the mean and coefficient of variation
+// in the individual R.V.s
+parameters {
+ real mu; // mean of individual component
+ real cv; // coefficent of variation in individual data
+}
+
+
+// The model to be estimated.
+model {
+ // Assume we have a truncated half N prior with a standard deviation of 2 times the mean
+ // not sure if thats reasonable...
+ mu ~ normal(prior_mu[1], prior_mu[2]);
+ cv ~ normal( prior_cv[1], prior_cv[2]);
+ (y- mu * pop_vector)./(mu * sqrt(pop_vector) * cv) ~ std_normal();
+ target+= -sum(log(mu * sqrt(pop_vector) * cv));
+
+}
+
+```
+
+Diagnostic warning from PARSER can be ignored (we did indeed adjust for the change of variables)
+
+```{r}
+cv_set <- 0.1
+mu_set <- 5
+standata <- get_data(
+ mu = mu_set,
+ cv = cv_set,
+ prior_mu = c(0.5 * mu_set, 2 * mu_set), # make priors imperfect but on the same scale
+ prior_cv = c(0, 2),
+ n_draws = 10,
+ pop = c(10, 1000, 1e5)
+)
+df <- data.frame(y = standata$y, pop = rep(standata$pop_vector))
+df <- data.frame(y = standata$y, pop = rep(standata$pop_vector))
+# Quick plot of the data
+ggplot(df) +
+ geom_point(aes(x = pop, y = y)) +
+ scale_x_continuous(trans = "log") +
+ scale_y_continuous(trans = "log")
+
+
+fit <- sampling(ncp_model_no_noise,
+ standata,
+ warmup = 500,
+ iter = 2000,
+ chains = 4,
+ cores = 4,
+ seed = 42,
+ init = 0,
+ control = list(adapt_delta = 0.99, max_treedepth = 10)
+)
+```
+
+```{r}
+posterior <- as.array(fit)
+np <- nuts_params(fit)
+
+mcmc_pairs(posterior,
+ np = np, pars = c("mu", "cv"),
+ off_diag_args = list(size = 0.75)
+)
+color_scheme_set("mix-brightblue-gray")
+mcmc_trace(posterior, pars = c("mu", "cv"), np = np) +
+ xlab("Post-warmup iteration")
+
+
+params <- fit |>
+ tidybayes::spread_draws(mu, cv) |>
+ dplyr::mutate(
+ mu_true = mu_set,
+ cv_true = cv_set
+ )
+
+ggplot(params, aes(x = mu)) +
+ stat_halfeye() +
+ geom_vline(aes(xintercept = mu_true)) +
+ theme_bw() +
+ xlab("Mean") +
+ ylab("")
+
+ggplot(params, aes(x = cv)) +
+ stat_halfeye() +
+ geom_vline(aes(xintercept = cv_true)) +
+ theme_bw() +
+ xlab("Coefficient of Variation") +
+ ylab("")
+```
+
+This looks good! We definitely needed more iterations to get sufficient ESS.
+
+
+## With observation noise
+Now let's add observation noise (lognormal). We can here use the normal transformed parameters approach again for the non-centered parameterization.
+
+```{stan, output.var = "ncp_model_noise"}
+
+functions {
+
+// Non-centered paramaterization of a normal approximation for the
+// sum of N i.i.d. Gamma distributed RVs with mean 1 and a specified cv
+vector gamma_sum_approx(real cv, vector N, vector noise_noncentered) {
+ // sqrt(N) * cv is the standard deviation of the sum of Gamma distributions
+ return N + noise_noncentered .* sqrt(N) * cv;
+}
+
+
+}
+
+data {
+ int N;
+ vector[N] y;
+ vector[N] pop_vector;
+ real prior_mu[2];
+ real prior_cv[2];
+ real prior_sigma_obs[2];
+}
+
+transformed data{
+ vector[N] log_y = log(y);
+}
+
+// The parameters accepted by the model. Our model
+// accepts two parameters 'mu' and 'cv', the mean and coefficient of variation
+// in the individual R.V.s
+parameters {
+ real mu; // mean of individual component
+ real cv; // coefficent of variation in individual data
+ vector[N] zeta_raw;
+ real sigma_obs; // variance of the likelihood
+}
+
+transformed parameters{
+ vector[N] exp_obs = mu * gamma_sum_approx(cv, pop_vector, zeta_raw);
+}
+
+// The model to be estimated.
+model {
+ mu ~ normal(prior_mu[1], prior_mu[2]);
+ cv ~ normal(prior_cv[1], prior_cv[2]);
+ zeta_raw ~ normal(0,1);
+ sigma_obs ~ normal(prior_sigma_obs[1], prior_sigma_obs[2]);
+
+ //Likelihood: mean = mu*N + sqrt(N)*cv *N(0.1), sd = sigma_obs
+ //y ~ lognormal(log(exp_obs), sigma_obs); // alternate way of writing it
+ log_y ~ normal(log(exp_obs), sigma_obs);
+
+
+}
+```
+
+```{r}
+mu_set <- exp(7)
+cv_set <- 0.6
+sigma_obs_set <- 0.4
+standata <- get_data(
+ mu = mu_set,
+ pop = c(1e1, 1e3, 1e6),
+ n_draws = 30, # make sure sufficient pop size
+ cv = cv_set,
+ sigma_obs = sigma_obs_set,
+ prior_mu = c(0.5 * mu_set, 2 * mu_set),
+ prior_cv = c(0.9 * cv_set, 1),
+ prior_sigma_obs = c(0, 2)
+)
+df <- data.frame(y = standata$y, pop = rep(standata$pop_vector))
+# Quick plot of the data
+ggplot(df) +
+ geom_point(aes(x = pop, y = y)) +
+ scale_x_continuous(trans = "log") +
+ scale_y_continuous(trans = "log")
+
+
+fit <- sampling(ncp_model_noise,
+ standata,
+ warmup = 500,
+ iter = 2000,
+ chains = 4,
+ cores = 4,
+ seed = 42,
+ init = 0,
+ control = list(adapt_delta = 0.99, max_treedepth = 10)
+)
+```
+```{r}
+posterior <- as.array(fit)
+np <- nuts_params(fit)
+
+mcmc_pairs(posterior,
+ np = np, pars = c("mu", "cv", "sigma_obs"),
+ off_diag_args = list(size = 0.75)
+)
+color_scheme_set("mix-brightblue-gray")
+mcmc_trace(posterior, pars = c("mu", "cv", "sigma_obs"), np = np) +
+ xlab("Post-warmup iteration")
+
+
+params <- fit |>
+ tidybayes::spread_draws(mu, cv, sigma_obs) |>
+ dplyr::mutate(
+ mu_true = mu_set,
+ cv_true = cv_set,
+ sigma_obs_true = sigma_obs_set
+ )
+
+ggplot(params, aes(x = mu)) +
+ stat_halfeye() +
+ geom_vline(aes(xintercept = mu_true)) +
+ theme_bw() +
+ xlab("Mean") +
+ ylab("")
+
+ggplot(params, aes(x = cv)) +
+ stat_halfeye() +
+ geom_vline(aes(xintercept = cv_true)) +
+ theme_bw() +
+ xlab("Coefficient of Variation") +
+ ylab("")
+
+ggplot(params, aes(x = sigma_obs)) +
+ stat_halfeye() +
+ geom_vline(aes(xintercept = sigma_obs_true)) +
+ theme_bw() +
+ xlab("Observation noise") +
+ ylab("")
+```
+
+This is also looking good.
diff --git a/scratch/ww_data_quick_figs.R b/scratch/ww_data_quick_figs.R
new file mode 100644
index 00000000..ba652cc3
--- /dev/null
+++ b/scratch/ww_data_quick_figs.R
@@ -0,0 +1,59 @@
+library(ggplot2)
+
+nwss_ny <- readr::read_csv(file.path("input", "ww_data", "nwss_data", "2024-02-14.csv")) |>
+ dplyr::filter(wwtp_jurisdiction == "ny")
+ny_old <- readr::read_csv(file.path("input", "ww_data", "nwss_data", "2024-02-11.csv")) |>
+ dplyr::filter(wwtp_jurisdiction == "ny")
+
+ggplot(nwss_ny) +
+ geom_line(
+ aes(
+ x = sample_collect_date, y = pcr_target_avg_conc,
+ color = pcr_target_units
+ ),
+ size = 0.1
+ ) +
+ facet_wrap(~pcr_target_units) +
+ ggtitle("2024-02-13 dataset") +
+ scale_y_continuous(trans = "log")
+
+ggplot(ny_old) +
+ geom_line(
+ aes(
+ x = sample_collect_date, y = pcr_target_avg_conc,
+ color = pcr_target_units
+ ),
+ size = 0.1
+ ) +
+ facet_wrap(~pcr_target_units) +
+ ggtitle("2024-02-11 dataset") +
+ scale_y_continuous(trans = "log")
+
+
+# Explore MN data
+nwss_ga <- readr::read_csv(file.path("input", "ww_data", "nwss_data", "2024-02-24.csv")) |>
+ dplyr::filter(wwtp_jurisdiction == "ga")
+
+ggplot(nwss_ga) +
+ geom_line(
+ aes(
+ x = sample_collect_date, y = pcr_target_avg_conc,
+ color = pcr_target_units
+ ),
+ size = 0.1
+ ) +
+ facet_wrap(~pcr_target_units) +
+ ggtitle("2024-02-21 dataset") +
+ scale_y_continuous(trans = "log")
+
+ggplot(ny_old) +
+ geom_line(
+ aes(
+ x = sample_collect_date, y = pcr_target_avg_conc,
+ color = pcr_target_units
+ ),
+ size = 0.1
+ ) +
+ facet_wrap(~pcr_target_units) +
+ ggtitle("2024-02-11 dataset") +
+ scale_y_continuous(trans = "log")
diff --git a/setup_container.R b/setup_container.R
new file mode 100644
index 00000000..b6b989e0
--- /dev/null
+++ b/setup_container.R
@@ -0,0 +1,37 @@
+options(
+ ## make HTTP requests
+ ## identify us correctly
+ ## to Posit package manager
+ ## so we get appropriate
+ ## precompiled binaries
+ ## see https://docs.posit.co/rspm/1.0.12/admin/binaries.html#binaries-r-configuration
+ HTTPUserAgent = sprintf(
+ "R/%s R (%s)",
+ getRversion(),
+ paste(
+ getRversion(),
+ R.version["platform"], R.version["arch"],
+ R.version["os"]
+ )
+ ),
+ ## use Posit package manager to get
+ ## precompiled binaries where possible
+ repos = c(
+ RSPM = "https://packagemanager.posit.co/cran/__linux__/jammy/latest"
+ )
+)
+
+additional_deps <- c(
+ "argparser"
+)
+
+install.packages("pak")
+pak::pkg_install("local::wweval")
+pak::pkg_install(additional_deps)
+cmdstanr::install_cmdstan()
+dir.create("stanmodels")
+cfaforecastrenewalww::compile_model(
+ "cfaforecastrenewalww/inst/stan/renewal_ww_hosp_site_level_inf_dynamics.stan",
+ "cfaforecastrenewalww/inst/stan",
+ "stanmodels"
+)
diff --git a/src/setup_eval.R b/src/setup_eval.R
index a37deb0d..f810631a 100644
--- a/src/setup_eval.R
+++ b/src/setup_eval.R
@@ -14,13 +14,14 @@ write_eval_config(
"PA", "PR", "RI", "SC", "SD", "TN", "TX", "UT", "VA",
"VT", "WA", "WI", "WV", "WY"
),
- forecast_dates = as.character(
- seq(
- from = lubridate::ymd("2023-10-16"),
- to = lubridate::ymd("2024-03-11"),
- by = "week"
- )
- ),
+ forecast_dates =
+ as.character(
+ seq(
+ from = lubridate::ymd("2023-10-16"),
+ to = lubridate::ymd("2024-03-11"),
+ by = "week"
+ )
+ ),
scenarios = c(
"status_quo"
),
diff --git a/src/write_eval_config.R b/src/write_eval_config.R
index 715506b2..d2235d03 100644
--- a/src/write_eval_config.R
+++ b/src/write_eval_config.R
@@ -69,6 +69,7 @@ write_eval_config <- function(locations, forecast_dates,
output_dir <- file.path("output", "eval")
figure_dir <- file.path("output", "eval", "plots")
hub_subdir <- file.path("output", "eval", "hub")
+ retro_rt_path <- file.path("input", "retro_Rt", "Rt_draws.parquet")
score_subdir <- file.path("output", "eval", "hub")
hub_model_names <- c(
"COVIDhub-4_week_ensemble", "UMass-trends_ensemble",
@@ -127,6 +128,7 @@ write_eval_config <- function(locations, forecast_dates,
baseline_score_table_dir = baseline_score_table_dir,
output_dir = output_dir,
hub_subdir = hub_subdir,
+ retro_rt_dir = retro_rt_dir,
score_subdir = score_subdir,
raw_output_dir = raw_output_dir,
figure_dir = figure_dir,
diff --git a/wweval/NAMESPACE b/wweval/NAMESPACE
index 6645aa3b..04762920 100644
--- a/wweval/NAMESPACE
+++ b/wweval/NAMESPACE
@@ -13,7 +13,9 @@ export(eval_post_process_ww)
export(exclude_hosp_outliers)
export(format_for_hub)
export(get_box_plot)
+export(get_convergence_df)
export(get_diagnostic_flags)
+export(get_epidemic_phases_from_rt)
export(get_filepath)
export(get_full_scores)
export(get_heatmap_relative_wis)
@@ -42,6 +44,7 @@ export(get_stan_data_list)
export(get_state_level_quantiles)
export(get_state_level_ww_quantiles)
export(get_subpop_data)
+export(get_table_sufficient_ww)
export(get_ww_data_indices)
export(get_ww_data_sizes)
export(get_ww_values)
diff --git a/wweval/R/combine_outputs.R b/wweval/R/combine_outputs.R
index f8462046..f899bc50 100644
--- a/wweval/R/combine_outputs.R
+++ b/wweval/R/combine_outputs.R
@@ -68,5 +68,9 @@ combine_outputs <- function(output_type =
)
}
}
+
+ if (nrow(combined_output) == 0) {
+ combined_output <- NULL
+ }
return(combined_output)
}
diff --git a/wweval/R/eval_post_process.R b/wweval/R/eval_post_process.R
index 21070f1a..b25938d7 100644
--- a/wweval/R/eval_post_process.R
+++ b/wweval/R/eval_post_process.R
@@ -56,8 +56,8 @@ eval_post_process_ww <- function(config_index,
save_object("ww_summary", output_file_suffix)
errors <- ww_fit_obj$error
save_object("errors", output_file_suffix)
- flags <- ww_fit_obj$flags
- save_object("flags", output_file_suffix)
+ raw_flags <- data.frame(ww_fit_obj$flags)
+ save_object("raw_flags", output_file_suffix)
# Save errors
save_table(
data_to_save = errors,
@@ -68,6 +68,14 @@ eval_post_process_ww <- function(config_index,
model_type = "ww",
location = location
)
+ # Get evaluation data from hospital admissions and wastewater
+ # Join draws and flags with data and metadata
+ flags <- raw_flags |> dplyr::mutate(
+ scenario = scenario,
+ forecast_date = forecast_date,
+ model_type = "ww",
+ location = location
+ )
# Save flags
save_table(
data_to_save = flags,
@@ -78,8 +86,7 @@ eval_post_process_ww <- function(config_index,
model_type = "ww",
location = location
)
- # Get evaluation data from hospital admissions and wastewater
- # Join draws with data
+
hosp_draws <- {
if (is.null(ww_raw_draws)) {
NULL
@@ -291,8 +298,8 @@ eval_post_process_hosp <- function(config_index,
save_object("hosp_summary", output_file_suffix)
errors <- hosp_fit_obj$error
save_object("errors", output_file_suffix)
- flags <- hosp_fit_obj$flags
- save_object("flags", output_file_suffix)
+ raw_flags <- data.frame(hosp_fit_obj$flags)
+ save_object("raw_flags", output_file_suffix)
# Save errors
save_table(
data_to_save = errors,
@@ -303,6 +310,16 @@ eval_post_process_hosp <- function(config_index,
model_type = "hosp",
location = location
)
+
+
+ # Get evaluation data from hospital admissions and wastewater
+ # Join draws with flags + data and metadata
+ flags <- raw_flags |> dplyr::mutate(
+ scenario = scenario,
+ forecast_date = forecast_date,
+ model_type = "hosp",
+ location = location
+ )
# Save flags
save_table(
data_to_save = flags,
@@ -313,10 +330,6 @@ eval_post_process_hosp <- function(config_index,
model_type = "hosp",
location = location
)
-
-
- # Get evaluation data from hospital admissions and wastewater
- # Join draws with data
hosp_model_hosp_draws <- get_model_draws_w_data(
model_output = "hosp",
model_type = "hosp",
diff --git a/wweval/R/get_epidemic_phases_from_rt.R b/wweval/R/get_epidemic_phases_from_rt.R
new file mode 100644
index 00000000..ea1fd3ef
--- /dev/null
+++ b/wweval/R/get_epidemic_phases_from_rt.R
@@ -0,0 +1,84 @@
+#' Get epidemic phases from R(t)
+#' @description
+#' This function loads in the posterior estimate of the retrospective R(t)
+#' from NNH. Then it categorizes each week into an epidemic phase based on the
+#' algorithm used in https://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1011200 #nolint
+#' in the S5 appendix. Code available here: https://github.com/cdcepi/Evaluation-of-case-forecasts-submitted-to-COVID19-Forecast-Hub/blob/b3c379dfd48e8c673f67996014f151ce44cbd8fa/Code/Supplement%205_Rt_Epi%20Phases.R #nolint
+#'
+#'
+#' @param locations A vector of the state abbreviations
+#' @param retro_rt_path A path to the parquet file of retrospective R(t) estimates
+#' @param prob_threshold A numeric between 0 and 1 that defines the threshold
+#' probability R(t) >= 1 or R(t) <= 1. Default is `0.9` from the above paper.
+#'
+#' @return df_epi_phase A dataframe containing the epidemic phase expanded
+#' to be daily, for each location, based on the retrospective R(t) estimate.
+#' @export
+#'
+get_epidemic_phases_from_rt <- function(locations,
+ retro_rt_path,
+ prob_threshold = 0.9) {
+ retro_rt <- arrow::read_parquet(retro_rt_path) |>
+ dplyr::filter(state_abbr %in% locations) |>
+ dplyr::mutate(
+ week_start_date = cut(reference_date, "week")
+ )
+
+ rt_initial_cat <- retro_rt |>
+ dplyr::group_by(state_abbr, week_start_date) |>
+ dplyr::summarise(
+ prob_rt_greater_than_1 = mean(R > 1)
+ ) |>
+ dplyr::arrange(state_abbr, week_start_date) |>
+ dplyr::mutate(
+ trend = dplyr::case_when(
+ prob_rt_greater_than_1 > prob_threshold ~ "increasing",
+ prob_rt_greater_than_1 < 1 - prob_threshold ~ "decreasing",
+ TRUE ~ "uncertain"
+ )
+ ) |>
+ dplyr::group_by(state_abbr) |>
+ dplyr::mutate(
+ groups_phase = data.table::rleid(trend)
+ ) |>
+ dplyr::ungroup()
+
+ summarized_by_trend <- rt_initial_cat |>
+ dplyr::distinct(groups_phase, state_abbr, trend) |>
+ dplyr::arrange(state_abbr, groups_phase) |>
+ dplyr::group_by(state_abbr) |>
+ dplyr::mutate(
+ lag_phase = dplyr::lag(trend),
+ lead_phase = dplyr::lead(trend),
+ phase = dplyr::case_when(
+ trend == "uncertain" & lag_phase == lead_phase ~ lead_phase, # nolint
+ trend == "uncertain" & lag_phase == "decreasing" & lead_phase == "increasing" ~ "nadir", # nolint
+ trend == "uncertain" & lag_phase == "increasing" & lead_phase == "decreasing" ~ "peak", # nolint
+ TRUE ~ trend
+ )
+ ) |>
+ dplyr::ungroup() |>
+ dplyr::select(state_abbr, groups_phase, phase)
+
+ rt_cat <- rt_initial_cat |>
+ dplyr::left_join(
+ summarized_by_trend,
+ by = c("state_abbr", "groups_phase")
+ ) |>
+ dplyr::select(state_abbr, phase, week_start_date)
+
+
+
+ # Expand to daily and save only the necessary columns
+ df_epi_phase <- retro_rt |>
+ dplyr::distinct(reference_date, state_abbr, week_start_date) |>
+ dplyr::left_join(
+ rt_cat,
+ by = c("state_abbr", "week_start_date")
+ ) |>
+ dplyr::select(state_abbr, reference_date, phase)
+
+
+
+ return(df_epi_phase)
+}
diff --git a/wweval/R/get_table_sufficient_ww.R b/wweval/R/get_table_sufficient_ww.R
new file mode 100644
index 00000000..980113a4
--- /dev/null
+++ b/wweval/R/get_table_sufficient_ww.R
@@ -0,0 +1,80 @@
+#' Get table of location-forecast dates with sufficient wastewater
+#'
+#' @description
+#' This function takes in a large dataframe containing the quantiled estimated
+#' and forecasted wastewater concentrations for each site and lab in each
+#' location and forecast date, joined with the observed data on each day, in
+#' each site and lab, from each forecast date. We will use the data to
+#' get a table of location-forecast-dates where wastewater data was considered
+#' sufficient to inform a forecast, based on the critera that was used for
+#' the Hub submissions.
+#'
+#' @param ww_quantiles The dataframe of the quantiled ww predictions with the
+#' real data alongside it (labeled as calib data)
+#' @param delay_thres The maximum number of days of delay between the last
+#' wastewater data point and the forecat date, before we would flag a state as
+#' having insufficient wastewater data to inform a forecast. Default is 21
+#' @param n_dps_thres The threshold number of data points within a single site
+#' within a state before we would flag the state as having insufficient
+#' wastewater data to inform a forecast. Default is 5
+#' @param prop_below_lod_thres The threshold proportion of wastewater data
+#' points that can be below the LOD. If greater than this proportion of points
+#' are below the LOD, we flag the state as having insufficient wastewater data.
+#' Default is 0.5
+#' @param sd_thres The minimum standard deviation between wastewater data points
+#' within a site. This is intended to catch when a site reports all the same
+#' values. Default is 0.1
+#' @param mean_log_ww_value_thres The minimum value of the log of the ww
+#' concentration, default is -4
+#'
+#' @return table_of_loc_dates_w_ww a tibble containing the location,
+#' forecast_date, and column stating the wastewater data was sufficient for all
+#' locations and forecast dates where ww data was deemed sufficient
+#' @export
+#'
+get_table_sufficient_ww <- function(ww_quantiles,
+ delay_thres = 21,
+ n_dps_thres = 5,
+ prop_below_lod_thres = 0.5,
+ sd_thres = 0.1,
+ mean_log_ww_value_thres = -4) {
+ calib_data <- ww_quantiles |>
+ dplyr::distinct(
+ location, site, lab, lab_site_index,
+ date, forecast_date, below_LOD, calib_data
+ ) |>
+ dplyr::filter(!is.na(calib_data))
+
+ diagnostic_table <- calib_data |>
+ dplyr::group_by(location, forecast_date) |>
+ dplyr::summarize(
+ last_date = max(date),
+ n_dps = dplyr::n(),
+ prop_below_lod = sum(below_LOD == 1) / dplyr::n(),
+ sd = sd(calib_data),
+ mean_log_ww = mean(log(calib_data))
+ ) |>
+ dplyr::mutate(
+ flag_delay = as.integer(forecast_date - last_date) > delay_thres,
+ flag_n_dps = n_dps < n_dps_thres,
+ flag_lod = prop_below_lod > prop_below_lod_thres,
+ flag_sd = sd < sd_thres,
+ flag_low_val = mean_log_ww < mean_log_ww_value_thres
+ )
+
+ flag_table_long <- diagnostic_table |>
+ dplyr::ungroup() |>
+ tidyr::pivot_longer(starts_with("flag"))
+
+ # Ensure all `values` are boolean
+ stopifnot(
+ "In diagnostic table checking for sufficent wastewater data flags, not all values are boolean" =
+ is.logical(flag_table_long$value)
+ )
+
+ table_of_loc_dates_w_ww <- flag_table_long |>
+ dplyr::group_by(location, forecast_date) |>
+ dplyr::summarise(ww_sufficient = !any(value))
+
+ return(table_of_loc_dates_w_ww)
+}
diff --git a/wweval/R/model_run_diagnostic_flags.R b/wweval/R/model_run_diagnostic_flags.R
index c05c09f3..d7cfa0fe 100644
--- a/wweval/R/model_run_diagnostic_flags.R
+++ b/wweval/R/model_run_diagnostic_flags.R
@@ -51,3 +51,36 @@ get_diagnostic_flags <- function(stan_fit_object,
)
return(flag_df)
}
+
+
+#' Get convergence dataframe
+#' @description This function takes the larger dataframe of convergence
+#' flags for each location, forecast date, and scenario and checks if any of
+#' the flags are TRUE, and returns a dataframe with just a column indicating
+#' whether any flags are true
+#'
+#'
+#' @param all_flags a dataframe containing the flags for each location,
+#' forecast_date, and scenario
+#' @param scenario The scenario to filter to, since some eval output will include multiple
+#' scenarios
+#'
+#' @return a dataframe with a column `any_flags` indicating whether any of the
+#' flags in the original full descriptive set of congerence flags are TRUE.
+#' @export
+#'
+get_convergence_df <- function(all_flags,
+ scenario) {
+ convergence_df <- all_flags |>
+ dplyr::filter(scenario == {{ scenario }}) |>
+ tidyr::gather(key, value, starts_with("flag")) |>
+ dplyr::group_by(location, forecast_date, scenario, model_type) |>
+ dplyr::mutate(any_flags = any(value == TRUE)) |>
+ tidyr::spread(key, value) |>
+ dplyr::ungroup() |>
+ dplyr::select(
+ location, forecast_date, any_flags
+ )
+
+ return(convergence_df)
+}
diff --git a/wweval/R/sample_model.R b/wweval/R/sample_model.R
index d038ac52..34c11afb 100644
--- a/wweval/R/sample_model.R
+++ b/wweval/R/sample_model.R
@@ -107,7 +107,6 @@ sample_model <- function(standata,
} else {
# Get the diagnostics using thresholds set in production pipeline
flag_df <- get_diagnostic_flags(fit$result, n_chains, iter_sampling)
- any_flags <- any(flag_df)
draws <- fit$result$draws()
diagnostics <- fit$result$sampler_diagnostics(format = "df")
@@ -118,14 +117,9 @@ sample_model <- function(standata,
draws = draws,
diagnostics = diagnostics,
summary_diagnostics = summary_diagnostics,
- summary = summary
+ summary = summary,
+ flags = list(flag_df)
)
-
- if (any_flags) { # If there are model convergence issues, pass
- # flags alongside the draws and summaries
- out <- c(out, flags = list(flag_df))
- message("Model convergence issues")
- }
}
return(out)
}
diff --git a/wweval/man/get_convergence_df.Rd b/wweval/man/get_convergence_df.Rd
new file mode 100644
index 00000000..4efe80b9
--- /dev/null
+++ b/wweval/man/get_convergence_df.Rd
@@ -0,0 +1,25 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/model_run_diagnostic_flags.R
+\name{get_convergence_df}
+\alias{get_convergence_df}
+\title{Get convergence dataframe}
+\usage{
+get_convergence_df(all_flags, scenario)
+}
+\arguments{
+\item{all_flags}{a dataframe containing the flags for each location,
+forecast_date, and scenario}
+
+\item{scenario}{The scenario to filter to, since some eval output will include multiple
+scenarios}
+}
+\value{
+a dataframe with a column \code{any_flags} indicating whether any of the
+flags in the original full descriptive set of congerence flags are TRUE.
+}
+\description{
+This function takes the larger dataframe of convergence
+flags for each location, forecast date, and scenario and checks if any of
+the flags are TRUE, and returns a dataframe with just a column indicating
+whether any flags are true
+}
diff --git a/wweval/man/get_epidemic_phases_from_rt.Rd b/wweval/man/get_epidemic_phases_from_rt.Rd
new file mode 100644
index 00000000..02f17f44
--- /dev/null
+++ b/wweval/man/get_epidemic_phases_from_rt.Rd
@@ -0,0 +1,26 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/get_epidemic_phases_from_rt.R
+\name{get_epidemic_phases_from_rt}
+\alias{get_epidemic_phases_from_rt}
+\title{Get epidemic phases from R(t)}
+\usage{
+get_epidemic_phases_from_rt(locations, retro_rt_path, prob_threshold = 0.9)
+}
+\arguments{
+\item{locations}{A vector of the state abbreviations}
+
+\item{retro_rt_path}{A path to the parquet file of retrospective R(t) estimates}
+
+\item{prob_threshold}{A numeric between 0 and 1 that defines the threshold
+probability R(t) >= 1 or R(t) <= 1. Default is \code{0.9} from the above paper.}
+}
+\value{
+df_epi_phase A dataframe containing the epidemic phase expanded
+to be daily, for each location, based on the retrospective R(t) estimate.
+}
+\description{
+This function loads in the posterior estimate of the retrospective R(t)
+from NNH. Then it categorizes each week into an epidemic phase based on the
+algorithm used inhttps://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1011200 #nolint
+in the S5 appendix. Code available here: https://github.com/cdcepi/Evaluation-of-case-forecasts-submitted-to-COVID19-Forecast-Hub/blob/b3c379dfd48e8c673f67996014f151ce44cbd8fa/Code/Supplement\%205_Rt_Epi\%20Phases.R #nolint
+}
diff --git a/wweval/man/get_table_sufficient_ww.Rd b/wweval/man/get_table_sufficient_ww.Rd
new file mode 100644
index 00000000..59bbd57a
--- /dev/null
+++ b/wweval/man/get_table_sufficient_ww.Rd
@@ -0,0 +1,53 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/get_table_sufficient_ww.R
+\name{get_table_sufficient_ww}
+\alias{get_table_sufficient_ww}
+\title{Get table of location-forecast dates with sufficient wastewater}
+\usage{
+get_table_sufficient_ww(
+ ww_quantiles,
+ delay_thres = 21,
+ n_dps_thres = 5,
+ prop_below_lod_thres = 0.5,
+ sd_thres = 0.1,
+ mean_log_ww_value_thres = -4
+)
+}
+\arguments{
+\item{ww_quantiles}{The dataframe of the quantiled ww predictions with the
+real data alongside it (labeled as calib data)}
+
+\item{delay_thres}{The maximum number of days of delay between the last
+wastewater data point and the forecat date, before we would flag a state as
+having insufficient wastewater data to inform a forecast. Default is 21}
+
+\item{n_dps_thres}{The threshold number of data points within a single site
+within a state before we would flag the state as having insufficient
+wastewater data to inform a forecast. Default is 5}
+
+\item{prop_below_lod_thres}{The threshold proportion of wastewater data
+points that can be below the LOD. If greater than this proportion of points
+are below the LOD, we flag the state as having insufficient wastewater data.
+Default is 0.5}
+
+\item{sd_thres}{The minimum standard deviation between wastewater data points
+within a site. This is intended to catch when a site reports all the same
+values. Default is 0.1}
+
+\item{mean_log_ww_value_thres}{The minimum value of the log of the ww
+concentration, default is -4}
+}
+\value{
+table_of_loc_dates_w_ww a tibble containing the location,
+forecast_date, and column stating the wastewater data was sufficient for all
+locations and forecast dates where ww data was deemed sufficient
+}
+\description{
+This function takes in a large dataframe containing the quantiled estimated
+and forecasted wastewater concentrations for each site and lab in each
+location and forecast date, joined with the observed data on each day, in
+each site and lab, from each forecast date. We will use the data to
+get a table of location-forecast-dates where wastewater data was considered
+sufficient to inform a forecast, based on the critera that was used for
+the Hub submissions.
+}