From 8f965c3abdaece2ab715d8260f7867329f9e494a Mon Sep 17 00:00:00 2001 From: cbernalz Date: Mon, 19 Aug 2024 20:14:04 -0700 Subject: [PATCH] 2024-08-19 update : modularize visualizer & determin rt / wt. --- docs/src/tutorials/uciwweihr_model_fitting.md | 27 ++- .../tutorials/uciwweihr_simulation_data.md | 87 ++++++++- src/UCIWWEIHR.jl | 7 + src/generate_simulation_data_uciwweihr.jl | 181 +++++++++++------- src/mcmcdiags_vis.jl | 67 +++++++ src/time_varying_param_vis.jl | 71 +++++++ src/uciwweihr_visualizer.jl | 104 ++-------- 7 files changed, 377 insertions(+), 167 deletions(-) create mode 100644 src/mcmcdiags_vis.jl create mode 100644 src/time_varying_param_vis.jl diff --git a/docs/src/tutorials/uciwweihr_model_fitting.md b/docs/src/tutorials/uciwweihr_model_fitting.md index 75d20fc..de6061d 100644 --- a/docs/src/tutorials/uciwweihr_model_fitting.md +++ b/docs/src/tutorials/uciwweihr_model_fitting.md @@ -6,15 +6,36 @@ Plots.reset_defaults() # [Generating Posterior Distribution Samples with UCIWWEIHR ODE compartmental based model.](@id uciwwiehr_model_fitting) -This package has a way to sample from a posterior or prior that is defined in the future paper using the `uciwweihr_fit.jl` and `uciwweihr_model.jl`. We can then generate desired quantities and forecast for a given time period with the posterior predictive distribution, using `uciwweihr_gq_pp.jl`. We first generate data using the `generate_simulation_data_uciwweihr` function which is a non-mispecified version of the model. +This package has a way to sample from a posterior or prior that is defined in the future paper using the `uciwweihr_fit.jl` and `uciwweihr_model.jl`. We can then generate desired quantities and forecast for a given time period with the posterior predictive distribution, using `uciwweihr_gq_pp.jl`. We first generate data using the `generate_simulation_data_uciwweihr` function which is a non-mispecified version of the model, we will also be using prespecified effective reporduction curves and prespecified hospitalization probability curves. ## 1. Data Generation. ``` @example tutorial using UCIWWEIHR -# Running simulation function with defaults -df = generate_simulation_data_uciwweihr() +# Running simulation function with presets +rt_custom = vcat( + range(1, stop=1.8, length=7*4), + fill(1.8, 7*2), + range(1.8, stop=1, length=7*8), + range(0.98, stop=0.8, length=7*2), + range(0.8, stop=1.1, length=7*6), + range(1.1, stop=0.97, length=7*3) +) +w_custom = vcat( + range(0.3, stop=0.38, length=7*5), + fill(0.38, 7*2), + range(0.38, stop=0.25, length=7*8), + range(0.25, stop=0.28, length=7*2), + range(0.28, stop=0.34, length=7*6), + range(0.34, stop=0.28, length=7*2) +) +params = create_uciwweihr_params( + time_points = length(rt_custom), + Rt = rt_custom, + w = w_custom +) +df = generate_simulation_data_uciwweihr(params) first(df, 5) ``` diff --git a/docs/src/tutorials/uciwweihr_simulation_data.md b/docs/src/tutorials/uciwweihr_simulation_data.md index 62d3e49..4ca66d3 100644 --- a/docs/src/tutorials/uciwweihr_simulation_data.md +++ b/docs/src/tutorials/uciwweihr_simulation_data.md @@ -8,15 +8,16 @@ This package provides a way to also simulate data using the UCIWWEIHR ODE compar using UCIWWEIHR using Plots # Running simulation function with defaults -df = generate_simulation_data_uciwweihr() +params = create_uciwweihr_params() +df = generate_simulation_data_uciwweihr(params) first(df, 5) ``` -## 2. Visualizing UCIWWEIHR model results. +## 1.2 Visualizing UCIWWEIHR model results. -Here we can make simple plots to visualize the data generated using the [Plots](https://docs.juliaplots.org/stable/) package. +Here we can make simple plots to visualize the data generated using the [Plots](https://docs.juliaplots.org/stable/) package. -### 2.1. Concentration of pathogen genome in wastewater(WW). +### 1.2.1. Concentration of pathogen genome in wastewater(WW). ```@example tutorial plot(df.obstimes, df.log_ww_conc, label=nothing, @@ -25,7 +26,7 @@ plot(df.obstimes, df.log_ww_conc, title="Plot of Conc. of Pathogen Genome in WW Over Time") ``` -### 2.2. Hospitalizations. +### 1.2.2. Hospitalizations. ```@example tutorial plot(df.obstimes, df.hosp, label=nothing, @@ -34,7 +35,7 @@ plot(df.obstimes, df.hosp, title="Plot of Hosp Over Time") ``` -### 2.3. Reproductive number. +### 1.2.3. Reproductive number. ```@example tutorial plot(df.obstimes, df.rt, label=nothing, @@ -43,7 +44,79 @@ plot(df.obstimes, df.rt, title="Plot of Rt Over Time") ``` -### 2.4. Hospitalization rate. +### 1.2.4. Hospitalization rate. +```@example tutorial +plot(df.obstimes, df.wt, + label=nothing, + xlabel="Obstimes", + ylabel="Rt", + title="Plot of Hospitalization Rate Over Time") +``` + +## 2. Alternate Functionality. +We can also use a prespecified effective repordcution number curve or a prespecified hospitaliation probability curve. Any combintation of presepcified or random walk curves can be used. Here we provide an example of using both a prespecified effective reproduction number curve and a prespecified hospitalization probability curve. + +``` @example tutorial +using UCIWWEIHR +using Plots +# Running simulation function with prespecified Rt and hospitalization probability +rt_custom = vcat( + range(1, stop=1.8, length=7*4), + fill(1.8, 7*2), + range(1.8, stop=1, length=7*8), + range(0.98, stop=0.8, length=7*2), + range(0.8, stop=1.1, length=7*6), + range(1.1, stop=0.97, length=7*3) +) +w_custom = vcat( + range(0.3, stop=0.38, length=7*5), + fill(0.38, 7*2), + range(0.38, stop=0.25, length=7*8), + range(0.25, stop=0.28, length=7*2), + range(0.28, stop=0.34, length=7*6), + range(0.34, stop=0.28, length=7*2) +) +params = create_uciwweihr_params( + time_points = length(rt_custom), + Rt = rt_custom, + w = w_custom +) +df = generate_simulation_data_uciwweihr(params) +first(df, 5) +``` + +## 2.2 Visualizing UCIWWEIHR model results. + +We can visualize these results using the [Plots](https://docs.juliaplots.org/stable/) package. + +### 2.2.1. Concentration of pathogen genome in wastewater(WW). +```@example tutorial +plot(df.obstimes, df.log_ww_conc, + label=nothing, + xlabel="Obstimes", + ylabel="Conc. of Pathogen Genome in WW", + title="Plot of Conc. of Pathogen Genome in WW Over Time") +``` + +### 2.2.2. Hospitalizations. +```@example tutorial +plot(df.obstimes, df.hosp, + label=nothing, + xlabel="Obstimes", + ylabel="Hosp", + title="Plot of Hosp Over Time") +``` + +### 2.2.3. Reproductive number. +```@example tutorial +plot(df.obstimes, df.rt, + label=nothing, + xlabel="Obstimes", + ylabel="Rt", + title="Plot of Rt Over Time") +``` + +### 2.2.4. Hospitalization rate. ```@example tutorial plot(df.obstimes, df.wt, label=nothing, diff --git a/src/UCIWWEIHR.jl b/src/UCIWWEIHR.jl index 00da2da..a43708b 100644 --- a/src/UCIWWEIHR.jl +++ b/src/UCIWWEIHR.jl @@ -32,8 +32,13 @@ include("uciwweihr_fit.jl") include("uciwweihr_gq_pp.jl") include("uciwweihr_visualizer.jl") include("helper_functions.jl") +include("mcmcdiags_vis.jl") +include("time_varying_param_vis.jl") export eihr_ode +export uciwweihr_sim_params +export create_uciwweihr_params +export generate_random_walk export generate_simulation_data_uciwweihr export generate_simulation_data_agent export NegativeBinomial2 @@ -46,5 +51,7 @@ export ChainsCustomIndexs export save_plots_to_docs export startswith_any export calculate_quantiles +export mcmcdiags_vis +export time_varying_param_vis end \ No newline at end of file diff --git a/src/generate_simulation_data_uciwweihr.jl b/src/generate_simulation_data_uciwweihr.jl index 80963e3..3a601a9 100644 --- a/src/generate_simulation_data_uciwweihr.jl +++ b/src/generate_simulation_data_uciwweihr.jl @@ -1,67 +1,121 @@ """ -## Generating Simulation Data for UCIWWEIHR ODE Compartmental Based Model + uciwweihr_sim_params +Struct for holding parameters used in the UCIWWEIHR ODE compartmental model simulation. -To generate simulation data using the UCIWWEIHR ODE compartmental based model, you can use the `generate_simulation_data_uciwweihr` function defined in the `UCIWWEIHR.jl` package. This function allows you to customize various parameters for the simulation. +# Fields +- `time_points::Int64`: Number of time points for the simulation. +- `seed::Int64`: Seed for random number generation. +- `E_init::Int64`: Initial number of exposed individuals. +- `I_init::Int64`: Initial number of infected individuals. +- `H_init::Int64`: Initial number of hospitalized individuals. +- `gamma::Float64`: Rate of incubation. +- `nu::Float64`: Rate of leaving the infected compartment. +- `epsilon::Float64`: Rate of hospitalization recovery. +- `rho_gene::Float64`: Contribution of infected individual's pathogen genome into wastewater. +- `tau::Float64`: Scale/variation of the log concentration of pathogen genome in wastewater. +- `df::Float64`: Degrees of freedom for generalized t-distribution for log concentration of pathogen genome in wastewater. +- `sigma_hosp::Float64`: Standard deviation for the negative binomial distribution for hospital data. +- `Rt::Union{Float64, Vector{Float64}}`: Initial value or time series of the time-varying reproduction number. +- `sigma_Rt::Float64`: Standard deviation for random walk of time-varying reproduction number. +- `w::Union{Float64, Vector{Float64}}`: Initial value or time series of the time-varying hospitalization rate. +- `sigma_w::Float64`: Standard deviation for random walk of time-varying hospitalization rate. +""" +struct uciwweihr_sim_params + time_points::Int64 + seed::Int64 + E_init::Int64 + I_init::Int64 + H_init::Int64 + gamma::Float64 + nu::Float64 + epsilon::Float64 + rho_gene::Float64 + tau::Float64 + df::Float64 + sigma_hosp::Float64 + Rt::Union{Float64, Vector{Float64}} + sigma_Rt::Float64 + w::Union{Float64, Vector{Float64}} + sigma_w::Float64 +end -### Function Signature +""" + create_uciwweihr_params(; kwargs...) +Creates a `uciwweihr_sim_params` struct with the option to either use a predetermined `Rt` and `w` or generate them as random walks. # Arguments -- time_points::Int64: Number of time points wanted for simulation. Default value is 150. -- seed::Int64: Seed for random number generation. Default value is 1. -- E_init::Int64: Initial number of exposed individuals. Default value is 200. -- I_init::Int64: Initial number of infected individuals. Default value is 100. -- H_init::Int64: Initial number of hospitalized individuals. Default value is 20. -- gamma::Float64: Rate of incubation. Default value is 1/4. -- nu::Float64: Rate of leaving the infected compartment. Default value is 1/7. -- epsilon::Float64: Rate of hospitalization recovery. Default value is 1/5. -- rho_gene::Float64: Contribution of infected individual's pathogen genome into wastewater. Default value is 0.011. -- tau::Float64: Scale/variation of the log concentration of pathogen genome in wastewater. Default value is 0.1. -- df::Float64: Degrees of freedom for generalized t distribution for log concentration of pathogen genome in wastewater. Default value is 29. -- sigma_hosp::Float64: Standard deviation for the negative binomial distribution for hospital data. Default value is 800. -- Rt_init::Float64: Initial value of the time-varying reproduction number. Default value is 1. -- sigma_Rt::Float64: Standard deviation for random walk of time-varying reproduction number. Default value is sqrt(0.02). -- w_init::Float64: Initial value of the time-varying hospitalization rate. Default value is 0.35. -- sigma_w::Float64: Standard deviation for random walk of time-varying hospitalization rate. Default value is sqrt(0.02). +- `kwargs...`: Named arguments corresponding to the fields in `uciwweihr_sim_params`. # Returns -- df::DataFrame: A DataFrame containing the simulation data with columns `obstimes`, `log_ww_conc`, `hosp`, and `rt`. +- `params::uciwweihr_sim_params`: A struct with simulation parameters. """ -function generate_simulation_data_uciwweihr( - time_points::Int64=150, seed::Int64=1, - E_init::Int64=200, I_init::Int64=100, H_init::Int64=20, - gamma::Float64=1/4, nu::Float64=1/7, epsilon::Float64=1/5, - rho_gene::Float64=0.011, tau::Float64=0.1, df::Float64=29.0, - sigma_hosp::Float64=800.0, - Rt_init::Float64=1.0, sigma_Rt::Float64=sqrt(0.001), - w_init::Float64=0.35, sigma_w::Float64=sqrt(0.001), -) - +function create_uciwweihr_params(; time_points::Int64=150, seed::Int64=1, + E_init::Int64=200, I_init::Int64=100, H_init::Int64=20, + gamma::Float64=1/4, nu::Float64=1/7, epsilon::Float64=1/5, + rho_gene::Float64=0.011, tau::Float64=0.1, df::Float64=29.0, + sigma_hosp::Float64=800.0, + Rt::Union{Float64, Vector{Float64}}=1.0, + sigma_Rt::Float64=sqrt(0.001), + w::Union{Float64, Vector{Float64}}=0.35, + sigma_w::Float64=sqrt(0.001)) + Random.seed!(seed) - - # Rt and W SETUP-------------------------- - Rt_t_no_init = Float64[] # Pre-defined vector - w_no_init = Float64[] # Pre-defined vector - log_Rt_t = log(Rt_init) - log_w_t = log(w_init) + + Rt_t = isa(Rt, Float64) ? generate_random_walk(time_points, sigma_Rt, Rt) : Rt + w_t = isa(w, Float64) ? generate_random_walk(time_points, sigma_w, w) : w + + return uciwweihr_sim_params(time_points, seed, E_init, I_init, H_init, gamma, nu, epsilon, + rho_gene, tau, df, sigma_hosp, Rt_t, sigma_Rt, w_t, sigma_w) +end + +""" + generate_random_walk(time_points::Int64, sigma::Float64, init_val::Float64) +Generates a random walk time series. + +# Arguments +- `time_points::Int64`: Number of time points. +- `sigma::Float64`: Standard deviation of the random walk. +- `init_val::Float64`: Initial value of the random walk. + +# Returns +- `walk::Vector{Float64}`: Generated random walk. +""" +function generate_random_walk(time_points::Int64, sigma::Float64, init_val::Float64) + walk = Float64[] + log_val = log(init_val) for _ in 1:time_points - log_Rt_t = log_Rt_t + rand(Normal(0, sigma_Rt)) - log_w_t = log_w_t + rand(Normal(0, sigma_w)) - push!(Rt_t_no_init, exp(log_Rt_t)) - push!(w_no_init, exp(log_w_t)) + log_val = rand(Normal(0, sigma)) + log_val + push!(walk, exp(log_val)) end - alpha_t_no_init = Rt_t_no_init * nu - alpha_init = Rt_init * nu + return walk +end + +""" + generate_simulation_data(params::UCIWWEIHRParams) + +Generates simulation data for the UCIWWEIHR ODE compartmental model. + +# Arguments +- `params::uciwweihr_sim_params`: Struct containing parameters for the simulation. + +# Returns +- `df::DataFrame`: A DataFrame containing the simulation data with columns `obstimes`, `log_ww_conc`, `hosp`, `rt`, and `wt`. +""" +function generate_simulation_data_uciwweihr(params::uciwweihr_sim_params) + time_points = params.time_points + + alpha_t = params.Rt .* params.nu + u0 = [params.E_init, params.I_init, params.H_init] + p0 = [alpha_t[1], params.gamma, params.nu, params.w[1], params.epsilon] - # ODE SETUP-------------------------- - prob = ODEProblem{true}(eihr_ode!, zeros(3), (0.0, time_points), ones(5)) + prob = ODEProblem(eihr_ode!, u0, (0.0, time_points), p0) + function param_affect_beta!(integrator) - ind_t = searchsortedfirst(collect(1:time_points), integrator.t) # Find the index of collect(1:time_points) that contains the current timestep - integrator.p[1] = alpha_t_no_init[ind_t] # Replace alpha with a new value from alpha_t_no_init - integrator.p[4] = w_no_init[ind_t] # Replace w with a new value from w_no_init + ind_t = searchsortedfirst(1:time_points, integrator.t) + integrator.p[1] = alpha_t[ind_t] + integrator.p[4] = params.w[ind_t] end - param_callback = PresetTimeCallback(collect(1:time_points), param_affect_beta!, save_positions=(false, false)) - u0 = [E_init, I_init, H_init] - p0 = [alpha_init, gamma, nu, w_init, epsilon] + param_callback = PresetTimeCallback(1:time_points, param_affect_beta!, save_positions=(false, false)) extra_ode_precision = false abstol = extra_ode_precision ? 1e-11 : 1e-9 reltol = extra_ode_precision ? 1e-8 : 1e-6 @@ -72,26 +126,21 @@ function generate_simulation_data_uciwweihr( return end sol_array = Array(sol) - I_comp_sol = clamp.(sol_array[2,2:end],1, 1e10) - H_comp_sol = clamp.(sol_array[3,2:end], 1, 1e10) - - # Log Gene SETUP-------------------------- - log_genes_mean = log.(I_comp_sol) .+ log(rho_gene) # first entry is the initial conditions, we want 2:end - data_wastewater = zeros(time_points) - data_hosp = zeros(time_points) - for t_i in 1:time_points - data_wastewater[t_i] = rand(GeneralizedTDist(log_genes_mean[t_i], tau, df)) - data_hosp[t_i] = rand(NegativeBinomial2(H_comp_sol[t_i], sigma_hosp)) - end + I_comp_sol = clamp.(sol_array[2, 2:end], 1, 1e10) + H_comp_sol = clamp.(sol_array[3, 2:end], 1, 1e10) + + # Log Gene Setup + log_genes_mean = log.(I_comp_sol) .+ log(params.rho_gene) + data_wastewater = [rand(GeneralizedTDist(log_genes_mean[t], params.tau, params.df)) for t in 1:time_points] + data_hosp = [rand(NegativeBinomial2(H_comp_sol[t], params.sigma_hosp)) for t in 1:time_points] df = DataFrame( obstimes = 1:time_points, log_ww_conc = data_wastewater, hosp = data_hosp, - rt = Rt_t_no_init, - wt = w_no_init - ); - return df + rt = params.Rt, + wt = params.w + ) - + return df end diff --git a/src/mcmcdiags_vis.jl b/src/mcmcdiags_vis.jl new file mode 100644 index 0000000..e93d871 --- /dev/null +++ b/src/mcmcdiags_vis.jl @@ -0,0 +1,67 @@ +""" + mcmcdiags_vis(...) +Default visualizer for results of the UCIWWEIHR model, includes posterior/priors of generated quantities and posterior predictive samples for forecasting. Forecasting plots will have the observed data alongside. + +# Arguments +- `gq_samples`: Generated quantities samples from the posterior/prior distribution, index 2 in uciwweihr_gq_pp output. +- `desired_params`: A list of lists of parameters to visualize. Each list will be visualized in a separate plot. Default is [["E_init", "I_init", "H_init"], ["gamma", "nu", "epsilon"], ["rho_gene", "tau", "df"], ["sigma_hosp"]]. +- `save_plots::Bool=false`: A boolean to indicate if user wants to save the plots as pngs into a plots folder. +""" +function mcmcdiags_vis(; + gq_samples=nothing, + desired_params=[ + ["E_init", "I_init", "H_init"], + ["gamma", "nu", "epsilon"], + ["rt_init", "w_init"], + ["rho_gene", "tau", "df"], + ["sigma_hosp"] + ], + save_plots::Bool=false + ) + + # Posterior/Prior Samples + ## MCMC evaluation + cat_plots = [] + for chain in unique(gq_samples.chain) + for param_group in desired_params + eff_sample_sizes = Dict{String, Float64}() + for param in param_group + if param in names(gq_samples) + size_temp = round(ess(gq_samples[gq_samples.chain .== chain, param])) + eff_sample_sizes[param] = size_temp + println("Effective Sample Size for $param for Chain $chain: $size_temp") + else + println("PARAMETER $param NOT IN GENERATED QUANTITIES!!!") + end + end + long_df = stack(gq_samples[gq_samples.chain .== chain, :], Not([:iteration, :chain]), variable_name=:name, value_name=:value) + df_filtered = filter(row -> row.name in param_group, long_df) + + for param in param_group + if param in names(gq_samples) + df_param = filter(row -> row.name == param, df_filtered) + title = "MCMC Diagnosis Plot for Chain $chain, $param (ESS: $(eff_sample_sizes[param]))" + plt = plot(df_param.iteration, df_param.value, + legend=false, + title=title, + xlabel="Iteration", ylabel="Value Drawn", + color = :black, lw = 2 + ) + push!(cat_plots, plt) + else + println("PARAMETER $param NOT IN GENERATED QUANTITIES!!!") + end + end + end + end + if !isempty(cat_plots) + plt = plot(cat_plots..., layout=(length(unique(gq_samples.chain)) * length(desired_params[1]), length(desired_params)), size = (1000, 1000)) + display(plt) + if save_plots + save_plots_to_docs(plt, "mcmc_diagnosis_plots") + end + else + println("NO PLOTS TO DISPLAY!!!") + end + +end diff --git a/src/time_varying_param_vis.jl b/src/time_varying_param_vis.jl new file mode 100644 index 0000000..99de412 --- /dev/null +++ b/src/time_varying_param_vis.jl @@ -0,0 +1,71 @@ +""" + time_varying_param_vis(...) +Default visualizer for results of the UCIWWEIHR model, includes posterior/priors of generated quantities and posterior predictive samples for forecasting. Forecasting plots will have the observed data alongside. + +# Arguments +- `gq_samples`: Generated quantities samples from the posterior/prior distribution, index 2 in uciwweihr_gq_pp output. +- `actual_rt_vals`: An array of actual Rt values if user has access to them assumed to be on a daily scale. This typically will come from some simulation. Default is nothing. +- `actual_w_t`: An array of actual w_t values if user has access to them assumed to be on a daily scale. This typically will come from some simulation. Default is nothing. +- `time_varying_params`: A list of time varying parameters to visualize. Default is ["rt_vals", "w_t"]. +- `quantiles`: A list of quantiles to calculate for ploting uncertainty. Default is [0.5, 0.8, 0.95]. +- `save_plots::Bool=false`: A boolean to indicate if user wants to save the plots as pngs into a plots folder. +""" +function time_varying_param_vis(; + gq_samples=nothing, + actual_rt_vals=nothing, + actual_w_t=nothing, + time_varying_params = ["rt_vals", "w_t"], + quantiles = [0.5, 0.8, 0.95], + save_plots::Bool=false + ) + + # Plotting time varying parameters + var_prefixs = time_varying_params + time_varying_plots = [] + column_names = names(gq_samples) + for var_prefix in var_prefixs + time_varying_param = filter(name -> startswith_any(name, [var_prefix]), column_names) + time_varying_subset_df = gq_samples[:, [time_varying_param..., "iteration", "chain"]] + chains = unique(time_varying_subset_df.chain) + for chain in chains + medians, lower_bounds, upper_bounds = calculate_quantiles(time_varying_subset_df, chain, var_prefix, quantiles) + ribbon_colors = generate_colors(length(quantiles)) + daily_medians = repeat(medians, inner=7) + daily_lower_bounds = repeat(lower_bounds, inner=7) + daily_upper_bounds = repeat(upper_bounds, inner=7) + daily_x = 1:length(daily_medians) + plt = plot(title = "Quantiles for Chain $chain for $var_prefix", + xlabel = "Time Points (daily scale)", + ylabel = "Value for $var_prefix") + for (i, q) in enumerate(quantiles) + daily_upper_bounds_temp = map(x -> x[i], daily_upper_bounds) + daily_lower_bounds_temp = map(x -> x[i], daily_lower_bounds) + plot!(plt, daily_x, daily_medians, ribbon = (daily_upper_bounds_temp .- daily_medians, daily_medians .- daily_lower_bounds_temp), + fillalpha = 0.2, + label = "$(@sprintf("%.0f", q*100))% Quantile", + color = ribbon_colors[i], + fillcolor = ribbon_colors[i]) + end + plot!(plt, daily_x, daily_medians, label = "Median", color = :black, lw = 2) + if !isnothing(actual_rt_vals) && var_prefix == "rt_vals" + scatter!(plt, 1:length(actual_rt_vals), actual_rt_vals, label = "Actual Rt Values", color = :red, lw = 2, marker = :circle) + end + if !isnothing(actual_w_t) && var_prefix == "w_t" + scatter!(plt, 1:length(actual_w_t), actual_w_t, label = "Actual w_t Values", color = :red, lw = 2, marker = :circle) + end + push!(time_varying_plots, plt) + end + end + + if !isempty(time_varying_plots) + chains = unique(gq_samples.chain) + plt = plot(time_varying_plots..., layout = (length(var_prefixs), length(chains)), size = (1000, 1000)) + display(plt) + if save_plots + save_plots_to_docs(plt, "mcmc_time_varying_parameter_plots") + end + else + println("NO TIME VARYING PARAMETER PLOTS TO DISPLAY!!!") + end + +end diff --git a/src/uciwweihr_visualizer.jl b/src/uciwweihr_visualizer.jl index 74a2f1b..fea9182 100644 --- a/src/uciwweihr_visualizer.jl +++ b/src/uciwweihr_visualizer.jl @@ -44,102 +44,24 @@ function uciwweihr_visualizer(; # Posterior/Prior Samples ## MCMC evaluation if mcmcdaigs - cat_plots = [] - for chain in unique(gq_samples.chain) - for param_group in desired_params - eff_sample_sizes = Dict{String, Float64}() - for param in param_group - if param in names(gq_samples) - size_temp = round(ess(gq_samples[gq_samples.chain .== chain, param])) - eff_sample_sizes[param] = size_temp - println("Effective Sample Size for $param for Chain $chain: $size_temp") - else - println("PARAMETER $param NOT IN GENERATED QUANTITIES!!!") - end - end - long_df = stack(gq_samples[gq_samples.chain .== chain, :], Not([:iteration, :chain]), variable_name=:name, value_name=:value) - df_filtered = filter(row -> row.name in param_group, long_df) - - for param in param_group - if param in names(gq_samples) - df_param = filter(row -> row.name == param, df_filtered) - title = "MCMC Diagnosis Plot for Chain $chain, $param (ESS: $(eff_sample_sizes[param]))" - plt = plot(df_param.iteration, df_param.value, - legend=false, - title=title, - xlabel="Iteration", ylabel="Value Drawn", - color = :black, lw = 2 - ) - push!(cat_plots, plt) - else - println("PARAMETER $param NOT IN GENERATED QUANTITIES!!!") - end - end - end - end - if !isempty(cat_plots) - plt = plot(cat_plots..., layout=(length(unique(gq_samples.chain)) * length(desired_params[1]), length(desired_params)), size = (1000, 1000)) - display(plt) - if save_plots - save_plots_to_docs(plt, "mcmc_diagnosis_plots") - end - else - println("NO PLOTS TO DISPLAY!!!") - end + mcmcdiags_vis( + gq_samples = gq_samples, + desired_params = desired_params, + save_plots = save_plots + ) else println("MCMC Diagnostics Plots are not requested.") end if time_varying_plots - - # Plotting time varying parameters - var_prefixs = time_varying_params - time_varying_plots = [] - column_names = names(gq_samples) - for var_prefix in var_prefixs - time_varying_param = filter(name -> startswith_any(name, [var_prefix]), column_names) - time_varying_subset_df = gq_samples[:, [time_varying_param..., "iteration", "chain"]] - chains = unique(time_varying_subset_df.chain) - for chain in chains - medians, lower_bounds, upper_bounds = calculate_quantiles(time_varying_subset_df, chain, var_prefix, quantiles) - ribbon_colors = generate_colors(length(quantiles)) - daily_medians = repeat(medians, inner=7) - daily_lower_bounds = repeat(lower_bounds, inner=7) - daily_upper_bounds = repeat(upper_bounds, inner=7) - daily_x = 1:length(daily_medians) - plt = plot(title = "Quantiles for Chain $chain for $var_prefix", - xlabel = "Time Points (daily scale)", - ylabel = "Value for $var_prefix") - for (i, q) in enumerate(quantiles) - daily_upper_bounds_temp = map(x -> x[i], daily_upper_bounds) - daily_lower_bounds_temp = map(x -> x[i], daily_lower_bounds) - plot!(plt, daily_x, daily_medians, ribbon = (daily_upper_bounds_temp .- daily_medians, daily_medians .- daily_lower_bounds_temp), - fillalpha = 0.2, - label = "$(@sprintf("%.0f", q*100))% Quantile", - color = ribbon_colors[i], - fillcolor = ribbon_colors[i]) - end - plot!(plt, daily_x, daily_medians, label = "Median", color = :black, lw = 2) - if !isnothing(actual_rt_vals) && var_prefix == "rt_vals" - scatter!(plt, 1:length(actual_rt_vals), actual_rt_vals, label = "Actual Rt Values", color = :red, lw = 2, marker = :circle) - end - if !isnothing(actual_w_t) && var_prefix == "w_t" - scatter!(plt, 1:length(actual_w_t), actual_w_t, label = "Actual w_t Values", color = :red, lw = 2, marker = :circle) - end - push!(time_varying_plots, plt) - end - end - - if !isempty(time_varying_plots) - chains = unique(gq_samples.chain) - plt = plot(time_varying_plots..., layout = (length(var_prefixs), length(chains)), size = (1000, 1000)) - display(plt) - if save_plots - save_plots_to_docs(plt, "mcmc_time_varying_parameter_plots") - end - else - println("NO TIME VARYING PARAMETER PLOTS TO DISPLAY!!!") - end + time_varying_param_vis( + gq_samples = gq_samples, + actual_rt_vals = actual_rt_vals, + actual_w_t = actual_w_t, + time_varying_params = time_varying_params, + quantiles = quantiles, + save_plots = save_plots + ) else println("MCMC time varying parameter results are not requested.") end