From 081d4451729616e334402a23819afaf8b387486c Mon Sep 17 00:00:00 2001 From: kaitejohnson Date: Mon, 24 Jun 2024 12:17:34 -0400 Subject: [PATCH 1/2] WIP exp gamma implementation --- .pre-commit-config.yaml | 21 ++- _targets_eval.R | 6 +- _targets_eval_postprocessing.R | 165 ++++++++++++++---- cfaforecastrenewalww/R/get_data.R | 5 + ...newal_ww_hosp_site_level_inf_dynamics.stan | 117 +++++++++++-- model_definition.md | 4 +- setup_container.R | 37 ++++ src/setup_eval.R | 15 +- wweval/NAMESPACE | 2 + wweval/R/combine_outputs.R | 4 + wweval/R/eval_post_process.R | 33 ++-- wweval/R/get_table_sufficient_ww.R | 80 +++++++++ wweval/R/model_run_diagnostic_flags.R | 33 ++++ wweval/R/sample_model.R | 10 +- wweval/man/get_convergence_df.Rd | 25 +++ wweval/man/get_table_sufficient_ww.Rd | 53 ++++++ 16 files changed, 540 insertions(+), 70 deletions(-) create mode 100644 setup_container.R create mode 100644 wweval/R/get_table_sufficient_ww.R create mode 100644 wweval/man/get_convergence_df.Rd create mode 100644 wweval/man/get_table_sufficient_ww.Rd diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 073ffca6..d80f1e46 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,7 +19,26 @@ repos: hooks: - id: style-files - id: lintr -##### + +######## +# Python +- repo: https://github.com/psf/black + rev: 23.10.0 + hooks: + - id: black + args: ['--line-length', '79'] +- repo: https://github.com/PyCQA/isort + rev: 5.12.0 + hooks: + - id: isort + args: ['--profile', 'black', + '--line-length', '79'] +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.1.0 + hooks: + - id: ruff + +######### # Secrets - repo: https://github.com/Yelp/detect-secrets rev: v1.4.0 diff --git a/_targets_eval.R b/_targets_eval.R index 9d2ce92f..79489a31 100644 --- a/_targets_eval.R +++ b/_targets_eval.R @@ -869,7 +869,8 @@ downstream_targets <- list( name = plot_summarized_scores_w_data, command = get_plot_scores_w_data( grouped_submission_scores, - eval_hosp_data + eval_hosp_data, + eval_config$figure_dir ), pattern = map(grouped_submission_scores), iteration = "list", @@ -900,7 +901,8 @@ downstream_targets <- list( name = plot_quantile_comparison, command = get_plot_quantile_comparison( all_hosp_quantiles, - eval_hosp_data + eval_hosp_data, + eval_config$figure_dir ), pattern = map(all_hosp_quantiles), iteration = "list" diff --git a/_targets_eval_postprocessing.R b/_targets_eval_postprocessing.R index c9e31f52..a23fda31 100644 --- a/_targets_eval_postprocessing.R +++ b/_targets_eval_postprocessing.R @@ -141,14 +141,15 @@ combined_targets <- list( ## Flags------------------------------------------------------------------ tar_target( name = all_flags_ww, - command = combine_outputs( - output_type = "flags", - scenarios = eval_config$scenario, - forecast_dates = eval_config$forecast_dates, - locations = eval_config$location_ww, - eval_output_subdir = eval_config$output_dir, - model_type = "ww" - ) + command = + combine_outputs( + output_type = "flags", + scenarios = eval_config$scenario, + forecast_dates = eval_config$forecast_date_ww, + locations = eval_config$location_ww, + eval_output_subdir = eval_config$output_dir, + model_type = "ww" + ) ), tar_target( name = all_flags_hosp, @@ -161,6 +162,25 @@ combined_targets <- list( model_type = "hosp" ) ), + tar_target( + name = all_flags, + command = dplyr::bind_rows(all_flags_ww, all_flags_hosp) + ), + tar_target( + name = convergence_df_ww, + command = get_convergence_df( + all_flags_ww, + default_scenario = "status_quo" + ) |> + dplyr::rename(any_flags_ww = any_flags) + ), + tar_target( + name = convergence_df_hosp, + command = get_convergence_df(all_flags_hosp, + default_scenario = "no_wastewater" + ) |> + dplyr::rename(any_flags_hosp = any_flags) + ), ### Scores from quantiles------------------------------------------------- tar_target( name = all_ww_scores_quantiles, @@ -169,7 +189,7 @@ combined_targets <- list( scenarios = eval_config$scenario, forecast_dates = eval_config$forecast_date_ww, locations = eval_config$location_ww, - eval_output_subdir = file.path("output", "eval"), + eval_output_subdir = eval_config$output_dir, model_type = "ww" ) ), @@ -180,7 +200,7 @@ combined_targets <- list( scenarios = "no_wastewater", forecast_dates = eval_config$forecast_date_hosp, locations = eval_config$location_hosp, - eval_output_subdir = file.path("output", "eval"), + eval_output_subdir = eval_config$output_dir, model_type = "hosp" ) ), @@ -192,7 +212,7 @@ combined_targets <- list( scenarios = eval_config$scenario, forecast_dates = eval_config$forecast_date_ww, locations = eval_config$location_ww, - eval_output_subdir = file.path("output", "eval"), + eval_output_subdir = eval_config$output_dir, model_type = "ww" ) ), @@ -203,7 +223,7 @@ combined_targets <- list( scenarios = "no_wastewater", forecast_dates = eval_config$forecast_date_hosp, locations = eval_config$location_hosp, - eval_output_subdir = file.path("output", "eval"), + eval_output_subdir = eval_config$output_dir, model_type = "hosp" ) ), @@ -214,7 +234,7 @@ combined_targets <- list( scenarios = eval_config$scenario, forecast_dates = eval_config$forecast_date_ww, locations = eval_config$location_ww, - eval_output_subdir = file.path("output", "eval"), + eval_output_subdir = eval_config$output_dir, model_type = "ww" ) ), @@ -226,7 +246,7 @@ combined_targets <- list( scenarios = eval_config$scenario, forecast_dates = eval_config$forecast_date_ww, locations = eval_config$location_ww, - eval_output_subdir = file.path("output", "eval"), + eval_output_subdir = eval_config$output_dir, model_type = "ww" ) ), @@ -237,15 +257,110 @@ combined_targets <- list( scenarios = "no_wastewater", forecast_dates = eval_config$forecast_date_hosp, locations = eval_config$location_hosp, - eval_output_subdir = file.path("output", "eval"), + eval_output_subdir = eval_config$output_dir, model_type = "hosp" ) ) ) +# Head-to-head comparison targets------------------------------------------- +# This set of targets will be conditioned on the presence of sufficient +# wastewater, whereas the below targets assume that for every location and +# forecast date we had to submit a forecast, and so we used the hospital +# admissions only model if wastewater was missing. +# These are only relevant for the status quo scenario +head_to_head_targets <- list( + tar_target( + name = all_ww_quantiles_sq, + command = all_ww_quantiles |> + dplyr::filter(scenario == "status_quo") + ), + # Get a table of locations and forecast dates with sufficient wastewater + tar_target( + name = table_of_loc_dates_w_ww, + command = get_table_sufficient_ww(all_ww_quantiles) + ), + # Get a table indicating whether there are locations and forecast dates with + # convergence issues + tar_target( + name = convergence_df, + command = convergence_df_hosp |> + dplyr::left_join(convergence_df_ww, + by = c("location", "forecast_date") + ) + ), + # Get the full set of quantiles, filtered down to only states and + # forecast dates with sufficient wastewater for both ww model and hosp only + # model. Then join the convergence df + tar_target( + name = hosp_quantiles_filtered, + command = dplyr::bind_rows( + all_ww_hosp_quantiles, + all_hosp_model_quantiles + ) |> + dplyr::left_join(table_of_loc_dates_w_ww, + by = c("location", "forecast_date") + ) |> + dplyr::filter( + ww_sufficient # filters to location forecast dates with sufficient ww + ) |> + dplyr::left_join( + convergence_df, + by = c( + "location", + "forecast_date" + ) + ) + ), + # Do the same thing for the sampled scores, combining ww and hosp under + # the status quo scenario, filtering to the locations and forecast dates + # with sufficient wastewater, and then joining the convergence flags + tar_target( + name = scores_filtered, + command = dplyr::bind_rows( + all_hosp_scores, + all_ww_scores |> + dplyr::filter(scenario == "status_quo") + ) |> + dplyr::left_join(table_of_loc_dates_w_ww, + by = c("location", "forecast_date") + ) |> + dplyr::filter(ww_sufficient) |> + dplyr::left_join( + convergence_df, + by = c( + "location", + "forecast_date" + ) + ) + ), + # Repeat for the quantile-based scores + tar_target( + name = scores_quantiles_filtered, + command = dplyr::bind_rows( + all_hosp_scores_quantiles, + all_ww_scores_quantiles |> + dplyr::filter(scenario == "status_quo") + ) |> + dplyr::left_join(table_of_loc_dates_w_ww, + by = c("location", "forecast_date") + ) |> + dplyr::filter( + isTRUE(ww_sufficient) + ) |> + dplyr::left_join( + convergence_df, + by = c( + "location", + "forecast_date" + ) + ) + ) +) + -# Head-to-head-scenario targets------------------------------------------------ -head_to_head_scenario_targets <- list( +# Scenario targets------------------------------------------------ +scenario_targets <- list( tar_target( name = all_raw_scores, command = data.table::as.data.table( @@ -262,18 +377,7 @@ head_to_head_scenario_targets <- list( name = all_errors, command = dplyr::bind_rows(all_hosp_errors, all_ww_errors) ), - tar_target( - name = all_flags, - command = dplyr::bind_rows(all_flags_ww, all_flags_hosp) - ), - tar_target( - name = nonconverge_df, - command = all_flags |> distinct( - location, forecast_date, model, scenario - ) |> - dplyr::mutate(convergence = FALSE) # Add a column that indicates did not pass - # convergence. We will use this in the head-to-head comparison to statify by convergence - ), + ## Raw scores----------------------------------------- # These are the scores from each scenario and location without buffering @@ -653,7 +757,8 @@ hub_comparison_plots <- list( list( upstream_targets, combined_targets, - head_to_head_scenario_targets, + head_to_head_targets, + scenario_targets, hub_targets, hub_comparison_plots ) diff --git a/cfaforecastrenewalww/R/get_data.R b/cfaforecastrenewalww/R/get_data.R index cd257a8b..29d64b86 100644 --- a/cfaforecastrenewalww/R/get_data.R +++ b/cfaforecastrenewalww/R/get_data.R @@ -751,6 +751,11 @@ site_level_inf_inits <- function(train_data, params, stan_data) { autoreg_rt = abs(rnorm(1, autoreg_rt_a / (autoreg_rt_a + autoreg_rt_b), 0.05)), log_r_mu_intercept = rnorm(1, convert_to_logmean(1, 0.1), convert_to_logsd(1, 0.1)), error_site = matrix(rnorm(n_subpops * n_weeks, mean = 0, sd = 0.1), n_subpops, n_weeks), + zeta_bar = abs(matrix( + rnorm(n_subpops * (uot + ot + ht), mean = 0, sd = 0.1), + n_subpops, (uot + ot + ht) + )), + cv = abs(rnorm(1, 0.025, 0.005)), autoreg_rt_site = abs(rnorm(1, 0.5, 0.05)), autoreg_p_hosp = abs(rnorm(1, 1 / 100, 0.001)), sigma_rt = abs(rnorm(1, 0, 0.01)), diff --git a/cfaforecastrenewalww/inst/stan/renewal_ww_hosp_site_level_inf_dynamics.stan b/cfaforecastrenewalww/inst/stan/renewal_ww_hosp_site_level_inf_dynamics.stan index e07ef35a..3cf42492 100644 --- a/cfaforecastrenewalww/inst/stan/renewal_ww_hosp_site_level_inf_dynamics.stan +++ b/cfaforecastrenewalww/inst/stan/renewal_ww_hosp_site_level_inf_dynamics.stan @@ -6,6 +6,50 @@ functions { #include functions/infections.stan #include functions/observation_model.stan #include functions/utils.stan +#include functions/expgamma_lpdf.stan + +real gamma3_lpdf(vector y, vector mean, real cv) { + int n = num_elements(y); + real alpha = 1 / (cv^2); + vector[n] beta = 1 / (mean * (cv^2)); + return gamma_lpdf(y | alpha, beta); +} + +real gamma3_sum_lpdf(row_vector y, real mean, real cv, vector N) { + int n = num_elements(y); + vector[n] alpha = N / (cv^2); // sum of gammas with same shape and scale + real beta = 1 / (mean * (cv^2)); + return gamma_lpdf(y | alpha, beta); +} + +/** + * Efficient dot product on log scale + */ + real log_dot_product(vector x, vector y) { + return(log_sum_exp(x + y)); + } + + /** +* Convolution of a time series for T time steps on log scale +** +* @param f The weight function, e.g. the incubation period distribution +* +* @param g The time series to be convolved, e.g. infections +* +* @return The convolved log time series. The first length(f)-1 elements are NA +* because the convolved values can only be computed starting from length(f). +*/ +vector log_convolve(vector f, vector g) { + int f_length = num_elements(f); + int g_length = num_elements(g); + vector[g_length] fg; + for (t in f_length:g_length) { + fg[t] = log_dot_product(f, g[(t-f_length+1):t]); + } + return(fg); +} + + } @@ -118,6 +162,8 @@ parameters { real autoreg_rt_site; real autoreg_p_hosp; matrix[n_subpops, n_weeks] error_site; // matrix of subpopulations + matrix[n_subpops, uot + ot + ht] zeta_bar; // total number of genomes shed for all incident infections at time t + real cv; // coefficient of variation in individual dispersion real i0_over_n; // initial per capita // infection incidence vector[n_subpops] eta_i0; // z-score on logit scale of state @@ -156,23 +202,35 @@ transformed parameters { vector[owt] exp_obs_log_v_true = rep_vector(0, owt); // expected observations at each site in log scale vector[owt] exp_obs_log_v = rep_vector(0, owt); // expected observations at each site with modifier in log scale vector[n_ww_lab_sites] ww_site_mod; // site specific WW mod - row_vector [ot + uot + ht] model_net_i; // number of net infected individuals shedding on each day (sum of individuals in dift stages of infection) + // row_vector [ot + uot + ht] model_net_i; // number of net infected individuals shedding on each day (sum of individuals in dift stages of infection) real phi_h = inv_square(inv_sqrt_phi_h); vector[n_ww_lab_sites] sigma_ww_site; vector[n_weeks] log_r_mu_t_in_weeks; // log of state level mean R(t) in weeks vector[n_weeks] log_r_site_t_in_weeks; // log of site level mean R(t) in weeks, used as a placeholder in loop vector[ot + ht] unadj_r; // state level R(t) before damping matrix[n_subpops, ot+ht] r_site_t; // site_level R(t) + matrix[n_subpops, uot+ot+ht] total_g; // total number of genomes shed in each site at each time + matrix[n_subpops, uot+ot+ht] log_total_g; row_vector[ot + ht] unadj_r_site_t; // site_level R(t) before damping row_vector[ot + uot + ht] new_i_site; // site level incident infections per capita + //matrix[n_subpops, uot + ot + ht] shape_g_bar; // the shape pararameter for the sum of gamma distributed ind genomes + //matrix[n_subpops, uot + ot + ht] scale_g_bar; // the theta pararameter for the sum of gamma distributed ind genomes real pop_fraction; // proportion of state population that the subpopulation represents vector[ot + uot + ht] state_inf_per_capita = rep_vector(0, uot + ot + ht); // state level incident infections per capita matrix[n_subpops, ot + ht] model_log_v_ot; // expected observed viral genomes/mL at all observed and forecasted times - real g = pow(log10_g, 10); // Estimated genomes shed per infected individual + matrix[n_subpops, uot + ot + ht] i_site_t; // number of new infections at each time point in each subpopulation + real mu_g = exp(log(10)*log10_g); // Estimated genomes shed per infected individual + //print("Mean number of genomes per infection:", mu_g); real i0 = i0_over_n * state_pop; // Initial absolute infection incidence vector[n_subpops] i0_site_over_n; // site-level initial // per capita infection incidence vector[n_subpops] growth_site; + // Start with basically no individual variation + //g_i ~ Gamma(kappa = exp(log_g), theta = 1) roughly + // eventually we will want to change this to add variance so maybe + // g_i ~ Gamma(kappa = exp(2), theta = exp(log_g - 2)) + //mean = exp(27), variance = exp(2+25*2) + // hard coding the coefficient of variation for now... // State-leve R(t) AR + RW implementation: @@ -221,13 +279,38 @@ transformed parameters { pop_fraction = subpop_size[i] / norm_pop; state_inf_per_capita += pop_fraction * to_vector(new_i_site); - model_net_i = to_row_vector(convolve_dot_product(to_vector(new_i_site), - reverse(s), (uot + ot + ht))); - - - model_log_v_ot[i] = log(10) * log10_g + - log(model_net_i[(uot+1):(uot + ot + ht) ] + 1e-8) - - log(mwpd); + // Sum of iid gammas over all infections + // g_bar_i ~ Gamma(I(t)*kappa, theta) + //G_i = \sum_{t=0}^{t=tau} s(tau)g_bar_i(t-tau) Convolving shedding kinetics + i_site_t[i] = new_i_site.*subpop_size[i]; + //print("mu_g*i_site_t: ", mu_g*i_site_t[1, (uot+2)]); + //print("new infections in first site at 2nd time point: ", i_site_t[1, uot + 2]); + // with RV that is sum of gammas representing number of genomes at each time + // point. Doing this manually to start! + // This is actually sum of iids of relative shedding intensities + total_g[i] = to_row_vector(convolve_dot_product(to_vector(exp(zeta_bar[i])), + reverse(s), (uot + ot + ht))); + log_total_g[i] = log(total_g[i]); + //print("Log(total_g): ", log10_g*log(10)+ log(total_g[i, uot +2])); + //log_total_g[i] = to_row_vector( + // log_convolve( + // reverse(log(s)), + // log10_g*log(10) + to_vector(zeta_bar[i]) + // ) + //); + //print("log_total_g: ", log_total_g[i,uot+2]); + //print("Output after convolution of number of genomes from day of peak shedding:", total_g[1, uot +7]); + // log(C_i(t)) = log(G_i(t)/(alpha*N_i)) + model_log_v_ot[i] = log_total_g[i, (uot+1):(uot + ot + ht)] - log(mwpd) - log(subpop_size[i]); + + + // model_net_i = to_row_vector(convolve_dot_product(to_vector(new_i_site), + // reverse(s), (uot + ot + ht))); + // + // + // model_log_v_ot[i] = log(10) * log10_g + + // log(model_net_i[(uot+1):(uot + ot + ht) ] + 1e-8) - + // log(mwpd); } @@ -282,6 +365,7 @@ model { autoreg_p_hosp ~ beta(autoreg_p_hosp_a, autoreg_p_hosp_b); log_r_mu_intercept ~ normal(r_logmean, r_logsd); to_vector(error_site) ~ std_normal(); + cv ~ normal(0.1, 0.025); sigma_rt ~ normal(0, sigma_rt_prior); i0_over_n ~ beta(i0_over_n_prior_a, i0_over_n_prior_b); @@ -292,8 +376,8 @@ model { eta_growth ~ std_normal(); initial_growth ~ normal(initial_growth_prior_mean, initial_growth_prior_sd); inv_sqrt_phi_h ~ normal(inv_sqrt_phi_prior_mean, inv_sqrt_phi_prior_sd); - sigma_ww_site_mean ~ normal(sigma_ww_site_prior_mean_mean, sigma_ww_site_prior_mean_sd); - sigma_ww_site_sd ~ normal(sigma_ww_site_prior_sd_mean, sigma_ww_site_prior_sd_sd); + sigma_ww_site_mean ~ normal(0.1*sigma_ww_site_prior_mean_mean, 0.5*sigma_ww_site_prior_mean_sd); + sigma_ww_site_sd ~ normal(sigma_ww_site_prior_sd_mean, 0.1*sigma_ww_site_prior_sd_sd); sigma_ww_site_raw ~ std_normal(); log10_g ~ normal(log10_g_prior_mean, log10_g_prior_sd); hosp_wday_effect ~ normal(effect_mean, wday_effect_prior_sd); @@ -310,6 +394,16 @@ model { //Compute log likelihood if (compute_likelihood == 1) { if (include_ww == 1) { + // for (i in 1:n_subpops){ + // g_bar[i] ~ gamma(1/(0.025^2), 1/(mu_g*i_site_t[i]*(0.025^2))); // + // } + for (i in 1:n_subpops){ + for (j in 1:(uot + ot + ht)){ + zeta_bar[i,j] ~ expgamma(i_site_t[i,j]/(cv^2), cv^2); // Sum of iid gamma individual relative shedding intensities + } + + } + //print("Output of gamma of number of genomes from first day:", g_bar[1, (uot+2)]); // Both genomes/person/day and observation error are now vectors //log_conc ~ normal(exp_obs_log_v, sigma_ww_site[ww_sampled_lab_sites]); // if non-censored: P(log_conc | expected log_conc) @@ -335,6 +429,7 @@ generated quantities { vector[uot + ot + ht] state_model_net_i; vector [n_subpops] site_i0_over_n_start; vector[ot + ht] rt; // state level R(t) + real log_genomes_shed_per_inf = log(gamma_rng( 1/(cv^2), 1/(mu_g*(cv^2)))); for(i in 1:n_subpops) { site_i0_over_n_start[i] = i0_site_over_n[i] * diff --git a/model_definition.md b/model_definition.md index e2f04e3a..f1e8b178 100644 --- a/model_definition.md +++ b/model_definition.md @@ -119,7 +119,9 @@ where $\gamma$ is the _infection feedback term_ controlling the strength of the Following other semi-mechanistic renewal frameworks, we model the _expected_ hospital admissions per capita $H(t)$ as a convolution of the _expected_ latent incident infections per capita $I(t)$, and a discrete infection to hospitalization distribution $d(\tau)$, scaled by the probability of being hospitalized $p_\mathrm{hosp}(t)$. -To account for day-of-week effects in hospital reporting, we use an estimated _weekday effect_ $\omega(t)$. If $t$ and $t'$ are the same day of the week, $\omega(t) = \omega(t')$. The seven values that $\omega(t)$ takes on are constrained to have mean 1. +To account for day-of-week effects in hospital reporting, we use an estimated _weekday effect_ $\omega(t)$. If $t$ and $t'$ are the same day of the week, $\omega(t) = \omega(t')$. +The seven values that $\omega(t)$ takes on are constrained to be non-negative and have a mean of 1. +This allows us to model the possibility that certain days of the week could have systematically high or low admissions reporting while holding the predicted weekly total reported admissions constant (i.e. the predicted weekly total is the same with and without these day-of-week reporting effects). $$H(t) = \omega(t) p_\mathrm{hosp}(t) \sum_{\tau = 0}^{T_d} d(\tau) I(t-\tau)$$ diff --git a/setup_container.R b/setup_container.R new file mode 100644 index 00000000..b6b989e0 --- /dev/null +++ b/setup_container.R @@ -0,0 +1,37 @@ +options( + ## make HTTP requests + ## identify us correctly + ## to Posit package manager + ## so we get appropriate + ## precompiled binaries + ## see https://docs.posit.co/rspm/1.0.12/admin/binaries.html#binaries-r-configuration + HTTPUserAgent = sprintf( + "R/%s R (%s)", + getRversion(), + paste( + getRversion(), + R.version["platform"], R.version["arch"], + R.version["os"] + ) + ), + ## use Posit package manager to get + ## precompiled binaries where possible + repos = c( + RSPM = "https://packagemanager.posit.co/cran/__linux__/jammy/latest" + ) +) + +additional_deps <- c( + "argparser" +) + +install.packages("pak") +pak::pkg_install("local::wweval") +pak::pkg_install(additional_deps) +cmdstanr::install_cmdstan() +dir.create("stanmodels") +cfaforecastrenewalww::compile_model( + "cfaforecastrenewalww/inst/stan/renewal_ww_hosp_site_level_inf_dynamics.stan", + "cfaforecastrenewalww/inst/stan", + "stanmodels" +) diff --git a/src/setup_eval.R b/src/setup_eval.R index a37deb0d..f810631a 100644 --- a/src/setup_eval.R +++ b/src/setup_eval.R @@ -14,13 +14,14 @@ write_eval_config( "PA", "PR", "RI", "SC", "SD", "TN", "TX", "UT", "VA", "VT", "WA", "WI", "WV", "WY" ), - forecast_dates = as.character( - seq( - from = lubridate::ymd("2023-10-16"), - to = lubridate::ymd("2024-03-11"), - by = "week" - ) - ), + forecast_dates = + as.character( + seq( + from = lubridate::ymd("2023-10-16"), + to = lubridate::ymd("2024-03-11"), + by = "week" + ) + ), scenarios = c( "status_quo" ), diff --git a/wweval/NAMESPACE b/wweval/NAMESPACE index 6645aa3b..4362c268 100644 --- a/wweval/NAMESPACE +++ b/wweval/NAMESPACE @@ -13,6 +13,7 @@ export(eval_post_process_ww) export(exclude_hosp_outliers) export(format_for_hub) export(get_box_plot) +export(get_convergence_df) export(get_diagnostic_flags) export(get_filepath) export(get_full_scores) @@ -42,6 +43,7 @@ export(get_stan_data_list) export(get_state_level_quantiles) export(get_state_level_ww_quantiles) export(get_subpop_data) +export(get_table_sufficient_ww) export(get_ww_data_indices) export(get_ww_data_sizes) export(get_ww_values) diff --git a/wweval/R/combine_outputs.R b/wweval/R/combine_outputs.R index f8462046..f899bc50 100644 --- a/wweval/R/combine_outputs.R +++ b/wweval/R/combine_outputs.R @@ -68,5 +68,9 @@ combine_outputs <- function(output_type = ) } } + + if (nrow(combined_output) == 0) { + combined_output <- NULL + } return(combined_output) } diff --git a/wweval/R/eval_post_process.R b/wweval/R/eval_post_process.R index 21070f1a..b25938d7 100644 --- a/wweval/R/eval_post_process.R +++ b/wweval/R/eval_post_process.R @@ -56,8 +56,8 @@ eval_post_process_ww <- function(config_index, save_object("ww_summary", output_file_suffix) errors <- ww_fit_obj$error save_object("errors", output_file_suffix) - flags <- ww_fit_obj$flags - save_object("flags", output_file_suffix) + raw_flags <- data.frame(ww_fit_obj$flags) + save_object("raw_flags", output_file_suffix) # Save errors save_table( data_to_save = errors, @@ -68,6 +68,14 @@ eval_post_process_ww <- function(config_index, model_type = "ww", location = location ) + # Get evaluation data from hospital admissions and wastewater + # Join draws and flags with data and metadata + flags <- raw_flags |> dplyr::mutate( + scenario = scenario, + forecast_date = forecast_date, + model_type = "ww", + location = location + ) # Save flags save_table( data_to_save = flags, @@ -78,8 +86,7 @@ eval_post_process_ww <- function(config_index, model_type = "ww", location = location ) - # Get evaluation data from hospital admissions and wastewater - # Join draws with data + hosp_draws <- { if (is.null(ww_raw_draws)) { NULL @@ -291,8 +298,8 @@ eval_post_process_hosp <- function(config_index, save_object("hosp_summary", output_file_suffix) errors <- hosp_fit_obj$error save_object("errors", output_file_suffix) - flags <- hosp_fit_obj$flags - save_object("flags", output_file_suffix) + raw_flags <- data.frame(hosp_fit_obj$flags) + save_object("raw_flags", output_file_suffix) # Save errors save_table( data_to_save = errors, @@ -303,6 +310,16 @@ eval_post_process_hosp <- function(config_index, model_type = "hosp", location = location ) + + + # Get evaluation data from hospital admissions and wastewater + # Join draws with flags + data and metadata + flags <- raw_flags |> dplyr::mutate( + scenario = scenario, + forecast_date = forecast_date, + model_type = "hosp", + location = location + ) # Save flags save_table( data_to_save = flags, @@ -313,10 +330,6 @@ eval_post_process_hosp <- function(config_index, model_type = "hosp", location = location ) - - - # Get evaluation data from hospital admissions and wastewater - # Join draws with data hosp_model_hosp_draws <- get_model_draws_w_data( model_output = "hosp", model_type = "hosp", diff --git a/wweval/R/get_table_sufficient_ww.R b/wweval/R/get_table_sufficient_ww.R new file mode 100644 index 00000000..980113a4 --- /dev/null +++ b/wweval/R/get_table_sufficient_ww.R @@ -0,0 +1,80 @@ +#' Get table of location-forecast dates with sufficient wastewater +#' +#' @description +#' This function takes in a large dataframe containing the quantiled estimated +#' and forecasted wastewater concentrations for each site and lab in each +#' location and forecast date, joined with the observed data on each day, in +#' each site and lab, from each forecast date. We will use the data to +#' get a table of location-forecast-dates where wastewater data was considered +#' sufficient to inform a forecast, based on the critera that was used for +#' the Hub submissions. +#' +#' @param ww_quantiles The dataframe of the quantiled ww predictions with the +#' real data alongside it (labeled as calib data) +#' @param delay_thres The maximum number of days of delay between the last +#' wastewater data point and the forecat date, before we would flag a state as +#' having insufficient wastewater data to inform a forecast. Default is 21 +#' @param n_dps_thres The threshold number of data points within a single site +#' within a state before we would flag the state as having insufficient +#' wastewater data to inform a forecast. Default is 5 +#' @param prop_below_lod_thres The threshold proportion of wastewater data +#' points that can be below the LOD. If greater than this proportion of points +#' are below the LOD, we flag the state as having insufficient wastewater data. +#' Default is 0.5 +#' @param sd_thres The minimum standard deviation between wastewater data points +#' within a site. This is intended to catch when a site reports all the same +#' values. Default is 0.1 +#' @param mean_log_ww_value_thres The minimum value of the log of the ww +#' concentration, default is -4 +#' +#' @return table_of_loc_dates_w_ww a tibble containing the location, +#' forecast_date, and column stating the wastewater data was sufficient for all +#' locations and forecast dates where ww data was deemed sufficient +#' @export +#' +get_table_sufficient_ww <- function(ww_quantiles, + delay_thres = 21, + n_dps_thres = 5, + prop_below_lod_thres = 0.5, + sd_thres = 0.1, + mean_log_ww_value_thres = -4) { + calib_data <- ww_quantiles |> + dplyr::distinct( + location, site, lab, lab_site_index, + date, forecast_date, below_LOD, calib_data + ) |> + dplyr::filter(!is.na(calib_data)) + + diagnostic_table <- calib_data |> + dplyr::group_by(location, forecast_date) |> + dplyr::summarize( + last_date = max(date), + n_dps = dplyr::n(), + prop_below_lod = sum(below_LOD == 1) / dplyr::n(), + sd = sd(calib_data), + mean_log_ww = mean(log(calib_data)) + ) |> + dplyr::mutate( + flag_delay = as.integer(forecast_date - last_date) > delay_thres, + flag_n_dps = n_dps < n_dps_thres, + flag_lod = prop_below_lod > prop_below_lod_thres, + flag_sd = sd < sd_thres, + flag_low_val = mean_log_ww < mean_log_ww_value_thres + ) + + flag_table_long <- diagnostic_table |> + dplyr::ungroup() |> + tidyr::pivot_longer(starts_with("flag")) + + # Ensure all `values` are boolean + stopifnot( + "In diagnostic table checking for sufficent wastewater data flags, not all values are boolean" = + is.logical(flag_table_long$value) + ) + + table_of_loc_dates_w_ww <- flag_table_long |> + dplyr::group_by(location, forecast_date) |> + dplyr::summarise(ww_sufficient = !any(value)) + + return(table_of_loc_dates_w_ww) +} diff --git a/wweval/R/model_run_diagnostic_flags.R b/wweval/R/model_run_diagnostic_flags.R index c05c09f3..d7cfa0fe 100644 --- a/wweval/R/model_run_diagnostic_flags.R +++ b/wweval/R/model_run_diagnostic_flags.R @@ -51,3 +51,36 @@ get_diagnostic_flags <- function(stan_fit_object, ) return(flag_df) } + + +#' Get convergence dataframe +#' @description This function takes the larger dataframe of convergence +#' flags for each location, forecast date, and scenario and checks if any of +#' the flags are TRUE, and returns a dataframe with just a column indicating +#' whether any flags are true +#' +#' +#' @param all_flags a dataframe containing the flags for each location, +#' forecast_date, and scenario +#' @param scenario The scenario to filter to, since some eval output will include multiple +#' scenarios +#' +#' @return a dataframe with a column `any_flags` indicating whether any of the +#' flags in the original full descriptive set of congerence flags are TRUE. +#' @export +#' +get_convergence_df <- function(all_flags, + scenario) { + convergence_df <- all_flags |> + dplyr::filter(scenario == {{ scenario }}) |> + tidyr::gather(key, value, starts_with("flag")) |> + dplyr::group_by(location, forecast_date, scenario, model_type) |> + dplyr::mutate(any_flags = any(value == TRUE)) |> + tidyr::spread(key, value) |> + dplyr::ungroup() |> + dplyr::select( + location, forecast_date, any_flags + ) + + return(convergence_df) +} diff --git a/wweval/R/sample_model.R b/wweval/R/sample_model.R index d038ac52..34c11afb 100644 --- a/wweval/R/sample_model.R +++ b/wweval/R/sample_model.R @@ -107,7 +107,6 @@ sample_model <- function(standata, } else { # Get the diagnostics using thresholds set in production pipeline flag_df <- get_diagnostic_flags(fit$result, n_chains, iter_sampling) - any_flags <- any(flag_df) draws <- fit$result$draws() diagnostics <- fit$result$sampler_diagnostics(format = "df") @@ -118,14 +117,9 @@ sample_model <- function(standata, draws = draws, diagnostics = diagnostics, summary_diagnostics = summary_diagnostics, - summary = summary + summary = summary, + flags = list(flag_df) ) - - if (any_flags) { # If there are model convergence issues, pass - # flags alongside the draws and summaries - out <- c(out, flags = list(flag_df)) - message("Model convergence issues") - } } return(out) } diff --git a/wweval/man/get_convergence_df.Rd b/wweval/man/get_convergence_df.Rd new file mode 100644 index 00000000..4efe80b9 --- /dev/null +++ b/wweval/man/get_convergence_df.Rd @@ -0,0 +1,25 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/model_run_diagnostic_flags.R +\name{get_convergence_df} +\alias{get_convergence_df} +\title{Get convergence dataframe} +\usage{ +get_convergence_df(all_flags, scenario) +} +\arguments{ +\item{all_flags}{a dataframe containing the flags for each location, +forecast_date, and scenario} + +\item{scenario}{The scenario to filter to, since some eval output will include multiple +scenarios} +} +\value{ +a dataframe with a column \code{any_flags} indicating whether any of the +flags in the original full descriptive set of congerence flags are TRUE. +} +\description{ +This function takes the larger dataframe of convergence +flags for each location, forecast date, and scenario and checks if any of +the flags are TRUE, and returns a dataframe with just a column indicating +whether any flags are true +} diff --git a/wweval/man/get_table_sufficient_ww.Rd b/wweval/man/get_table_sufficient_ww.Rd new file mode 100644 index 00000000..59bbd57a --- /dev/null +++ b/wweval/man/get_table_sufficient_ww.Rd @@ -0,0 +1,53 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/get_table_sufficient_ww.R +\name{get_table_sufficient_ww} +\alias{get_table_sufficient_ww} +\title{Get table of location-forecast dates with sufficient wastewater} +\usage{ +get_table_sufficient_ww( + ww_quantiles, + delay_thres = 21, + n_dps_thres = 5, + prop_below_lod_thres = 0.5, + sd_thres = 0.1, + mean_log_ww_value_thres = -4 +) +} +\arguments{ +\item{ww_quantiles}{The dataframe of the quantiled ww predictions with the +real data alongside it (labeled as calib data)} + +\item{delay_thres}{The maximum number of days of delay between the last +wastewater data point and the forecat date, before we would flag a state as +having insufficient wastewater data to inform a forecast. Default is 21} + +\item{n_dps_thres}{The threshold number of data points within a single site +within a state before we would flag the state as having insufficient +wastewater data to inform a forecast. Default is 5} + +\item{prop_below_lod_thres}{The threshold proportion of wastewater data +points that can be below the LOD. If greater than this proportion of points +are below the LOD, we flag the state as having insufficient wastewater data. +Default is 0.5} + +\item{sd_thres}{The minimum standard deviation between wastewater data points +within a site. This is intended to catch when a site reports all the same +values. Default is 0.1} + +\item{mean_log_ww_value_thres}{The minimum value of the log of the ww +concentration, default is -4} +} +\value{ +table_of_loc_dates_w_ww a tibble containing the location, +forecast_date, and column stating the wastewater data was sufficient for all +locations and forecast dates where ww data was deemed sufficient +} +\description{ +This function takes in a large dataframe containing the quantiled estimated +and forecasted wastewater concentrations for each site and lab in each +location and forecast date, joined with the observed data on each day, in +each site and lab, from each forecast date. We will use the data to +get a table of location-forecast-dates where wastewater data was considered +sufficient to inform a forecast, based on the critera that was used for +the Hub submissions. +} From 844ed246cd6e01d1efb93d96990883ed55647de7 Mon Sep 17 00:00:00 2001 From: kaitejohnson Date: Mon, 24 Jun 2024 15:55:50 -0400 Subject: [PATCH 2/2] add scratch files --- scratch/site_level_inf_dynamics.Rmd | 1354 +++++++++++++++++++++++++ scratch/toy_example_ind_variation.Rmd | 683 +++++++++++++ 2 files changed, 2037 insertions(+) create mode 100644 scratch/site_level_inf_dynamics.Rmd create mode 100644 scratch/toy_example_ind_variation.Rmd diff --git a/scratch/site_level_inf_dynamics.Rmd b/scratch/site_level_inf_dynamics.Rmd new file mode 100644 index 00000000..f71654d5 --- /dev/null +++ b/scratch/site_level_inf_dynamics.Rmd @@ -0,0 +1,1354 @@ +--- +title: "Site level hierarchical model" +author: "Kaitlyn Johnson" +date: "2023-09-01" +output: html_document +--- + +```{r setup, include=FALSE} +library(cfaforecastrenewalww) +library(cmdstanr) +library(lubridate) +library(ggplot2) +library(dplyr) +library(here) +library(tidybayes) +library(bayesplot) +source(here::here("src/write_config.R")) +knitr::opts_chunk$set(echo = TRUE) +cfaforecastrenewalww::setup_secrets(here::here("secrets.yaml")) +``` + +# Motivation +- This is an implementation of a renewal approach for inference of hospitalization and wastewater viral concentrations by estimating a time-varying effective reproductive number $R(t)$. Specifically +this Rmd is focused on implementing a hierarchical site-level infectious disease +dynamics model. This assumes that at each time step, the $R(t)$ at each site is +drawn from a shared lognormal distribution with mean $\mu_{R(t)state}$. +This means functionally that each site is allowed to have its own disease dyanmics +that are informed by information from the other sites and constrained by the +overall state-level hospital admissions. We will +continue to allow for site level observation error and a site level multiplier +on the observed wastewater concentration. + +# Approach +- For each state, we will treat the combination of sites/measurements as observations +from the underlying wastewater concentration in the whole date +- For the hospital admissions data, we use state-level total hospital admission data +that was available as of the forecast date, using the covidcast package and the +`as_of` input. We impose a 9 day reporting delay to mirror the current data scenario. + + +Note: This Rmd mirrors the single model, single location, single forecast date pipeline +set up in the `_targets.R` file. We will use the Rmd format to continue to refine the model. + + +# Get the state level data from NWSS +```{r} +ww_data_path <- save_timestamped_nwss_data( + ww_path_to_save = + here::here( + "input", "ww_data", + "nwss_data" + ) +) + +config_written <- write_config( + save_config = TRUE, + config_path = here::here("input", "config"), + location = "MA", + prod_run = FALSE, + run_id = "test", + date_run = lubridate::today(), + model_type = "site-level infection dynamics", + forecast_date = "2024-04-22", + hosp_data_source = "NHSN", + pull_from_local = FALSE, + hosp_data_dir = here::here("input", "hosp_data", "vintage_datasets"), + population_data_path = here::here("input", "locations.csv"), + param_file_path = here::here("input", "params.toml"), + include_ww = 1, + hosp_reporting_delay = 5, + ww_data_path = here::here( + "input", "ww_data", + "nwss_data", "2024-04-21.csv" + ) +) + + +config_vars_ss <- get_config_vals(config_written) +# Site level WW data from NWSS + +ww_data_raw <- do.call(get_ww_data, config_vars_ss) + +ww_data <- ww_data_raw %>% filter(location == config_vars_ss$location) +``` +# Subsample sites +We want to be able to compare the model forecasts (R(t) and hospital admissions forecasts) for a state +with a large number of sites fit to all of them vs fit to a subset of them, and see if this meaningfully impacts results. +```{r} +ww_data_subsampled <- subsample_sites(ww_data, + prop_sites = 0.2 +) +``` + +```{r} +# Get training data +train_data_orig <- do.call( + get_all_training_data, + c(list(ww_data_raw = ww_data), config_vars_ss, + subsample_sites = 1, + prop_sites = 0.2 + ) +) +train_data <- flag_ww_outliers(train_data_orig) +train_data <- train_data %>% mutate( + site_lab = stringr::str_glue("Site: {site}, Lab: {lab}") +) + +ggplot(train_data %>% filter(period != "forecast", !is.na(site))) + + geom_line(aes(x = date, y = ww, color = as.factor(site), group = site), + show.legend = FALSE + ) + + geom_point(aes(x = date, y = ww, color = as.factor(site), group = site), + show.legend = FALSE + ) + + theme_bw() + + facet_wrap(~site) + + xlab("") + + ylab("Genome copies per mL") + + scale_y_continuous(trans = "log") + + guides(color = guide_legend(title = "Site")) + + scale_x_date( + date_breaks = "2 weeks", + labels = scales::date_format("%Y-%m-%d") + ) +ggplot(train_data %>% filter(period != "forecast", !is.na(site))) + + geom_line(aes(x = date, y = ww, color = as.factor(site_lab), group = site_lab)) + + geom_point(aes(x = date, y = ww, color = as.factor(site_lab), group = site), + show.legend = FALSE + ) + + geom_vline(aes(xintercept = forecast_date), linetype = "dashed") + + theme_bw() + + # facet_wrap(~site_lab, scales = "free")+ + xlab("") + + ylab("Genome copies per mL") + + # scale_y_continuous(trans = "log") + + guides(color = guide_legend(title = "")) + + scale_x_date( + date_breaks = "2 weeks", + labels = scales::date_format("%Y-%m-%d") + ) + + theme( + axis.text.x = element_text( + size = 10, vjust = 1, + hjust = 1, angle = 45 + ), + axis.title.x = element_text(size = 10), + axis.title.y = element_text(size = 10), + plot.title = element_text( + size = 9, + vjust = 0.5, hjust = 0.5 + ) + ) + + +ggplot(train_data) + + geom_density(aes( + x = log(ww), + fill = as.factor(flag_as_ww_outlier) + ), alpha = 0.3) + + theme_bw() + + guides(fill = guide_legend(title = "Outlier?")) + +ggplot(train_data) + + geom_density(aes(x = log(ww), ), alpha = 0.3, fill = "gray") + + theme_bw() +``` +
+# Model description + +We have observed data from 90 days of hospital admissions and 14 weeks of wastewater measurements for a single state. We are going to make all of the simplifying assumptions: +- the population shedding into the WW is the same as those captured by the hospital admissions +- the IHR is constant during the period of inference +- the shedding kinetics and shedding amount are constant during the period of inference +- for now, we will assume that the number of genomes shed per individual is constant +across infected individuals, since the population is large enough that we don't expect +inter-individual variability to add too much noise to the signal + +This model builds heavily off of (and in some instances borrows functions from) the +[EpiNow2 R package](https://github.com/epiforecasts/EpiNow2) + +## Renewal Equation Stan Model +In brief, this model assumes that incident infections $I(t)$ are generated from a vector of $R(t)$ values and an initial number of infections $I(0)$ seeded on day 0: +$$I(t) = R(t) \sum_{s=1}^{t}I(t-s)g(s)$$ +Where $g(s)$ is the generation interval, which describes the distribution of times +from incident infection to secondary infection (i.e. infectiousness profile) and $R(t)$ describes the number of expected secondary infections of an index case at time $t$. Because the data does not start at the initial seeding event, we assume that prior to the first observation initial infections grow exponentially to give rise to the early observations such that $I(t') = I(0)exp(rt')$ where $t'$ is the "unobserved time" before the first observed data point.
+ +The model estimates $R(t)$ by estimating a weekly random walk such that $R(k) = R(k-1)\eta$, where $\eta ~N(0, \eta_{sd}$, $k$ is the week, and the magnitude of the step size, $eta_{sd}$ is estimated in the model calibration. For the $R(t)$ during the forecast period, we assume $R(t)$ is dampened in proportion to the number of cumulative infections following an SIR approximation, which is implemented when damp_type =1 (see EpiNow2 documentation for details) or we assume that the R(t) is dampened by the current number of infections with a drift term, per Jason's Ebola model. We will likely modify both of these significantly, particularly if we observe consisteny over or underprediction in the forecasts.
+ + +We model the process of generating two type of observations. First the hospital admissions: Each infectee is assumed to have some (independent) chance of getting hospitalizated $p_{hosp}$, and *if* they will eventually be hospitalized then the probability distribution for days after infection before observation and reporting is $d(t)$. Therefore, conditional on some time series of state level daily infections generated by one of the models, $I_{state\mu}(t)$, the *expected* number of state-level reported hospitalizations on each day $t$ is, + +$$H[t] = p_{hosp}\sum_{\tau\geq 0}d[\tau] I_{state\mu}[t-\tau]$$ +We estimate $p_{hosp}$ and for now are setting the hospital admissions delay distribution $d(t)$. We assume that the observation process is negative binomial, such that: +$$\overline{H}[t] \sim NegBinom(H[t], \phi_h)$$ +
+ +In a similar fashion for generated the expected WW observations, we assume that each individual follows the same shaped shedding kinetics distribution $S(t)$. We assume the shedding kinetics follows a hinge function +in log10 space. +$$\begin{align*} +log10(S[t]) = + + \begin{cases} + \frac{V_{peak}}{t_{peak}}t & t\leq t_{peak} \\ + V_{peak} + wt_{peak} - wt & t \geq t_{peak} \\ + \end{cases} + \end{align*} +$$ + +Where $V_{peak}$ is the log10 peak viral load that occurs at $t_{peak}$ days since infection onset, and $w$ is the rate at which viral load wanes after $t_{peak}$. We convert to natural scale with $S[t] = 10^{log10(S[t])}$. The parameters of this hinge function are estimated (with relatively strong priors taken from fecal shedding literature and literature on viral loads in nasal passages). The individual level viral kinetics $S[t]$ of each infected individual are normalized to sum to 1 over the course of the infection and are multiplied by $G$, the number of genome copies shed by each infected individual throughout their infection. Therefore, conditional on some time series of daily infections in each site $I_j(t)$, the *expected* number of viral genomes shed in WW on each day $t$ is, + +$$V_{j}[t] = \ G\sum_{\tau\geq 0} S(\tau)I_j(t-\tau)$$ +Where $\sum_{\tau\geq 0} S(\tau)I_j(t-\tau)$ is the net number of infected +individuals in site $j$ on each day $t$, which we will call $\iota_j(t)$. +While $S[t]$ describes the kinetics of how an infected individual sheds over the course of +their infection, there is also individual variabilty in the total amount of virus shed per infected individual over the course of their infection, $G$.
+ +## Dispersion in viral genomes shed per infected individual +We will assume that an individual sheds $G_i$ genomes per infection and $G_i$ is negative binomially distributed with mean $\mu_G$ and dispersion $phi$, +$$ G_i \sim NegBinom(\mu_G, \phi) $$ +Then we can write the expected sum of the genomes shed from the infected individuals in site $j$ +as, $G_j(t) = \sum_{i=1}^\iota_j(t) G_i$: +$$ G_j(t) \sim NegBinom(\iota_j(t)\mu_G, \iota_j(t)\phi) $$ +where $N_j$ is the number of infected individuals in each site on each day. This is is a result of a fun trick that the sum of Negative Binomials are Negative Binomials. +What it means is rather intuitive, +that as you have a larger population of infected individuals shedding into the WW, +you observe lower overall dispersion ($\phi$ increases). $G_j(t)$ is the expected +number of viral genomes shed in each site on each day. To implement this, we will +use the Gaussian approximation described [here](https://academic.oup.com/jrsssa/article/185/Supplement_1/S65/7069481?login=true#rssa12971-sec-0020). + + +In practice, each site +processes samples differently, which adds both a site-level scaling factor $M_j$ +and site level variability, which we will call $\phi_j$. If we assume observation +are also Negative Binomial, then the observed genomes in each site on each day +$\overline{G}_j(t)$ can be defined as: +$$ \overline{G}_j(t) \sim NegBinom(M_jG_j(t), \phi_j) $$ +This defines a model with a site-level scaling factor and a site-level measurement error. +Our observations will be in terms of genomes per person per day, so $\overline{G*}_j(t) = \frac{\overline{G}_j(t)}{N_j}$ where $N_j$ is the population in the WW catchment area. Note the observation errors don't have to be Negative Binomially distributed, could also use +something like a student t. +
+ + + + +## Hierarchical infectious disease dynamics +We will assume that at each time step, the site level infections $I_j(t) = R_j(t) \sum_{s=1}^{t}I_j(t-s)g(s)$ are drawn from a shared distribution of state-level $R(t)$ such that: +$$ R_j(t) \sim logNormal(R_{state\mu}(t), \sigma_{R(t)})$$ + + + +# Assign parameters +```{r} +# This contains all the model priors and parameter settings. Informative priors +# include the shedding viral kinetic parameters (timing of peak viral shedding, +# magnitude of peak viral shedding, and duration of viral shedding). These +# were informed in combination from literature on fecal shedding and viral loads +# in the nasal package. +params <- get_params(config_vars_ss$param_file_path) # +print(params) + +params$p_hosp_w_sd_sd <- 0.05 +# Pull the prior hyperparameters into the global environment so you can create +# the init_fun to be passed to stan +par_names <- colnames(params) +for (i in seq_along(par_names)) { + assign(par_names[i], as.double(params[i])) +} + + +stan_data <- do.call(get_stan_data_site_level_model, c( + config_vars_ss, + list(params = params), + list(train_data = train_data) +)) + +pop <- train_data %>% + select(pop) %>% + unique() %>% + pull(pop) +stopifnot("More than one population size in training data" = length(pop) == 1) + +n_weeks <- as.numeric(stan_data$n_weeks) +tot_weeks <- as.numeric(stan_data$tot_weeks) +n_ww_sites <- as.numeric(stan_data$n_ww_sites) +n_ww_lab_sites <- as.numeric(stan_data$n_ww_lab_sites) +ot <- stan_data$ot +ht <- stan_data$ht + +# Estimate of number of initial infections +i0 <- mean(train_data$daily_hosp_admits[1:7]) / p_hosp_mean +``` + + + + +# Initialize the parameter search using center of the priors + a bit of noise +```{r} +init_fun <- function() { + site_level_inf_inits(train_data, params, stan_data) +} +``` + +# Compile the model +```{r, echo = FALSE} +model_file_path <- here::here( + "cfaforecastrenewalww", "inst", + "stan", "renewal_ww_hosp_site_level_inf_dynamics.stan" +) + +model <- compile_model(model_filepath = model_file_path) +``` + +# Fit the model +```{r, echo = FALSE} +fit_dynamic_rt <- model$sample( + data = stan_data, + seed = 123, + init = init_fun, + iter_sampling = 100, + iter_warmup = 50, + chains = 4, + parallel_chains = 4 +) +``` +```{r} +stanfit <- rstan::read_stan_csv(fit_dynamic_rt$output_files()) +posterior_cp <- as.array(stanfit) +available_mcmc(pattern = "_nuts_") +np_cp <- nuts_params(stanfit) +mcmc_pairs(posterior_cp, + np = np_cp, pars = c("log10_g", "cv", "sigma_ww_site_mean", "sigma_ww_site_sd"), + off_diag_args = list(size = 0.75) +) +``` + + +Quick test +```{r} +all_draws <- fit_dynamic_rt$draws() + + +log_total_g <- all_draws %>% + spread_draws(log_total_g[site_index, t]) %>% + rename(value = log_total_g) %>% + mutate( + draw = `.draw`, + name = "log_total_g" + ) %>% + select(name, t, value, site_index, draw) + + +autoreg_rt_site <- all_draws %>% + spread_draws(autoreg_rt_site) %>% + mutate(draw = `.draw`) %>% + mutate( + name = "autoreg_rt_site", + t = NA + ) %>% + rename(value = autoreg_rt_site) %>% + select(name, t, value, draw) + +prior_autoreg_rt_site <- data.frame(value = rbeta(1000, autoreg_rt_a, autoreg_rt_b)) +prior_autoreg_rt_site <- data.frame(value = rbeta(1000, 1, 4)) + +sigma_rt <- all_draws %>% + spread_draws(sigma_rt) %>% + sample_draws(ndraws = 100) %>% + mutate(draw = row_number()) %>% + mutate( + name = "sigma_rt", + ) %>% + rename(value = sigma_rt) %>% + select(name, value, draw) + +sigma_rt_prior <- data.frame(value = rnorm(1000, 0, params$sigma_rt_prior)) +ggplot() + + geom_density(data = sigma_rt_prior, aes(x = value), fill = "blue", alpha = 0.1) + + geom_density(data = sigma_rt, aes(x = value), fill = "red", alpha = 0.1) + + theme_bw() + + ggtitle("Average deviation of site and state R(t)") + + coord_cartesian(xlim = c(0, 1)) + +ggplot() + + geom_density(data = prior_autoreg_rt_site, aes(x = value), fill = "blue", alpha = 0.1) + + geom_density(data = autoreg_rt_site, aes(x = value), fill = "red", alpha = 0.1) + + theme_bw() + + ggtitle("Autoreg coefficient in site level R(t) deviation") + + coord_cartesian(xlim = c(0, 0.8)) +``` +```{r} +r_state <- all_draws %>% + spread_draws(rt[t]) %>% + rename(value = rt) %>% + mutate( + draw = `.draw`, + name = "r_state" + ) %>% + select(name, t, value, draw) %>% + group_by(t, name) %>% + summarise( + median_Rt = quantile(value, 0.5, na.rm = TRUE), + lb = quantile(value, 0.025, na.rm = TRUE), + ub = quantile(value, 0.975, na.rm = TRUE), + ub_50th = quantile(value, 0.75, na.rm = TRUE), + lb_50th = quantile(value, 0.25, na.rm = TRUE) + ) %>% + left_join(train_data %>% select(forecast_date, t, date) %>% distinct()) + + +site_map <- train_data %>% + select(site, site_index, ww_pop) %>% + distinct() %>% + filter(!is.na(site)) + +if (sum(stan_data$subpop_size) < stan_data$state_pop) { + subpop_map <- site_map %>% + rbind(c(1, stan_data$subpop_size[stan_data$n_subpops])) %>% + rename(subpop_size = ww_pop) +} else { + subpop_map <- site_map %>% + rename(subpop_size = ww_pop) +} + +r_site <- all_draws %>% + spread_draws(r_site_t[site_index, t]) %>% + rename(value = r_site_t) %>% + mutate( + draw = `.draw`, + name = "r_site" + ) %>% + select(name, t, value, site_index, draw) %>% + group_by(t, site_index, name) %>% + summarise( + median_Rt = quantile(value, 0.5, na.rm = TRUE), + lb = quantile(value, 0.025, na.rm = TRUE), + ub = quantile(value, 0.975, na.rm = TRUE), + ub_50th = quantile(value, 0.75, na.rm = TRUE), + lb_50th = quantile(value, 0.25, na.rm = TRUE) + ) %>% + left_join(subpop_map) %>% + left_join( + train_data %>% + select(forecast_date, t, date) %>% + distinct() + ) %>% + mutate( + site_name = paste0("Site: ", site) + ) + +ggplot(r_state) + + geom_line(aes(x = date, y = median_Rt)) + + geom_ribbon(aes(x = date, ymin = lb, ymax = ub), alpha = 0.2) + + geom_ribbon(aes(x = date, ymin = lb_50th, ymax = ub_50th), alpha = 0.2) + + geom_hline(aes(yintercept = 1), linetype = "dashed") + + ylab("State R(t)") + + xlab("") + # scale_y_continuous(trans = "log") + + # coord_cartesian(ylim = c(0.5, 2.2))+ + theme_bw() + +ggplot(r_site) + + geom_line(aes(x = date, y = median_Rt, color = as.factor(site_name))) + + geom_ribbon(aes(x = date, ymin = lb, ymax = ub, fill = as.factor(site_name)), alpha = 0.2) + + geom_ribbon( + aes( + x = date, + ymin = lb_50th, ymax = ub_50th, + fill = as.factor(site_name) + ), + alpha = 0.2 + ) + + geom_hline(aes(yintercept = 1), linetype = "dashed") + + facet_wrap(~site_name) + + ylab("Site R(t)") + + xlab("") + + theme_bw() + +sites_to_display <- train_data |> + dplyr::select(site) |> + dplyr::filter(!is.na(site)) |> + unique() |> + pull() |> + head(2) + +ggplot(r_site %>% filter(site %in% sites_to_display)) + + geom_line(aes(x = date, y = median_Rt, color = as.factor(site_name)), + show.legend = FALSE + ) + + geom_ribbon(aes(x = date, ymin = lb, ymax = ub, fill = as.factor(site_name)), + alpha = 0.2, show.legend = FALSE + ) + + geom_ribbon(aes(x = date, ymin = lb_50th, ymax = ub_50th, fill = as.factor(site_name)), + alpha = 0.2, show.legend = FALSE + ) + + geom_hline(aes(yintercept = 1), linetype = "dashed") + + facet_wrap(~site_name) + + ylab("Site R(t)") + + xlab("") + + theme_bw() + +ggplot() + + geom_line( + data = r_site %>% filter(site %in% sites_to_display[1]), + aes(x = date, y = median_Rt), color = "darkorange2", + show.legend = FALSE + ) + + geom_ribbon( + data = r_site %>% filter(site %in% sites_to_display[1]), + aes(x = date, ymin = lb, ymax = ub), fill = "darkorange2", + alpha = 0.2, show.legend = FALSE + ) + + geom_ribbon( + data = r_site %>% filter(site %in% sites_to_display[1]), + aes(x = date, ymin = lb_50th, ymax = ub_50th), fill = "darkorange2", + alpha = 0.2, show.legend = FALSE + ) + + geom_line( + data = r_site %>% filter(site %in% sites_to_display[2]), + aes(x = date, y = median_Rt), color = "darkgreen", + show.legend = FALSE + ) + + geom_ribbon( + data = r_site %>% filter(site %in% sites_to_display[2]), + aes(x = date, ymin = lb, ymax = ub), fill = "darkgreen", + alpha = 0.2, show.legend = FALSE + ) + + geom_ribbon( + data = r_site %>% filter(site %in% sites_to_display[2]), + aes(x = date, ymin = lb_50th, ymax = ub_50th), fill = "darkgreen", + alpha = 0.2, show.legend = FALSE + ) + + geom_hline(aes(yintercept = 1), linetype = "dashed") + + ylab("Site R(t)") + + xlab("") + + scale_y_continuous(trans = "log") + + coord_cartesian(ylim = c(0.5, 2.2)) + + theme_bw() +``` + +Time-varying IHR +```{r} +p_hosp <- all_draws %>% + spread_draws(p_hosp[t]) %>% + sample_draws(ndraws = 20) %>% + rename(value = p_hosp) %>% + mutate( + draw = `.draw`, + name = "p_hosp" + ) %>% + select(name, t, value, draw) + +ggplot(p_hosp) + + geom_line(aes(x = t, y = value, group = draw), size = 0.1) + + xlab("Time (days)") + + ylab("IHR(t)") + +p_hosp_mean <- all_draws %>% + spread_draws(p_hosp_mean) %>% + mutate(draw = row_number()) %>% + mutate( + name = "p_hosp_mean", + ) %>% + mutate(value = plogis(p_hosp_mean)) %>% + select(name, value, draw) +ggplot(p_hosp_mean) + + aes(x = value) + + stat_halfeye() + + xlab("Mean of IHR") + +p_hosp_w_sd <- all_draws %>% + spread_draws(p_hosp_w_sd) %>% + mutate(draw = row_number()) %>% + mutate( + name = "p_hosp_w_sd", + ) %>% + rename(value = p_hosp_w_sd) %>% + select(name, value, draw) +ggplot(p_hosp_w_sd) + + aes(x = value) + + stat_halfeye() + + xlab("Stdev of step size of IHR(t)") + +cv <- all_draws %>% + spread_draws(cv) %>% + mutate(draw = row_number()) %>% + mutate( + name = "cv", + ) %>% + rename(value = cv) %>% + select(name, value, draw) +ggplot(cv) + + aes(x = value) + + stat_halfeye() + + xlab("Coefficient of variation in individual dispersion in number of genomes shed") +mu_g <- all_draws %>% + spread_draws(mu_g) %>% + mutate(draw = row_number()) %>% + mutate( + name = "mu_g", + ) %>% + rename(value = mu_g) %>% + select(name, value, draw) +ggplot(mu_g) + + aes(x = value) + + stat_halfeye() + + xlab("Mean number of genomes shed per infection") + +log_genomes_shed_per_inf <- all_draws %>% + spread_draws(log_genomes_shed_per_inf) %>% + mutate(draw = row_number()) %>% + mutate( + name = "log_genomes_shed_per_inf", + ) %>% + rename(value = log_genomes_shed_per_inf) %>% + select(name, value, draw) +ggplot(log_genomes_shed_per_inf) + + aes(x = value) + + stat_halfeye() + + xlab("Log of individual's number of genomes shed per infection ") + +genomes_shed_per_inf <- all_draws %>% + spread_draws(log_genomes_shed_per_inf) %>% + mutate(draw = row_number()) %>% + mutate( + name = "genomes_shed_per_inf", + ) %>% + mutate(value = exp(log_genomes_shed_per_inf)) %>% + select(name, value, draw) +ggplot(genomes_shed_per_inf) + + aes(x = value) + + stat_halfeye() + + xlab("Individual's number of genomes shed per infection ") +``` +Hospital admissions +```{r} +hosp_state <- all_draws %>% + spread_draws(pred_hosp[t]) %>% + rename(value = pred_hosp) %>% + mutate( + draw = `.draw`, + name = "pred_hosp" + ) %>% + select(name, t, value, draw) %>% + group_by(t, name) %>% + summarise( + median_hosp = quantile(value, 0.5, na.rm = TRUE), + lb = quantile(value, 0.025, na.rm = TRUE), + ub = quantile(value, 0.975, na.rm = TRUE), + ub_50th = quantile(value, 0.75, na.rm = TRUE), + lb_50th = quantile(value, 0.25, na.rm = TRUE) + ) %>% + left_join(train_data %>% select(forecast_date, daily_hosp_admits, t, date) %>% distinct()) + + +ggplot(hosp_state) + + geom_line(aes(x = date, y = median_hosp)) + + geom_point(aes(x = date, y = daily_hosp_admits)) + + geom_ribbon(aes(x = date, ymin = lb, ymax = ub), alpha = 0.2) + + geom_ribbon(aes(x = date, ymin = lb_50th, ymax = ub_50th), alpha = 0.2) + + geom_vline(aes(xintercept = forecast_date), linetype = "dashed") + + ylab("State-level hospital admissions") + + xlab("") + + theme_bw() + +hosp_state_draws <- all_draws %>% + spread_draws(pred_hosp[t]) %>% + sample_draws(ndraws = 100) %>% + rename(value = pred_hosp) %>% + mutate( + draw = `.draw`, + name = "pred_hosp" + ) %>% + select(name, t, value, draw) %>% + left_join(train_data %>% select(forecast_date, daily_hosp_admits, t, date) %>% distinct()) + +ggplot(hosp_state_draws) + + geom_line(aes(x = date, y = value, group = draw), size = 0.1, alpha = 0.1) + + geom_point(aes(x = date, y = daily_hosp_admits)) + + geom_vline(aes(xintercept = forecast_date), linetype = "dashed") + + ylab("State-level hospital admissions") + + xlab("") + + theme_bw() +``` +Site-lab concentrations + +```{r} +lab_site_map <- train_data %>% + dplyr::select(lab_wwtp_unique_id, lab_site_index, site, lab) %>% + dplyr::mutate( + lab_site = glue::glue("Site: {site} Lab: {lab}") + ) %>% + dplyr::distinct() + +conc <- all_draws %>% + spread_draws(pred_ww[lab_site_index, t]) %>% + rename(value = pred_ww) %>% + mutate( + draw = `.draw`, + name = "log_conc" + ) %>% + select(name, t, value, lab_site_index, draw) %>% + group_by(t, lab_site_index, name) %>% + summarise( + median_conc = quantile(value, 0.5, na.rm = TRUE), + lb = quantile(value, 0.025, na.rm = TRUE), + ub = quantile(value, 0.975, na.rm = TRUE), + ub_50th = quantile(value, 0.75, na.rm = TRUE), + lb_50th = quantile(value, 0.25, na.rm = TRUE) + ) %>% + left_join(lab_site_map) %>% + left_join( + train_data %>% + select(forecast_date, t, date) %>% + distinct() + ) %>% + left_join( + train_data %>% select( + ww, t, lab_site_index, ww_pop, below_LOD, lod_sewage, + flag_as_ww_outlier + ), + by = c("t", "lab_site_index") + ) + +ggplot(conc) + + geom_line(aes(x = date, y = median_conc, color = lab_site), show.legend = FALSE) + + geom_ribbon(aes(x = date, ymin = lb, ymax = ub, fill = lab_site), + alpha = 0.2, show.legend = FALSE + ) + + geom_point(aes(x = date, y = log(ww)), show.legend = FALSE) + + geom_ribbon(aes(x = date, ymin = lb_50th, ymax = ub_50th, fill = lab_site), + alpha = 0.2, show.legend = FALSE + ) + + geom_vline(aes(xintercept = forecast_date), linetype = "dashed") + + facet_wrap(~lab_site) + + ylab("Site Concentration(t)") + + xlab("") + + theme_bw() + +ggplot(data = conc %>% filter(site == sites_to_display[1])) + + geom_line(aes(x = date, y = median_conc), color = "darkorange2") + + geom_ribbon(aes(x = date, ymin = lb, ymax = ub), fill = "darkorange2", alpha = 0.2) + + geom_point(aes(x = date, y = log(ww))) + + geom_ribbon(aes(x = date, ymin = lb_50th, ymax = ub_50th), fill = "darkorange2", alpha = 0.2) + + geom_vline(aes(xintercept = forecast_date), linetype = "dashed") + + facet_wrap(~lab_site) + + ylab("Site-level genome copies per mL") + + xlab("") + + theme_bw() + +ggplot(data = conc %>% filter(site == sites_to_display[2])) + + geom_line(aes(x = date, y = median_conc), color = "darkgreen") + + geom_ribbon(aes(x = date, ymin = lb, ymax = ub), fill = "darkgreen", alpha = 0.2) + + geom_point(aes(x = date, y = log(ww))) + + geom_ribbon(aes(x = date, ymin = lb_50th, ymax = ub_50th), fill = "darkgreen", alpha = 0.2) + + geom_vline(aes(xintercept = forecast_date), linetype = "dashed") + + facet_wrap(~lab_site) + + ylab("Site-level genome copies per mL") + + xlab("") + + theme_bw() +``` + + + +```{r} +log_r_state <- all_draws %>% + spread_draws(log_r_mu_t_in_weeks[t]) %>% + rename(value = log_r_mu_t_in_weeks) %>% + mutate( + draw = `.draw`, + name = "log_r_state" + ) %>% + select(name, t, value, draw) + +log_r_site <- all_draws %>% + spread_draws(log_r_site_t_in_weeks[t]) %>% + rename(value = log_r_site_t_in_weeks) %>% + mutate( + draw = `.draw`, + name = "log_r_state" + ) %>% + select(name, t, value, draw) + +sigma_rt <- all_draws %>% + spread_draws(sigma_rt) %>% + sample_draws(ndraws = 100) %>% + mutate(draw = row_number()) %>% + mutate( + name = "sigma_rt", + ) %>% + rename(value = sigma_rt) %>% + select(name, value, draw) +ggplot(sigma_rt) + + aes(x = value) + + stat_halfeye() + + xlab("Standard deviation between site and state level R(t)s") + +infection_feedback <- all_draws %>% + spread_draws(infection_feedback) %>% + sample_draws(ndraws = 100) %>% + mutate(draw = row_number()) %>% + mutate( + name = "infection_feedback", + ) %>% + rename(value = infection_feedback) %>% + select(name, value, draw) +ggplot(infection_feedback) + + aes(x = value) + + stat_halfeye() + + xlab("Infection feedback") + +autoreg_rt_site <- all_draws %>% + spread_draws(autoreg_rt_site) %>% + sample_draws(ndraws = 100) %>% + mutate(draw = row_number()) %>% + mutate( + name = "autoreg_rt_site", + ) %>% + rename(value = autoreg_rt_site) %>% + select(name, value, draw) +ggplot(autoreg_rt_site) + + aes(x = value) + + stat_halfeye() + + xlab("AR term on Rt between sites and state mean") + +autoreg_rt <- all_draws %>% + spread_draws(autoreg_rt) %>% + sample_draws(ndraws = 100) %>% + mutate(draw = row_number()) %>% + mutate( + name = "autoreg_rt", + ) %>% + rename(value = autoreg_rt) %>% + select(name, value, draw) +ggplot(autoreg_rt) + + aes(x = value) + + stat_halfeye() + + xlab("AR term on Rt between previous time step R(t)") +``` + + +# Look at the generated quantitities + +```{r} +all_draws <- fit_dynamic_rt$draws() + +# Dataframes with ndraws (long format) +hosp_draws <- all_draws %>% + spread_draws(pred_hosp[t]) %>% + rename(value = pred_hosp) %>% + mutate( + draw = row_number(), + name = "pred_hosp" + ) %>% + select(name, t, value, draw) + + +site_map <- train_data %>% + select(site, site_index, ww_pop) %>% + distinct() +lab_site_map <- train_data %>% + select(lab_wwtp_unique_id, lab_site_index, site) %>% + distinct() + +ww_draws <- all_draws %>% + spread_draws(pred_ww[lab_site_index, t]) %>% + rename(value = pred_ww) %>% + group_by(lab_site_index, t) %>% + mutate( + draw = row_number(), + name = "pred_ww", + value = exp(value) + ) %>% + select(name, lab_site_index, t, value, draw) + +sampled_draws <- sample(length(unique(ww_draws$draw)), 100) + +# Gather R(t) +site_level_rt <- all_draws %>% + spread_draws(r_site_t[site_index, t]) %>% + rename(value = r_site_t) %>% + group_by(site_index, t) %>% + mutate( + draw = row_number(), + name = "R_site_t" + ) %>% + ungroup() %>% + select(name, site_index, t, value, draw) %>% + filter(draw %in% sampled_draws) %>% + left_join(site_map, by = "site_index") %>% + left_join( + train_data %>% + select(-ww, -site, -site_index, -ww_pop) %>% + distinct() + ) %>% + left_join( + train_data %>% + select(ww, site_index, t), + by = c("t", "site_index") + ) + +site_level_rt_summary <- site_level_rt %>% + group_by(t, site_index) %>% + summarise( + site_level_rt_median = quantile(value, 0.5), + site_level_rt_lb = quantile(value, 0.025), + site_level_rt_ub = quantile(value, 0.975) + ) %>% + left_join(site_map, by = "site_index") %>% + left_join( + train_data %>% + select(-ww, -site, -site_index, -ww_pop) %>% + distinct() + ) %>% + left_join(train_data %>% select(ww, site_index, t), by = c("t", "site_index")) + + +state_rt <- all_draws %>% + spread_draws(rt[t]) %>% + rename(value = rt) %>% + group_by(t) %>% + mutate( + draw = row_number(), + name = "rt" + ) %>% + select(name, t, value, draw) %>% + left_join(train_data %>% select(forecast_date, t, date, period) %>% distinct()) + +state_rt_summary <- state_rt %>% + group_by(t) %>% + summarise( + exp_rt_median = quantile(value, 0.5), + exp_rt_lb = quantile(value, 0.025), + exp_rt_ub = quantile(value, 0.975) + ) %>% + left_join( + train_data %>% + select(forecast_date, t, date) %>% + distinct() + ) + +state_rt <- state_rt %>% + filter(draw %in% sampled_draws) + + +ggplot(site_level_rt_summary %>% filter(date <= forecast_date)) + + geom_line(aes(x = date, y = site_level_rt_median, color = as.factor(site))) + + geom_ribbon( + aes( + x = date, + ymin = site_level_rt_lb, ymax = site_level_rt_ub, + fill = as.factor(site) + ), + alpha = 0.2 + ) + + geom_line( + data = state_rt_summary %>% filter(date <= forecast_date), + aes(x = date, y = exp_rt_median), color = "black" + ) + + geom_ribbon( + data = state_rt_summary %>% filter(date <= forecast_date), + aes( + x = date, ymin = exp_rt_lb, + ymax = exp_rt_ub + ), fill = "black", alpha = 0.1 + ) + + facet_wrap(~site) + + ylab("Estimated R(t)") + + xlab("") + + ggtitle("Estimated site level R(t)") + + theme_bw() + + +# Cross-sectional R(t) d +d <- 3 +site_level_rt_end <- site_level_rt %>% + filter(date == forecast_date - days(d)) %>% + left_join( + site_level_rt %>% + filter(date == forecast_date - days(d)) %>% + group_by(site) %>% + summarise(median_rt = quantile(value, 0.5)), + by = "site" + ) +state_rt_end <- state_rt %>% + filter(date == forecast_date - days(d)) %>% + left_join( + state_rt %>% + filter(date == forecast_date - days(d)) %>% + summarise(median_rt = quantile(value, 0.5)), + by = "t" + ) + + +ggplot() + + geom_density( + data = site_level_rt_end, + aes(x = value, fill = as.factor(site), group = as.factor(site)), alpha = 0.3 + ) + + geom_vline(data = site_level_rt_end, aes(xintercept = median_rt, color = as.factor(site))) + + geom_density( + data = state_rt_end, + aes(x = value), fill = "black", alpha = 0.3 + ) + + facet_wrap(~site) + + geom_vline(data = state_rt_end, aes(xintercept = median_rt), color = "black") + + ggtitle("R(t) estimates at forecast date by site and overall") +``` + + + + +Site-level ww genome copies per mL vs observations +```{r} +ww_draws_w_data <- ww_draws %>% + filter(draw %in% sampled_draws) %>% + left_join(lab_site_map, by = "lab_site_index") %>% + left_join(train_data %>% select( + date, location, pop, daily_hosp_admits_for_eval, + daily_hosp_admits, t, hosp_reporting_delay, + period, + forecast_date, day_of_week, include_ww + ) %>% + distinct()) %>% + left_join( + train_data %>% select( + ww, t, lab_site_index, ww_pop, below_LOD, lod_sewage, + flag_as_ww_outlier + ), + by = c("t", "lab_site_index") + ) + +sites_to_plot <- unique(lab_site_map$lab_wwtp_unique_id) + +ggplot(ww_draws_w_data %>% filter( + date <= forecast_date + days(7), + lab_wwtp_unique_id %in% sites_to_plot +)) + + geom_line( + aes( + x = date, y = value, group = draw, + color = as.factor(lab_wwtp_unique_id) + ), + linewidth = 0.1, alpha = 0.1, show.legend = FALSE + ) + + geom_line( + data = ww_draws_w_data %>% filter( + period != "forecast", + lab_wwtp_unique_id %in% sites_to_plot + ), + aes( + x = date, y = value, group = draw, + color = as.factor(lab_wwtp_unique_id) + ), + linewidth = 0.1, alpha = 0.3, show.legend = FALSE + ) + + geom_point( + data = ww_draws_w_data %>% filter( + period != "forecast", + lab_wwtp_unique_id %in% sites_to_plot + ), + aes(x = date, y = lod_sewage, group = draw), color = "navy" + ) + + geom_point(aes(x = date, y = ww), size = 1, shape = 21) + + geom_point( + data = ww_draws_w_data %>% filter(flag_as_ww_outlier == 1), + aes(x = date, y = ww), size = 1.5, shape = 21, color = "purple" + ) + + geom_point( + data = ww_draws_w_data %>% filter(below_LOD == 1), + aes(x = date, y = ww), size = 1.5, shape = 21, color = "red" + ) + + geom_point( + data = ww_draws_w_data %>% filter( + period != "forecast", + lab_wwtp_unique_id %in% sites_to_plot + ), + aes(x = date, y = ww), + fill = "black", size = 1, shape = 21 + ) + + geom_vline(aes(xintercept = forecast_date), linetype = "dashed") + + facet_wrap(~lab_wwtp_unique_id, scales = "free") + + theme_bw() + + xlab("") + + ylab("Genome copies/mL") + +ggplot(ww_draws_w_data %>% filter( + date <= forecast_date + days(7), + lab_wwtp_unique_id %in% sites_to_plot +)) + + geom_line( + aes( + x = date, y = log(value), group = draw, + color = as.factor(lab_wwtp_unique_id) + ), + linewidth = 0.1, alpha = 0.1, show.legend = FALSE + ) + + geom_line( + data = ww_draws_w_data %>% filter( + period != "forecast", + lab_wwtp_unique_id %in% sites_to_plot + ), + aes( + x = date, y = log(value), group = draw, + color = as.factor(lab_wwtp_unique_id) + ), + linewidth = 0.1, alpha = 0.3, show.legend = FALSE + ) + + geom_point( + data = ww_draws_w_data %>% filter( + period != "forecast", + lab_wwtp_unique_id %in% sites_to_plot + ), + aes(x = date, y = log(lod_sewage), group = draw), color = "navy" + ) + + geom_point(aes(x = date, y = log(ww)), size = 1, shape = 21) + + geom_point( + data = ww_draws_w_data %>% filter(flag_as_ww_outlier == 1), + aes(x = date, y = log(ww)), size = 1.5, shape = 21, color = "purple" + ) + + geom_point( + data = ww_draws_w_data %>% filter(below_LOD == 1), + aes(x = date, y = log(ww)), size = 1.5, shape = 21, color = "red" + ) + + geom_point( + data = ww_draws_w_data %>% filter( + period != "forecast", + lab_wwtp_unique_id %in% sites_to_plot + ), + aes(x = date, y = log(ww)), + fill = "black", size = 1, shape = 21 + ) + + geom_vline(aes(xintercept = forecast_date), linetype = "dashed") + + facet_wrap(~lab_wwtp_unique_id, scales = "free") + + theme_bw() + + xlab("") + + ylab("log(genome copies/mL)") + +sites_to_plot <- ww_draws_w_data %>% + ungroup() %>% + select(site) %>% + filter(!is.na(site)) %>% + unique() %>% + pull() +ggplot(ww_draws_w_data %>% filter( + date <= forecast_date + days(7), + site %in% sites_to_plot +)) + + geom_line( + aes( + x = date, y = log(value), group = draw, + color = as.factor(site) + ), + linewidth = 0.1, alpha = 0.1, show.legend = FALSE + ) + + geom_line( + data = ww_draws_w_data %>% filter( + period != "forecast", + site %in% sites_to_plot + ), + aes( + x = date, y = log(value), group = draw, + color = as.factor(site) + ), + linewidth = 0.1, alpha = 0.3, show.legend = FALSE + ) + + geom_point(aes(x = date, y = log(ww)), size = 1, shape = 21) + + geom_point( + data = ww_draws_w_data %>% filter(flag_as_ww_outlier == 1), + aes(x = date, y = log(ww)), size = 1.5, shape = 21, color = "purple" + ) + + geom_point( + data = ww_draws_w_data %>% filter(below_LOD == 1), + aes(x = date, y = log(ww)), size = 1.5, shape = 21, color = "red" + ) + + geom_point( + data = ww_draws_w_data %>% filter( + period != "forecast", + site %in% sites_to_plot + ), + aes(x = date, y = log(ww)), + fill = "black", size = 1, shape = 21 + ) + + geom_vline(aes(xintercept = forecast_date), linetype = "dashed") + + facet_wrap(~site, scales = "free") + + theme_bw() + + xlab("") + + ylab("log(genome copies/mL)") + + ggtitle("Observed and estimated site-level wastewater concentrations") +``` + +```{r} +ww_site_modifier_draws <- all_draws %>% + spread_draws(ww_site_mod[site_index]) %>% + # sample_draws(ndraws = n_draws) %>% + rename(value = ww_site_mod) %>% + mutate( + draw = `.draw`, + name = "ww site multiplier" + ) %>% + select(name, site_index, value, draw) %>% + left_join(site_map) + +i0_ww_site <- all_draws %>% + spread_draws(i0_site_over_n[site_index]) %>% + rename(value = i0_site_over_n) %>% + mutate( + draw = `.draw`, + name = "i0/N ww site" + ) %>% + select(name, site_index, value, draw) %>% + left_join(site_map) + +growth_site <- all_draws %>% + spread_draws(growth_site[site_index]) %>% + rename(value = growth_site) %>% + mutate( + draw = `.draw`, + name = "initial growth rate" + ) %>% + select(name, site_index, value, draw) %>% + left_join(site_map) + +i_at_site <- all_draws %>% + spread_draws(site_i0_over_n_start[site_index]) %>% + rename(value = site_i0_over_n_start) %>% + mutate( + draw = `.draw`, + name = "site i0 over N at start" + ) %>% + select(name, site_index, value, draw) %>% + left_join(site_map) + +ggplot(ww_site_modifier_draws) + + geom_density(aes( + x = exp(value), fill = as.factor(site), + group = site + ), alpha = 0.5) + + theme_bw() + + xlab("Site level multiplier") + +ggplot(i0_ww_site) + + geom_density(aes( + x = value, fill = as.factor(site), + group = site + ), alpha = 0.5) + + theme_bw() + + xlab("log(i0/N) in each site") + +ggplot(i_at_site) + + geom_density(aes( + x = log(value), fill = as.factor(site), + group = site + ), alpha = 0.5) + + theme_bw() + + xlab("log(i0/N) in each site at start of R(t)") + +ggplot(growth_site) + + geom_density(aes( + x = value, fill = as.factor(site), + group = site + ), alpha = 0.5) + + theme_bw() + + xlab("Initial growth rate in each site") + +ww_site_sigma_draws <- all_draws %>% + spread_draws(sigma_ww_site[site_index]) %>% + # sample_draws(ndraws = n_draws) %>% + rename(value = sigma_ww_site) %>% + mutate( + draw = row_number(), + name = "ww site sigma" + ) %>% + select(name, site_index, value, draw) %>% + left_join(site_map) + +ww_site_sigma_summary <- ww_site_sigma_draws %>% + group_by(site) %>% + summarise( + sigma_median = quantile(value, 0.5), + sigma_lb = quantile(value, 0.025), + sigma_ub = quantile(value, 0.975) + ) %>% + left_join(ww_site_sigma_draws %>% select(name, site, ww_pop) %>% distinct(), + by = "site" + ) + + +ggplot(ww_site_sigma_draws) + + geom_density(aes( + x = value, fill = as.factor(site), + group = site + ), alpha = 0.5) + + scale_x_log10() + + theme_bw() + + xlab("Site level standard deviation") + +theme_set(theme_ggdist()) +ggplot(data = ww_site_sigma_draws, aes(y = as.factor(ww_pop), x = value)) + + stat_halfeye() + + ylab("Catchment area population") + + xlab("Phi") + + ggtitle("WW site-level stdev vs population size") + + +ggplot(ww_site_sigma_summary) + + geom_linerange(aes( + x = ww_pop, ymin = sigma_lb, + ymax = sigma_ub + )) + + geom_point(aes(x = ww_pop, y = sigma_median)) + + scale_x_log10() + + theme_bw() + + xlab("Population served by WW catchment") + + ylab("phi of WW site (~1/error)") + +hosp_draws_sampled <- hosp_draws %>% + filter(draw %in% sampled_draws) %>% + left_join(train_data, by = c("t")) + +ggplot(hosp_draws_sampled %>% filter( + date <= forecast_date + days(10) +)) + + geom_line(aes(x = date, y = value, group = draw), + linewidth = 0.1, alpha = 0.1, + color = "darkred", show.legend = FALSE + ) + + geom_line( + data = hosp_draws_sampled %>% filter( + period != "forecast" + ), + aes(x = date, y = value, group = draw), + linewidth = 0.1, alpha = 0.3, + color = "darkred", show.legend = FALSE + ) + + geom_point(aes(x = date, y = daily_hosp_admits_for_eval), + size = 1, shape = 21 + ) + + geom_point( + data = hosp_draws_sampled %>% filter( + period != "forecast" + ), + aes(x = date, y = daily_hosp_admits_for_eval), fill = "black", + size = 1, shape = 21 + ) + + geom_vline(aes(xintercept = forecast_date), linetype = "dashed") + + theme_bw() + + xlab("") + + ylab("Hospital admissions") +``` +
+We now don't expect the sigma inherent to the WW site to be correlated with population size, +because in theory this is taken care of by the sum of negative binomials in the +site level expected true genomes (being a sum of the individuals infected at that +time points dispersion). So the site level phi is now just additional +variability introduced by the site independent of population shedding. diff --git a/scratch/toy_example_ind_variation.Rmd b/scratch/toy_example_ind_variation.Rmd new file mode 100644 index 00000000..051fc3a1 --- /dev/null +++ b/scratch/toy_example_ind_variation.Rmd @@ -0,0 +1,683 @@ +--- +title: "Individual variation toy examples" +author: "Kaitlyn Johnson" +date: "2024-05-13" +output: html_document +--- + +The purpose of this document is to set up a toy example of a model that tries +to estimate the mean and coefficient of variation of an individual random variable +drawn from a gamma distribution using draws from the sum of gammas of a different +and known number of individual components. This relies on the following: +$$ +Y \sim \sum_{i=1}^{i= N} Gamma(\alpha = \frac{1}{cv^2}, \beta = \frac{1}{ \mu cv^2}) = Gamma(\alpha = \frac{N}{cv^2}, \beta = \frac{1}{\mu cv^2}) +$$ +We will first implement this using a linear scale implementation, then will try +using log scale, and then a non-centered normal approximation adopted from +[EpiSewer](https://github.com/adrian-lison/EpiSewer/blob/main/README.md) + +```{r setup, include=FALSE} +knitr::opts_chunk$set(echo = TRUE) +library(rstan) +library(ggplot2) +library(tidybayes) +library(bayesplot) +``` + + +Write a function generate the data with a specified individual mean, individual +coefficient of variation, population sizes of the individual things we're +summing over, and the number of draws from each population size. We'll start +with a simple example where individual components have a mean 1 and coefficient +of variation of 2, and we have observations of the sum of their components from +population sizes of 5, 20, and 100. +```{r} +get_data <- function(mu = 1, + cv = 2, + pop = c(5, 20, 100), + n_draws = 10, + prior_mu = c(0, 1), + prior_cv = c(0, 1), + sigma_obs = 0, + prior_sigma_obs = c(0, 1)) { + pop_vector <- rep(pop, each = n_draws) + y <- rep(0, n_draws * length(pop)) + for (i in seq_along(pop_vector)) { + alpha <- pop_vector[i] / cv^2 + beta <- 1 / (mu * (cv^2)) + y[i] <- rgamma(1, alpha, beta) + # same as y[i] <- sum(rgamma(pop_vector[i], 1 / cv^2, 1 / (mu * (cv^2)))) + if (sigma_obs > 0) { + y[i] <- rlnorm(1, log(y[i]), sigma_obs) + } + } + + data <- list( + N = length(pop_vector), + y = y, + pop_vector = pop_vector, + prior_mu = prior_mu, + prior_cv = prior_cv, + prior_sigma_obs = prior_sigma_obs + ) + + return(data) +} + +standata <- get_data() +``` +# Linear scale implementation +Next we will write a linear scale stan model to fit this data. We will use +`rstan` just for this example because it interfaces well with Rmd and will +give us the diagnostic outputs in the format we want them. +```{stan, output.var = "model_linear"} + +data { + int N; + vector[N] y; + vector[N] pop_vector; + real prior_mu[2]; + real prior_cv[2]; +} + +// The parameters accepted by the model. Our model +// accepts two parameters 'mu' and 'cv', the mean and coefficient of variation +// in the individual R.V.s +parameters { + real mu; // mean of individual component + real cv; // coefficent of variation in individual data +} + + +// The model to be estimated. +model { + // Assume we have a truncated half N prior with a standard deviation of 2 times the mean + // not sure if thats reasonable... + mu ~ normal(prior_mu[1], prior_mu[2]); + cv ~ normal(prior_cv[1], prior_cv[2]); + + //Formula for the sum of N gammas: Y ~ gamma(N*alpha, beta) + y ~ gamma(pop_vector./cv^2, 1 / (mu * (cv^2))); + +} + +``` +Fit the model using `rstan` +```{r} +fit <- sampling(model_linear, + standata, + warmup = 500, + iter = 2000, + chains = 4, + cores = 4, + seed = 42, + init = 0, + control = list(adapt_delta = 0.99, max_treedepth = 10) +) +``` +Analyze the outputs using `bayesplot` and `tidybayes`. +In this example we will make pairs plots, trace plots, and also plots +that compare the posterior estimate to the known true parameter. +```{r} +# Extract posterior draws for later use +posterior <- as.array(fit) +np <- nuts_params(fit) + +mcmc_pairs(posterior, + np = np, pars = c("mu", "cv"), + off_diag_args = list(size = 0.75) +) +color_scheme_set("mix-brightblue-gray") +mcmc_trace(posterior, pars = c("mu", "cv"), np = np) + + xlab("Post-warmup iteration") + + +params <- fit |> + tidybayes::spread_draws(mu, cv) |> + dplyr::mutate( + mu_true = 1, + cv_true = 2 + ) + +ggplot(params, aes(x = mu)) + + stat_halfeye() + + geom_vline(aes(xintercept = mu_true)) + + theme_bw() + + xlab("Mean") + + ylab("") + +ggplot(params, aes(x = cv)) + + stat_halfeye() + + geom_vline(aes(xintercept = cv_true)) + + theme_bw() + + xlab("Coefficient of Variation") + + ylab("") +``` + + +We are a little off the estimate of the `cv`, but overall the convergence diagnostics look pretty good. Let's try to get closer to the wastewater problem by making the orders +of magnitude of the `mu` the number of genomes shed and the `pop` the +number of infected individuals in the population more realistic. +```{r} +cv_set <- 2 +mu_set <- exp(7) +standata <- get_data( + mu = mu_set, + cv = cv_set, + prior_mu = c(1.2 * mu_set, 2 * mu_set), # make priors imperfect but on the same scale + prior_cv = c(0.9 * cv_set, 1), + n_draws = 10, + pop = c(10, 1000, 1e5) +) +df <- data.frame(y = standata$y, pop = rep(standata$pop_vector)) +# Quick plot of the data +ggplot(df) + + geom_point(aes(x = pop, y = y)) + + scale_x_continuous(trans = "log") + + scale_y_continuous(trans = "log") + + +fit <- sampling(model_linear, + standata, + warmup = 500, + iter = 2000, + chains = 4, + cores = 4, + seed = 42, + init = 0, + control = list(adapt_delta = 0.99, max_treedepth = 10) +) +``` + +Looking at the outputs from the fit of the model +```{r} +# Extract posterior draws for later use +posterior <- as.array(fit) +np <- nuts_params(fit) + +mcmc_pairs(posterior, + np = np, pars = c("mu", "cv"), + off_diag_args = list(size = 0.75) +) +color_scheme_set("mix-brightblue-gray") +mcmc_trace(posterior, pars = c("mu", "cv"), np = np) + + xlab("Post-warmup iteration") + + +params <- fit |> + tidybayes::spread_draws(mu, cv) |> + dplyr::mutate( + mu_true = mu_set, + cv_true = cv_set + ) + +ggplot(params, aes(x = mu)) + + stat_halfeye() + + geom_vline(aes(xintercept = mu_true)) + + theme_bw() + + xlab("Mean") + + ylab("") + +ggplot(params, aes(x = cv)) + + stat_halfeye() + + geom_vline(aes(xintercept = cv_true)) + + theme_bw() + + xlab("Coefficient of Variation") + + ylab("") +``` + +These look pretty good! Because of overflow though, we probably still want to +be estimating things on the log scale, because in practice this quantity will +be the total number of genomes shed in the sewershed, which could be +quite large + +# Log scale implementation using `expgamma` function +Let's see if we log transform the data and rescale to estimate +the mean and the relative variation in shedding, if we can improve the estimates + +```{r} +# First write a function to get the data in log scale +get_log_scale_data <- function(log_mu = 1, + cv = 2, + pop = c(5, 20, 100), + n_draws = 10, + log_prior_mu = c(1, 2), + prior_cv = c(0, 1)) { + pop_vector <- rep(pop, each = n_draws) + log_y <- rep(0, n_draws * length(pop)) + mu <- exp(log_mu) + for (i in seq_along(pop_vector)) { + alpha <- pop_vector[i] / cv^2 + beta <- 1 / (mu * (cv^2)) + log_y[i] <- log(rgamma(1, alpha, beta)) + } + + data <- list( + N = length(pop_vector), + log_y = log_y, + pop_vector = pop_vector, + log_prior_mu = log_prior_mu, + prior_cv = prior_cv + ) + + return(data) +} +``` +Now generate the log scaled data +```{r} +mu_set <- exp(7) +cv_set <- 2 +standata <- get_log_scale_data( + log_mu = log(mu_set), + cv = cv_set, + log_prior_mu = c(0.5 * log(mu_set), 2 * log(mu_set)), + prior_cv = c(0, cv_set), + pop = c(10, 1000, 1e5) +) +df <- data.frame(log_y = standata$log_y, pop = rep(standata$pop_vector)) +# Quick plot of the data +ggplot(df) + + geom_point(aes(x = pop, y = log_y)) + + scale_x_continuous(trans = "log") +``` + +Modify the stan model to fit the log data and to estimate separately a a relative +shedding variation and the lean in log scale +```{stan, output.var = "log_model"} +functions{ +real expgamma_lpdf(vector xs, + vector shapes_k, + vector scales_theta){ + + return(sum( + -shapes_k .* log(scales_theta) - + lgamma(shapes_k) + + shapes_k .* xs - (exp(xs) ./ scales_theta))); +} +real expgamma_lpdf(vector x, real shape_k, real scale_theta){ + + return(sum( + -shape_k * log(scale_theta) - + lgamma(shape_k) + + shape_k * x - (exp(x) / scale_theta))); +} + +real expgamma_lpdf(real x, real shape_k, real scale_theta){ + + return( + -shape_k * log(scale_theta) - + lgamma(shape_k) + + shape_k * x - (exp(x) / scale_theta)); +} + +} + + +data { + int N; + vector[N] log_y; + vector[N] pop_vector; + real log_prior_mu[2]; + real prior_cv[2]; +} + + +// The parameters accepted by the model. Our model +// accepts two parameters 'mu' and 'cv', the mean and coefficient of variation +// in the individual R.V.s +parameters { + real log_mu; // mean of individual component + real cv; // coefficent of variation in individual data +} + +transformed parameters{ + real mu = exp(log_mu ); // We can transform the mean before passing to the gamma +} + + +// The model to be estimated. +model { + log_mu ~ normal(log_prior_mu[1], log_prior_mu[2]); // lognormal prior + cv ~ normal(prior_cv[1], prior_cv[2]); + + //Formula for the sum of N gammas: Y ~ gamma(N*alpha, beta) + for(i in 1:N){ + log_y[i] ~ expgamma(pop_vector[i]./cv^2, mu * (cv^2)); +} + + +} + +``` + +```{r} +fit <- sampling(log_model, + standata, + warmup = 500, + iter = 2000, + chains = 4, + cores = 4, + seed = 42, + init = 0, + control = list(adapt_delta = 0.99, max_treedepth = 10) +) +``` +Look at the parameter estimates +```{r} +posterior <- as.array(fit) +np <- nuts_params(fit) + +mcmc_pairs(posterior, + np = np, pars = c("mu", "cv"), + off_diag_args = list(size = 0.75) +) +color_scheme_set("mix-brightblue-gray") +mcmc_trace(posterior, pars = c("mu", "cv"), np = np) + + xlab("Post-warmup iteration") + + +params <- fit |> + tidybayes::spread_draws(mu, cv) |> + dplyr::mutate( + mu_true = mu_set, + cv_true = cv_set + ) + +ggplot(params, aes(x = mu)) + + stat_halfeye() + + geom_vline(aes(xintercept = mu_true)) + + theme_bw() + + xlab("Mean") + + ylab("") + +ggplot(params, aes(x = cv)) + + stat_halfeye() + + geom_vline(aes(xintercept = cv_true)) + + theme_bw() + + xlab("Coefficient of Variation") + + ylab("") +``` + +These look ok but not great. + +# Non-centered parameterization +This relies on an approximation to the sum of iid gammas as implemented in +EpiSewer. + +The non-centered parameterization re-expresses the model and likelihood as: +$$ +\begin{aligned} +\zeta \sim Normal(0,1) \\ + +Y = N\mu + \zeta \sqrt{N}cv +\end{aligned} +$$ +Where $y$ are the observations of the summed distributions over the population, +$N$ is the population size (so analogous to number of infections in the +wastewater model), $\mu$ is the individual mean and $cv$ is the individual +coefficient of variation. + +With observation noise, we can write this as: +$$ +\begin{aligned} + \zeta \sim Normal(0,1) \\ +\overline{y} = N\mu + \zeta \sqrt{N}cv \\ +Y \sim Normal(\overline{y}, \sigma) +\end{aligned} +$$ +Where $\overline{y}$ is the expected value of the observations and $\sigma$ is the observation model noise. +We'll start with the first implementation without observation noise. To compute the log likelihood, +we have to account for a [change of variables](https://mc-stan.org/docs/stan-users-guide/reparameterization.html#changes-of-variables) to solve for $\zeta$ + +```{stan, output.var = "ncp_model_no_noise"} + +functions { + +// Non-centered paramaterization of a normal approximation for the +// sum of N i.i.d. Gamma distributed RVs with mean 1 and a specified cv +vector gamma_sum_approx(real cv, vector N, vector noise_noncentered) { + // sqrt(N) * cv is the standard deviation of the sum of Gamma distributions + return N + noise_noncentered .* sqrt(N) * cv; +} + + +} + +data { + int N; + vector[N] y; + vector[N] pop_vector; + real prior_mu[2]; + real prior_cv[2]; +} + +// The parameters accepted by the model. Our model +// accepts two parameters 'mu' and 'cv', the mean and coefficient of variation +// in the individual R.V.s +parameters { + real mu; // mean of individual component + real cv; // coefficent of variation in individual data +} + + +// The model to be estimated. +model { + // Assume we have a truncated half N prior with a standard deviation of 2 times the mean + // not sure if thats reasonable... + mu ~ normal(prior_mu[1], prior_mu[2]); + cv ~ normal( prior_cv[1], prior_cv[2]); + (y- mu * pop_vector)./(mu * sqrt(pop_vector) * cv) ~ std_normal(); + target+= -sum(log(mu * sqrt(pop_vector) * cv)); + +} + +``` + +Diagnostic warning from PARSER can be ignored (we did indeed adjust for the change of variables) + +```{r} +cv_set <- 0.1 +mu_set <- 5 +standata <- get_data( + mu = mu_set, + cv = cv_set, + prior_mu = c(0.5 * mu_set, 2 * mu_set), # make priors imperfect but on the same scale + prior_cv = c(0, 2), + n_draws = 10, + pop = c(10, 1000, 1e5) +) +df <- data.frame(y = standata$y, pop = rep(standata$pop_vector)) +df <- data.frame(y = standata$y, pop = rep(standata$pop_vector)) +# Quick plot of the data +ggplot(df) + + geom_point(aes(x = pop, y = y)) + + scale_x_continuous(trans = "log") + + scale_y_continuous(trans = "log") + + +fit <- sampling(ncp_model_no_noise, + standata, + warmup = 500, + iter = 2000, + chains = 4, + cores = 4, + seed = 42, + init = 0, + control = list(adapt_delta = 0.99, max_treedepth = 10) +) +``` + +```{r} +posterior <- as.array(fit) +np <- nuts_params(fit) + +mcmc_pairs(posterior, + np = np, pars = c("mu", "cv"), + off_diag_args = list(size = 0.75) +) +color_scheme_set("mix-brightblue-gray") +mcmc_trace(posterior, pars = c("mu", "cv"), np = np) + + xlab("Post-warmup iteration") + + +params <- fit |> + tidybayes::spread_draws(mu, cv) |> + dplyr::mutate( + mu_true = mu_set, + cv_true = cv_set + ) + +ggplot(params, aes(x = mu)) + + stat_halfeye() + + geom_vline(aes(xintercept = mu_true)) + + theme_bw() + + xlab("Mean") + + ylab("") + +ggplot(params, aes(x = cv)) + + stat_halfeye() + + geom_vline(aes(xintercept = cv_true)) + + theme_bw() + + xlab("Coefficient of Variation") + + ylab("") +``` + +This looks good! We definitely needed more iterations to get sufficient ESS. + + +## With observation noise +Now let's add observation noise (lognormal). We can here use the normal transformed parameters approach again for the non-centered parameterization. + +```{stan, output.var = "ncp_model_noise"} + +functions { + +// Non-centered paramaterization of a normal approximation for the +// sum of N i.i.d. Gamma distributed RVs with mean 1 and a specified cv +vector gamma_sum_approx(real cv, vector N, vector noise_noncentered) { + // sqrt(N) * cv is the standard deviation of the sum of Gamma distributions + return N + noise_noncentered .* sqrt(N) * cv; +} + + +} + +data { + int N; + vector[N] y; + vector[N] pop_vector; + real prior_mu[2]; + real prior_cv[2]; + real prior_sigma_obs[2]; +} + +transformed data{ + vector[N] log_y = log(y); +} + +// The parameters accepted by the model. Our model +// accepts two parameters 'mu' and 'cv', the mean and coefficient of variation +// in the individual R.V.s +parameters { + real mu; // mean of individual component + real cv; // coefficent of variation in individual data + vector[N] zeta_raw; + real sigma_obs; // variance of the likelihood +} + +transformed parameters{ + vector[N] exp_obs = mu * gamma_sum_approx(cv, pop_vector, zeta_raw); +} + +// The model to be estimated. +model { + mu ~ normal(prior_mu[1], prior_mu[2]); + cv ~ normal(prior_cv[1], prior_cv[2]); + zeta_raw ~ normal(0,1); + sigma_obs ~ normal(prior_sigma_obs[1], prior_sigma_obs[2]); + + //Likelihood: mean = mu*N + sqrt(N)*cv *N(0.1), sd = sigma_obs + //y ~ lognormal(log(exp_obs), sigma_obs); // alternate way of writing it + log_y ~ normal(log(exp_obs), sigma_obs); + + +} +``` + +```{r} +mu_set <- exp(7) +cv_set <- 0.6 +sigma_obs_set <- 0.4 +standata <- get_data( + mu = mu_set, + pop = c(1e1, 1e3, 1e6), + n_draws = 30, # make sure sufficient pop size + cv = cv_set, + sigma_obs = sigma_obs_set, + prior_mu = c(0.5 * mu_set, 2 * mu_set), + prior_cv = c(0.9 * cv_set, 1), + prior_sigma_obs = c(0, 2) +) +df <- data.frame(y = standata$y, pop = rep(standata$pop_vector)) +# Quick plot of the data +ggplot(df) + + geom_point(aes(x = pop, y = y)) + + scale_x_continuous(trans = "log") + + scale_y_continuous(trans = "log") + + +fit <- sampling(ncp_model_noise, + standata, + warmup = 500, + iter = 2000, + chains = 4, + cores = 4, + seed = 42, + init = 0, + control = list(adapt_delta = 0.99, max_treedepth = 10) +) +``` +```{r} +posterior <- as.array(fit) +np <- nuts_params(fit) + +mcmc_pairs(posterior, + np = np, pars = c("mu", "cv", "sigma_obs"), + off_diag_args = list(size = 0.75) +) +color_scheme_set("mix-brightblue-gray") +mcmc_trace(posterior, pars = c("mu", "cv", "sigma_obs"), np = np) + + xlab("Post-warmup iteration") + + +params <- fit |> + tidybayes::spread_draws(mu, cv, sigma_obs) |> + dplyr::mutate( + mu_true = mu_set, + cv_true = cv_set, + sigma_obs_true = sigma_obs_set + ) + +ggplot(params, aes(x = mu)) + + stat_halfeye() + + geom_vline(aes(xintercept = mu_true)) + + theme_bw() + + xlab("Mean") + + ylab("") + +ggplot(params, aes(x = cv)) + + stat_halfeye() + + geom_vline(aes(xintercept = cv_true)) + + theme_bw() + + xlab("Coefficient of Variation") + + ylab("") + +ggplot(params, aes(x = sigma_obs)) + + stat_halfeye() + + geom_vline(aes(xintercept = sigma_obs_true)) + + theme_bw() + + xlab("Observation noise") + + ylab("") +``` + +This is also looking good.