From 1dbb3722428c35357eab20877fc224bdda0aa097 Mon Sep 17 00:00:00 2001 From: cbernalz Date: Fri, 23 Aug 2024 23:51:33 -0700 Subject: [PATCH] 2024-08-23 update : added all result visuals & updated docs. --- docs/make.jl | 3 +- docs/src/tutorial_index.md | 2 + .../uciwwiehr_model_fitting_forecast.md | 110 ++++++++++++++++++ ...=> uciwwiehr_model_fitting_no_forecast.md} | 26 ++++- src/UCIWWEIHR.jl | 9 +- src/generate_simulation_data_uciwweihr.jl | 38 +++++- src/helper_functions.jl | 24 +++- src/mcmcdiags_vis.jl | 40 +++++-- src/non_time_varying_param_vis.jl | 71 +++++++++++ src/predictive_param_vis.jl | 74 ++++++++++++ src/time_varying_param_vis.jl | 7 +- src/uciwweihr_gq_pp.jl | 2 +- src/uciwweihr_model.jl | 12 +- src/uciwweihr_visualizer.jl | 64 +++++++--- 14 files changed, 435 insertions(+), 47 deletions(-) create mode 100644 docs/src/tutorials/uciwwiehr_model_fitting_forecast.md rename docs/src/tutorials/{uciwweihr_model_fitting.md => uciwwiehr_model_fitting_no_forecast.md} (83%) create mode 100644 src/non_time_varying_param_vis.jl create mode 100644 src/predictive_param_vis.jl diff --git a/docs/make.jl b/docs/make.jl index 6bebd39..c672f6a 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -22,7 +22,8 @@ makedocs(; "GETTING STARTED" => "tutorials/getting_started.md", "UCIWWEIHR SIMULATION DATA" => "tutorials/uciwweihr_simulation_data.md", "AGENT-BASED SIMULATION DATA" => "tutorials/agent_based_simulation_data.md", - "UCIWWEIHR FITTING MODEL" => "tutorials/uciwweihr_model_fitting.md", + "UCIWWEIHR FITTING MODEL W/OUT FORECASTING" => "tutorials/uciwwiehr_model_fitting_no_forecast.md", + "UCIWWEIHR FITTING MODEL W/ FORECASTING" => "tutorials/uciwwiehr_model_fitting_forecast.md", ] , "NEWS" => "news.md", diff --git a/docs/src/tutorial_index.md b/docs/src/tutorial_index.md index 79df2c9..355a04a 100644 --- a/docs/src/tutorial_index.md +++ b/docs/src/tutorial_index.md @@ -7,5 +7,7 @@ Future Description. - [Getting Started](@ref getting_started) - [Generating simulated data with UCIWWEIHR ODE compartmental based model.](@ref uciwweihr_simulation_data) - [Generating simulated data with an agent based model.](@ref agent_based_simulation_data) +- [Generating posterior distribution samples with UCIWWEIHR ODE compartmental based model without forecasting.](@ref uciwwiehr_model_fitting_no_forecast) +- [Generating posterior distribution samples with UCIWWEIHR ODE compartmental based model with forecasting.](@ref uciwwiehr_model_fitting_with_forecast) diff --git a/docs/src/tutorials/uciwwiehr_model_fitting_forecast.md b/docs/src/tutorials/uciwwiehr_model_fitting_forecast.md new file mode 100644 index 0000000..1a76cf9 --- /dev/null +++ b/docs/src/tutorials/uciwwiehr_model_fitting_forecast.md @@ -0,0 +1,110 @@ +```@setup tutorial +using Plots, StatsPlots; gr() +Plots.reset_defaults() + +``` + +# [Generating Posterior Distribution Samples with UCIWWEIHR ODE Compartmental Based Model with Forecasting.](@id uciwwiehr_model_fitting_with_forecast) + +Here we extend the [previous tutorial](@ref uciwwiehr_model_fitting_no_forecast) to include forecasting capabilities. We start with generating out data using `generate_simulation_data_uciwweihr`'s alternate parameterization where we do not prespecify the effective reproduction number and hospitalization probability but instead preform a log-normal random walk and a logit-normal random walk respectively. We then sample from the posterior distribution using the `uciwweihr_fit.jl` function. We then generate desired quantities and forecast for a given time period with the posterior predictive distribution, using `uciwweihr_gq_pp.jl`. + + +## 1. Data Generation. + +Here we generate two datasets, one with 150 time points and one with 178 time points. We will use the 150 time point dataset for fitting and the 178 time point dataset for forecast evaluation. + +``` @example tutorial +using UCIWWEIHR +# Running simulation function with presets +params = create_uciwweihr_params( + time_points = 150 +) +df = generate_simulation_data_uciwweihr(params) + +params_ext = create_uciwweihr_params( + time_points = 178 +) +df_ext = generate_simulation_data_uciwweihr(params_ext) +first(df, 5) +first(df_ext, 5) +``` + +## 2. Sampling from the Posterior Distribution and Posterior Predictive Distribution. + +Here we sample from the posterior distribution using the `uciwweihr_fit.jl` function. First, we setup some presets, then have an array where index 1 contains the posterior/prior predictive samples, index 2 contains the posterior/prior generated quantities samples, and index 3 contains the original sampled parameters for the model. The diference here is that we set `forecast = true` and `forecast_weeks = 4` to forecast 4 weeks into the future. + +``` @example tutorial +data_hosp = df.hosp +data_wastewater = df.log_ww_conc +obstimes = df.obstimes +param_change_times = 1:7:length(obstimes) # Change every week +priors_only = false +n_samples = 200 +forecast = true +forecast_weeks = 4 + +samples = uciwweihr_fit( + data_hosp, + data_wastewater, + obstimes, + param_change_times, + priors_only, + n_samples +) +model_output = uciwweihr_gq_pp( + samples = samples, + data_hosp = data_hosp, + data_wastewater = data_wastewater, + obstimes = obstimes, + param_change_times = param_change_times, + forecast = forecast, + forecast_weeks = forecast_weeks +) + +first(model_output[1][:,1:5], 5) +``` + +``` @example tutorial +first(model_output[2][:,1:5], 5) +``` + +``` @example tutorial +first(model_output[3][:,1:5], 5) +``` + +## 3. MCMC Diagnostic Plots/Results Along with Posterior Predictive Distribution. + +We can again look at model diagnostics, posterior distribution of time or non-time varying parameters, and the posterior predictive distribution extended for forecasting. + +```@example tutorial +uciwweihr_visualizer( + pp_samples = model_output[1], + gq_samples = model_output[2], + data_hosp = df_ext.hosp, + data_wastewater = df_ext.log_ww_conc, + actual_rt_vals = df_ext.rt, + actual_w_t = df_ext.wt, + actual_non_time_varying_vals = params, + forecast_weeks = forecast_weeks, + bayes_dist_type = "Posterior", + save_plots = false +) +``` + +### 3.1. MCMC Diagnostic Plots. + +![Plot 1](plots/mcmc_diagnosis_plots.png) + +### 3.2. Time Varying Parameter Results Plot. + +![Plot 2](plots/mcmc_time_varying_parameter_plots.png) + +### 3.3. Non-Time Varying Parameter Results Plot. +![Plot 3](plots/mcmc_nontime_varying_parameter_plots.png) + +### 3.4. Posterior Predictive Distribution Plot. + +![Plot 4](plots/mcmc_pred_parameter_plots.png) + + +### [Tutorial Contents](@ref tutorial_home) \ No newline at end of file diff --git a/docs/src/tutorials/uciwweihr_model_fitting.md b/docs/src/tutorials/uciwwiehr_model_fitting_no_forecast.md similarity index 83% rename from docs/src/tutorials/uciwweihr_model_fitting.md rename to docs/src/tutorials/uciwwiehr_model_fitting_no_forecast.md index de6061d..1c8f964 100644 --- a/docs/src/tutorials/uciwweihr_model_fitting.md +++ b/docs/src/tutorials/uciwwiehr_model_fitting_no_forecast.md @@ -4,7 +4,7 @@ Plots.reset_defaults() ``` -# [Generating Posterior Distribution Samples with UCIWWEIHR ODE compartmental based model.](@id uciwwiehr_model_fitting) +# [Generating Posterior Distribution Samples with UCIWWEIHR ODE Compartmental Based Model without Forecasting.](@id uciwwiehr_model_fitting_no_forecast) This package has a way to sample from a posterior or prior that is defined in the future paper using the `uciwweihr_fit.jl` and `uciwweihr_model.jl`. We can then generate desired quantities and forecast for a given time period with the posterior predictive distribution, using `uciwweihr_gq_pp.jl`. We first generate data using the `generate_simulation_data_uciwweihr` function which is a non-mispecified version of the model, we will also be using prespecified effective reporduction curves and prespecified hospitalization probability curves. @@ -49,7 +49,7 @@ data_wastewater = df.log_ww_conc obstimes = df.obstimes param_change_times = 1:7:length(obstimes) # Change every week priors_only = false -n_samples = 50 +n_samples = 200 samples = uciwweihr_fit( data_hosp, @@ -83,10 +83,17 @@ first(model_output[3][:,1:5], 5) We also provide a very basic way to visualize some MCMC diagnostics along with effective sample sizes of desired generated quantities(does not include functionality for time-varying quantities). Along with this, we can also visualize the posterior predictive distribution with actual observed values, which can be used to examine forecasts generated by the model. ```@example tutorial -uciwweihr_visualizer(gq_samples = model_output[2], - actual_rt_vals = df.rt, - actual_w_t = df.wt, - save_plots = true) +uciwweihr_visualizer( + pp_samples = model_output[1], + gq_samples = model_output[2], + data_hosp = data_hosp, + data_wastewater = data_wastewater, + actual_rt_vals = df.rt, + actual_w_t = df.wt, + actual_non_time_varying_vals = params, + bayes_dist_type = "Posterior", + save_plots = true +) ``` ### 3.1. MCMC Diagnostic Plots. @@ -97,5 +104,12 @@ uciwweihr_visualizer(gq_samples = model_output[2], ![Plot 2](plots/mcmc_time_varying_parameter_plots.png) +### 3.3. Non-Time Varying Parameter Results Plot. +![Plot 3](plots/mcmc_nontime_varying_parameter_plots.png) + +### 3.4. Posterior Predictive Distribution Plot. + +![Plot 4](plots/mcmc_pred_parameter_plots.png) + ### [Tutorial Contents](@ref tutorial_home) \ No newline at end of file diff --git a/src/UCIWWEIHR.jl b/src/UCIWWEIHR.jl index a43708b..b36933c 100644 --- a/src/UCIWWEIHR.jl +++ b/src/UCIWWEIHR.jl @@ -30,15 +30,18 @@ include("generalizedtdist.jl") include("uciwweihr_model.jl") include("uciwweihr_fit.jl") include("uciwweihr_gq_pp.jl") -include("uciwweihr_visualizer.jl") include("helper_functions.jl") include("mcmcdiags_vis.jl") include("time_varying_param_vis.jl") +include("non_time_varying_param_vis.jl") +include("predictive_param_vis.jl") +include("uciwweihr_visualizer.jl") export eihr_ode export uciwweihr_sim_params export create_uciwweihr_params export generate_random_walk +export generate_logit_normal_random_walk export generate_simulation_data_uciwweihr export generate_simulation_data_agent export NegativeBinomial2 @@ -50,8 +53,12 @@ export uciwweihr_visualizer export ChainsCustomIndexs export save_plots_to_docs export startswith_any +export generate_colors export calculate_quantiles +export repeat_last_n_elements export mcmcdiags_vis export time_varying_param_vis +export non_time_varying_param_vis +export predictive_param_vis end \ No newline at end of file diff --git a/src/generate_simulation_data_uciwweihr.jl b/src/generate_simulation_data_uciwweihr.jl index 3a601a9..ff821d3 100644 --- a/src/generate_simulation_data_uciwweihr.jl +++ b/src/generate_simulation_data_uciwweihr.jl @@ -1,6 +1,7 @@ """ uciwweihr_sim_params -Struct for holding parameters used in the UCIWWEIHR ODE compartmental model simulation. + +Struct for holding parameters used in the UCIWWEIHR ODE compartmental model simulation. Use `create_uciwweihr_params` to create an instance of this struct. # Fields - `time_points::Int64`: Number of time points for the simulation. @@ -19,6 +20,8 @@ Struct for holding parameters used in the UCIWWEIHR ODE compartmental model simu - `sigma_Rt::Float64`: Standard deviation for random walk of time-varying reproduction number. - `w::Union{Float64, Vector{Float64}}`: Initial value or time series of the time-varying hospitalization rate. - `sigma_w::Float64`: Standard deviation for random walk of time-varying hospitalization rate. +- `rt_init::Float64`: Initial value of the time-varying reproduction number, NOT USER SPECIFIED `create_uciwweihr_params` TAKES CARE OF THIS. +- `w_init::Float64`: Initial value of the time-varying hospitalization rate, NOT USER SPECIFIED `create_uciwweihr_params` TAKES CARE OF THIS. """ struct uciwweihr_sim_params time_points::Int64 @@ -37,10 +40,13 @@ struct uciwweihr_sim_params sigma_Rt::Float64 w::Union{Float64, Vector{Float64}} sigma_w::Float64 + rt_init::Float64 + w_init::Float64 end """ create_uciwweihr_params(; kwargs...) + Creates a `uciwweihr_sim_params` struct with the option to either use a predetermined `Rt` and `w` or generate them as random walks. # Arguments @@ -60,16 +66,19 @@ function create_uciwweihr_params(; time_points::Int64=150, seed::Int64=1, sigma_w::Float64=sqrt(0.001)) Random.seed!(seed) + rt_init = isa(Rt, Float64) ? Rt : Rt[1] + w_init = isa(w, Float64) ? w : w[1] Rt_t = isa(Rt, Float64) ? generate_random_walk(time_points, sigma_Rt, Rt) : Rt - w_t = isa(w, Float64) ? generate_random_walk(time_points, sigma_w, w) : w + w_t = isa(w, Float64) ? generate_logit_normal_random_walk(time_points, sigma_w, w) : w return uciwweihr_sim_params(time_points, seed, E_init, I_init, H_init, gamma, nu, epsilon, - rho_gene, tau, df, sigma_hosp, Rt_t, sigma_Rt, w_t, sigma_w) + rho_gene, tau, df, sigma_hosp, Rt_t, sigma_Rt, w_t, sigma_w, rt_init, w_init) end """ generate_random_walk(time_points::Int64, sigma::Float64, init_val::Float64) + Generates a random walk time series. # Arguments @@ -90,6 +99,29 @@ function generate_random_walk(time_points::Int64, sigma::Float64, init_val::Floa return walk end +""" + generate_logit_normal_random_walk(time_points::Int64, sigma::Float64, init_val::Float64) + +Generates a logit-normal random walk time series. + +# Arguments +- `time_points::Int64`: Number of time points. +- `sigma::Float64`: Standard deviation of the random walk in logit space. +- `init_val::Float64`: Initial value of the random walk on the probability scale. + +# Returns +- `walk::Vector{Float64}`: Generated random walk on the probability scale. +""" +function generate_logit_normal_random_walk(time_points::Int64, sigma::Float64, init_val::Float64) + walk = Float64[] + logit_val = logit(init_val) + for _ in 1:time_points + logit_val = rand(Normal(0, sigma)) + logit_val + push!(walk, logistic(logit_val)) + end + return walk +end + """ generate_simulation_data(params::UCIWWEIHRParams) diff --git a/src/helper_functions.jl b/src/helper_functions.jl index c28714a..ce221d5 100644 --- a/src/helper_functions.jl +++ b/src/helper_functions.jl @@ -110,7 +110,6 @@ end Saves plots to docs/plots directory. -Function created by Christian Bernal Zelaya. """ function save_plots_to_docs(plot, filename; format = "png") doc_loc = "plots" @@ -130,7 +129,6 @@ end Checks if the name of time varying paramter starts with any of the patterns. -Function created by Christian Bernal Zelaya. """ function startswith_any(name, patterns) for pattern in patterns @@ -147,7 +145,6 @@ end Calculate quantiles for a given chain and variable prefix. Quantiles can be any user desired quantile. -Function created by Christian Bernal Zelaya. """ function calculate_quantiles(df, chain, var_prefix, quantiles) df_chain = filter(row -> row.chain == chain, df) @@ -167,9 +164,28 @@ end Generates a vector with colors for ribbons in plots. -Function created by Christian Bernal Zelaya. """ function generate_colors(number_of_colors) alpha_values = range(0.1, stop=0.7, length=number_of_colors) return [RGBA(colorant"blue", alpha) for alpha in alpha_values] +end + + +""" + repeat_last_n_elements(x::Vector{T}, n::Int, w::Int) where T + +Modifies a given array so that the last n elements are repeated w times. + +""" +function repeat_last_n_elements(x::Vector{T}, n::Int, w::Int) where T + if n == 0 + return x + else + n = min(n, length(x)) + last_n_elements = x[end-n+1:end] + repeated_elements = [elem for elem in last_n_elements for _ in 1:w] + x_new = vcat(x, repeated_elements) + + return x_new + end end \ No newline at end of file diff --git a/src/mcmcdiags_vis.jl b/src/mcmcdiags_vis.jl index e93d871..84e6e96 100644 --- a/src/mcmcdiags_vis.jl +++ b/src/mcmcdiags_vis.jl @@ -5,6 +5,7 @@ Default visualizer for results of the UCIWWEIHR model, includes posterior/priors # Arguments - `gq_samples`: Generated quantities samples from the posterior/prior distribution, index 2 in uciwweihr_gq_pp output. - `desired_params`: A list of lists of parameters to visualize. Each list will be visualized in a separate plot. Default is [["E_init", "I_init", "H_init"], ["gamma", "nu", "epsilon"], ["rho_gene", "tau", "df"], ["sigma_hosp"]]. +- `actual_non_time_varying_vals::uciwweihr_sim_params`: A uciwweihr_sim_params object of actual non-time varying parameter values if user has access to them. Default is nothing. - `save_plots::Bool=false`: A boolean to indicate if user wants to save the plots as pngs into a plots folder. """ function mcmcdiags_vis(; @@ -16,6 +17,7 @@ function mcmcdiags_vis(; ["rho_gene", "tau", "df"], ["sigma_hosp"] ], + actual_non_time_varying_vals::uciwweihr_sim_params = nothing, save_plots::Bool=false ) @@ -40,14 +42,28 @@ function mcmcdiags_vis(; for param in param_group if param in names(gq_samples) df_param = filter(row -> row.name == param, df_filtered) - title = "MCMC Diagnosis Plot for Chain $chain, $param (ESS: $(eff_sample_sizes[param]))" plt = plot(df_param.iteration, df_param.value, - legend=false, - title=title, - xlabel="Iteration", ylabel="Value Drawn", - color = :black, lw = 2 + label = "Chain $chain", + title = "Trace of $param", + xlabel = "Iteration", ylabel="Value Drawn", + color = :black, lw = 2, + xguidefont = font(8), + yguidefont = font(8), + titlefont = font(10), + legendfont = font(8), + legend = :topright ) push!(cat_plots, plt) + + if !isnothing(actual_non_time_varying_vals) + actual_param_value = round(getfield(actual_non_time_varying_vals, Symbol(param)), digits=3) + scatter!(plt, [1], Float64[actual_param_value], + label = "Actual Value : $actual_param_value", + color = :red, + markersize = 5, + marker = :circle, + legendfont = font(8)) + end else println("PARAMETER $param NOT IN GENERATED QUANTITIES!!!") end @@ -55,13 +71,21 @@ function mcmcdiags_vis(; end end if !isempty(cat_plots) - plt = plot(cat_plots..., layout=(length(unique(gq_samples.chain)) * length(desired_params[1]), length(desired_params)), size = (1000, 1000)) - display(plt) + num_plots = length(cat_plots) + num_chains = length(unique(gq_samples.chain)) + num_params = length(desired_params) + layout_rows = num_chains * length(desired_params[1]) + layout_cols = num_params + + final_plot = plot(cat_plots..., + layout = (layout_rows, layout_cols), + size = (1500, 1500)) + display(final_plot) if save_plots save_plots_to_docs(plt, "mcmc_diagnosis_plots") end else - println("NO PLOTS TO DISPLAY!!!") + println("NO MCMC DIAGNOSIS PLOTS TO DISPLAY!!!") end end diff --git a/src/non_time_varying_param_vis.jl b/src/non_time_varying_param_vis.jl new file mode 100644 index 0000000..1ec5437 --- /dev/null +++ b/src/non_time_varying_param_vis.jl @@ -0,0 +1,71 @@ +""" + non_time_varying_param_vis(...) + +Used in the `uciwweihr_visualizer` to create visuals for non-time varying parameters. + +# Arguments +- `gq_samples`: Generated quantities samples from the posterior/prior distribution, index 2 in uciwweihr_gq_pp output. +- `desired_params`: A list of lists of parameters to visualize. Each list will be visualized in a separate plot. Default is any parameter not in this list : ["alpha_t", "w_t", "rt_vals", "log_genes_mean", "H"] +- `bayes_dist_type`: A string to indicate if user is using Posterior or Prior distribution. Default is "Posterior". +- `actual_non_time_varying_vals::uciwweihr_sim_params`: A uciwweihr_sim_params object of actual non-time varying parameter values if user has access to them. Default is nothing. +- `save_plots::Bool=false`: A boolean to indicate if user wants to save the plots as pngs into a plots folder. +""" +function non_time_varying_param_vis(; + gq_samples=nothing, + desired_params=nothing, + bayes_dist_type="Posterior", + actual_non_time_varying_vals::uciwweihr_sim_params = nothing, + save_plots::Bool=false + ) + # Plotting non time varying parameters + non_time_varying_plots = [] + for param_group in desired_params + for curr_param in param_group + chains = unique(gq_samples.chain) + for chain in chains + curr_param_chain_df = filter(row -> row.chain == chain, gq_samples) + plt = histogram(curr_param_chain_df[:, curr_param], + label = "Chain $chain", + title = "$bayes_dist_type $curr_param", + bins = 50, + normalize = :probability, + xlabel = "Probability", + ylabel = "Value for $curr_param", + xguidefont = font(8), + yguidefont = font(8), + titlefont = font(10), + legendfont = font(8), + legend = :topright) + if !isnothing(actual_non_time_varying_vals) + actual_param_value = round(getfield(actual_non_time_varying_vals, Symbol(curr_param)), digits=3) + scatter!(plt, Float64[actual_param_value], [0.002], + label = "Actual Value : $actual_param_value", + color = :red, + markersize = 5, + marker = :circle, + legendfont = font(8)) + end + push!(non_time_varying_plots, plt) + end + end + end + if !isempty(non_time_varying_plots) + num_plots = length(non_time_varying_plots) + num_chains = length(unique(gq_samples.chain)) + num_params = length(desired_params) + layout_rows = num_chains * length(desired_params[1]) + layout_cols = num_params + + final_plot = plot(non_time_varying_plots..., + layout = (layout_rows, layout_cols), + size = (1500, 1500)) + display(final_plot) + if save_plots + save_plots_to_docs(plt, "mcmc_nontime_varying_parameter_plots") + end + else + println("NO NON-TIME VARYING PARAMETER PLOTS TO DISPLAY!!!") + end + + +end diff --git a/src/predictive_param_vis.jl b/src/predictive_param_vis.jl new file mode 100644 index 0000000..81077ba --- /dev/null +++ b/src/predictive_param_vis.jl @@ -0,0 +1,74 @@ +""" + predictive_param_vis(...) +Used in the `uciwweihr_visualizer` to create visuals for wastewater data and hospitalization data. + +# Arguments +- `pp_samples`: A DataFrame of posterior or prior predictive samples. +- `data_wastewater`: An array of actual wastewater values if user has access to them assumed, using time scale of observed time points. Default is nothing. +- `data_hosp`: An array of actual hospitalization values if user has access to them assumed, , using time scale of observed time points. Default is nothing. +- `forecast_weeks`: An integer of the number of weeks forecasted. Default is 0. +- `vars_to_pred`: A list of variables to predict. Default is ["data_wastewater", "data_hosp"]. +- `quantiles`: A list of quantiles to calculate for ploting uncertainty. Default is [0.5, 0.8, 0.95]. +- `bayes_dist_type`: A string to indicate if user is using Posterior or Prior distribution. Default is "Posterior". +- `save_plots::Bool=false`: A boolean to indicate if user wants to save the plots as pngs into a plots folder. +""" +function predictive_param_vis(; + pp_samples = nothing, + data_wastewater = nothing, + data_hosp = nothing, + forecast_weeks = 0, + vars_to_pred = ["data_wastewater", "data_hosp"], + quantiles = [0.5, 0.8, 0.95], + bayes_dist_type = "Posterior", + save_plots::Bool = false + ) + # Plotting wastewater and hosp predictive posterior / prior + pred_plots = [] + column_names = names(pp_samples) + for var_prefix in vars_to_pred + pred_var_names = filter(name -> startswith_any(name, [var_prefix]), column_names) + pred_var_df = pp_samples[:, [pred_var_names..., "iteration", "chain"]] + chains = unique(pred_var_df.chain) + for chain in chains + pred_var_elem = pred_var_df[:, [pred_var_names..., "chain"]] + medians, lower_bounds, upper_bounds = calculate_quantiles(pred_var_elem, chain, var_prefix, quantiles) + ribbon_colors = generate_colors(length(quantiles)) + preped_medians = repeat_last_n_elements(medians, forecast_weeks, 7) + preped_lower_bounds = repeat_last_n_elements(lower_bounds, forecast_weeks, 7) + preped_upper_bounds = repeat_last_n_elements(upper_bounds, forecast_weeks, 7) + time_index = 1:length(preped_medians) + plt = plot(title = "$bayes_dist_type Quantiles for Chain $chain for $var_prefix", + xlabel = "Time Points (daily scale)", + ylabel = "Value for $var_prefix") + for (i, q) in enumerate(quantiles) + preped_upper_bounds_temp = map(x -> x[i], preped_upper_bounds) + preped_lower_bounds_temp = map(x -> x[i], preped_lower_bounds) + plot!(plt, time_index, preped_medians, ribbon = (preped_upper_bounds_temp .- preped_medians, preped_medians .- preped_lower_bounds_temp), + fillalpha = 0.2, + label = "$(@sprintf("%.0f", q*100))% Quantile", + color = ribbon_colors[i], + fillcolor = ribbon_colors[i]) + end + plot!(plt, time_index, preped_medians, label = "Median", color = :black, lw = 2) + if !isnothing(data_wastewater) && var_prefix == "data_wastewater" + scatter!(plt, 1:length(data_wastewater), data_wastewater, label = "Actual Wastewater Values", color = :red, lw = 2, marker = :circle) + end + if !isnothing(data_hosp) && var_prefix == "data_hosp" + scatter!(plt, 1:length(data_hosp), data_hosp, label = "Actual Hosp Values", color = :red, lw = 2, marker = :circle) + end + push!(pred_plots, plt) + end + end + + if !isempty(pred_plots) + chains = unique(pp_samples.chain) + plt = plot(pred_plots..., layout = (length(pred_plots), length(chains)), size = (1000, 1000)) + display(plt) + if save_plots + save_plots_to_docs(plt, "mcmc_pred_parameter_plots") + end + else + println("NO TIME VARYING PARAMETER PLOTS TO DISPLAY!!!") + end + +end diff --git a/src/time_varying_param_vis.jl b/src/time_varying_param_vis.jl index 99de412..370cdc0 100644 --- a/src/time_varying_param_vis.jl +++ b/src/time_varying_param_vis.jl @@ -1,6 +1,6 @@ """ time_varying_param_vis(...) -Default visualizer for results of the UCIWWEIHR model, includes posterior/priors of generated quantities and posterior predictive samples for forecasting. Forecasting plots will have the observed data alongside. +Used in the `uciwweihr_visualizer` to create visuals for time varying parameters. # Arguments - `gq_samples`: Generated quantities samples from the posterior/prior distribution, index 2 in uciwweihr_gq_pp output. @@ -20,10 +20,9 @@ function time_varying_param_vis(; ) # Plotting time varying parameters - var_prefixs = time_varying_params time_varying_plots = [] column_names = names(gq_samples) - for var_prefix in var_prefixs + for var_prefix in time_varying_params time_varying_param = filter(name -> startswith_any(name, [var_prefix]), column_names) time_varying_subset_df = gq_samples[:, [time_varying_param..., "iteration", "chain"]] chains = unique(time_varying_subset_df.chain) @@ -59,7 +58,7 @@ function time_varying_param_vis(; if !isempty(time_varying_plots) chains = unique(gq_samples.chain) - plt = plot(time_varying_plots..., layout = (length(var_prefixs), length(chains)), size = (1000, 1000)) + plt = plot(time_varying_plots..., layout = (length(time_varying_params), length(chains)), size = (1000, 1000)) display(plt) if save_plots save_plots_to_docs(plt, "mcmc_time_varying_parameter_plots") diff --git a/src/uciwweihr_gq_pp.jl b/src/uciwweihr_gq_pp.jl index 8c057db..57ca348 100644 --- a/src/uciwweihr_gq_pp.jl +++ b/src/uciwweihr_gq_pp.jl @@ -47,7 +47,7 @@ The defaults for this fuction will follow those of the default simulation in gen - Samples from the posterior or prior distribution. """ -function uciwweihr_gq_pp( +function uciwweihr_gq_pp(; samples, data_hosp, data_wastewater, diff --git a/src/uciwweihr_model.jl b/src/uciwweihr_model.jl index f4ef547..7df816d 100644 --- a/src/uciwweihr_model.jl +++ b/src/uciwweihr_model.jl @@ -55,7 +55,7 @@ The defaults for this fuction will follow those of the default simulation in gen sigma_hosp_sd::Float64=50.0, sigma_hosp_mean::Float64=500.0, Rt_init_sd::Float64=0.3, Rt_init_mean::Float64=0.2, sigma_Rt_sd::Float64=0.2, sigma_Rt_mean::Float64=-3.0, - w_init_sd::Float64=0.1, w_init_mean::Float64=log(0.35), + w_init_sd::Float64=0.04, w_init_mean::Float64=logit(0.35), sigma_w_sd::Float64=0.2, sigma_w_mean::Float64=-3.5 ) @@ -94,7 +94,7 @@ The defaults for this fuction will follow those of the default simulation in gen w_params_non_centered ~ MvNormal(zeros(l_param_change_times + 2), I) # +2 for sigma and init sigma_w_non_centered = w_params_non_centered[1] w_init_non_centered = w_params_non_centered[2] - log_w_steps_non_centered = w_params_non_centered[3:end] + logit_w_steps_non_centered = w_params_non_centered[3:end] # TRANSFORMATIONS-------------------- @@ -117,10 +117,12 @@ The defaults for this fuction will follow those of the default simulation in gen alpha_t_no_init = exp.(log(Rt_init) .+ cumsum(log_Rt_steps_non_centered) * sigma_Rt) * nu alpha_init = Rt_init * nu alpha_t = vcat(alpha_init, alpha_t_no_init) - # Non-constant Hosp Rate w - w_init = exp(w_init_non_centered * w_init_sd + w_init_mean) + # Non-constant Hosp Prob w + w_init_logit = w_init_non_centered * w_init_sd + w_init_mean sigma_w = exp(sigma_w_non_centered * sigma_w_sd + sigma_w_mean) - w_no_init = exp.(log(w_init) .+ cumsum(log_w_steps_non_centered) * sigma_w) + logit_w_no_init = w_init_logit .+ cumsum(logit_w_steps_non_centered) * sigma_w + w_init = logistic(w_init_logit) + w_no_init = logistic.(logit_w_no_init) w_t = vcat(w_init, w_no_init) diff --git a/src/uciwweihr_visualizer.jl b/src/uciwweihr_visualizer.jl index fea9182..bfacc6d 100644 --- a/src/uciwweihr_visualizer.jl +++ b/src/uciwweihr_visualizer.jl @@ -7,27 +7,31 @@ Default visualizer for results of the UCIWWEIHR model, includes posterior/priors - `gq_samples`: Generated quantities samples from the posterior/prior distribution, index 2 in uciwweihr_gq_pp output. - `data_hosp`: An array of hospital data. - `data_wastewater`: An array of pathogen genome concentration in localized wastewater data. -- `obstimes`: An array of timepoints for observed hosp/wastewater. -- `param_change_times`: An array of timepoints where the parameters change. - `actual_rt_vals`: An array of actual Rt values if user has access to them assumed to be on a daily scale. This typically will come from some simulation. Default is nothing. - `actual_w_t`: An array of actual w_t values if user has access to them assumed to be on a daily scale. This typically will come from some simulation. Default is nothing. +- `actual_non_time_varying_vals::uciwweihr_sim_params`: A uciwweihr_sim_params object of actual non-time varying parameter values if user has access to them. Default is nothing. +- `forecast_weeks`: Number of weeks to forecasted. Default is 0. - `desired_params`: A list of lists of parameters to visualize. Each list will be visualized in a separate plot. Default is [["E_init", "I_init", "H_init"], ["gamma", "nu", "epsilon"], ["rho_gene", "tau", "df"], ["sigma_hosp"]]. - `time_varying_params`: A list of time varying parameters to visualize. Default is ["rt_vals", "w_t"]. +- `var_to_pred`: A list of variables to predict. Default is ["data_wastewater", "data_hosp"]. - `quantiles`: A list of quantiles to calculate for ploting uncertainty. Default is [0.5, 0.8, 0.95]. +- `bayes_dist_type`: A string to indicate if user is using Posterior or Prior distribution ("Posterior" / "Prior"). - `mcmcdaigs::Bool=true`: A boolean to indicate if user wants to visualize mcmc diagnosis plots and Effective Sample Size(ESS). - `time_varying_plots::Bool=true`: A boolean to indicate if user wants to visualize time varying parameters. +- `non_time_varying_plots::Bool=true`: A boolean to indicate if user wants to visualize non-time varying parameters. +- `pred_param_plots::Bool=true`: A boolean to indicate if user wants to visualize posterior (or prior) predictive parameter results. - `save_plots::Bool=false`: A boolean to indicate if user wants to save the plots as pngs into a plots folder. """ function uciwweihr_visualizer(; - pp_samples=nothing, - gq_samples=nothing, - data_hosp=nothing, - data_wastewater=nothing, - obstimes=nothing, - param_change_times=nothing, - actual_rt_vals=nothing, - actual_w_t=nothing, - desired_params=[ + pp_samples = nothing, + gq_samples = nothing, + data_hosp = nothing, + data_wastewater = nothing, + actual_rt_vals = nothing, + actual_w_t = nothing, + actual_non_time_varying_vals::uciwweihr_sim_params = nothing, + forecast_weeks = 0, + desired_params = [ ["E_init", "I_init", "H_init"], ["gamma", "nu", "epsilon"], ["rt_init", "w_init"], @@ -35,10 +39,14 @@ function uciwweihr_visualizer(; ["sigma_hosp"] ], time_varying_params = ["rt_vals", "w_t"], + var_to_pred = ["data_wastewater", "data_hosp"], quantiles = [0.5, 0.8, 0.95], - mcmcdaigs::Bool=true, - time_varying_plots::Bool=true, - save_plots::Bool=false + bayes_dist_type = nothing, + mcmcdaigs::Bool = true, + time_varying_plots::Bool = true, + non_time_varying_plots::Bool = true, + pred_param_plots::Bool = true, + save_plots::Bool = false ) # Posterior/Prior Samples @@ -47,6 +55,7 @@ function uciwweihr_visualizer(; mcmcdiags_vis( gq_samples = gq_samples, desired_params = desired_params, + actual_non_time_varying_vals = actual_non_time_varying_vals, save_plots = save_plots ) else @@ -66,6 +75,33 @@ function uciwweihr_visualizer(; println("MCMC time varying parameter results are not requested.") end + if non_time_varying_plots + non_time_varying_param_vis( + gq_samples = gq_samples, + desired_params = desired_params, + bayes_dist_type = bayes_dist_type, + actual_non_time_varying_vals = actual_non_time_varying_vals, + save_plots = save_plots + ) + else + println("MCMC non-time varying parameter results are not requested.") + end + + if pred_param_plots + predictive_param_vis( + pp_samples = pp_samples, + data_wastewater = data_wastewater, + data_hosp = data_hosp, + forecast_weeks = forecast_weeks, + vars_to_pred = var_to_pred, + quantiles = quantiles, + bayes_dist_type = bayes_dist_type, + save_plots = save_plots + ) + else + println("MCMC posterior (or prior) predictive parameter results are not requested.") + end + end