From 3f02ad89dd021d3b512c17e1d62c324da7c8437b Mon Sep 17 00:00:00 2001 From: cbernalz Date: Sun, 4 Aug 2024 18:07:37 -0700 Subject: [PATCH] "2024-08-04 update : added timevarying param plots." --- Project.toml | 2 + docs/src/tutorials/uciwweihr_model_fitting.md | 12 +++- .../tutorials/uciwweihr_simulation_data.md | 9 +++ src/UCIWWEIHR.jl | 4 ++ src/generate_simulation_data_uciwweihr.jl | 3 +- src/helper_functions.jl | 50 +++++++++++++ src/uciwweihr_model.jl | 9 ++- src/uciwweihr_visualizer.jl | 70 ++++++++++++++++++- 8 files changed, 151 insertions(+), 8 deletions(-) diff --git a/Project.toml b/Project.toml index 6bfdbd8..c182984 100644 --- a/Project.toml +++ b/Project.toml @@ -7,6 +7,7 @@ version = "1.0.0-DEV" AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9" CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" +Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" @@ -20,6 +21,7 @@ NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" Optim = "429524aa-4258-5aef-a3af-852621145aeb" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46" +Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd" diff --git a/docs/src/tutorials/uciwweihr_model_fitting.md b/docs/src/tutorials/uciwweihr_model_fitting.md index d487694..75d20fc 100644 --- a/docs/src/tutorials/uciwweihr_model_fitting.md +++ b/docs/src/tutorials/uciwweihr_model_fitting.md @@ -62,9 +62,19 @@ first(model_output[3][:,1:5], 5) We also provide a very basic way to visualize some MCMC diagnostics along with effective sample sizes of desired generated quantities(does not include functionality for time-varying quantities). Along with this, we can also visualize the posterior predictive distribution with actual observed values, which can be used to examine forecasts generated by the model. ```@example tutorial -uciwweihr_visualizer(gq_samples = model_output[2], save_plots = true) +uciwweihr_visualizer(gq_samples = model_output[2], + actual_rt_vals = df.rt, + actual_w_t = df.wt, + save_plots = true) ``` + +### 3.1. MCMC Diagnostic Plots. + ![Plot 1](plots/mcmc_diagnosis_plots.png) +### 3.2. Time Varying Parameter Results Plot. + +![Plot 2](plots/mcmc_time_varying_parameter_plots.png) + ### [Tutorial Contents](@ref tutorial_home) \ No newline at end of file diff --git a/docs/src/tutorials/uciwweihr_simulation_data.md b/docs/src/tutorials/uciwweihr_simulation_data.md index 0d0aa28..62d3e49 100644 --- a/docs/src/tutorials/uciwweihr_simulation_data.md +++ b/docs/src/tutorials/uciwweihr_simulation_data.md @@ -43,5 +43,14 @@ plot(df.obstimes, df.rt, title="Plot of Rt Over Time") ``` +### 2.4. Hospitalization rate. +```@example tutorial +plot(df.obstimes, df.wt, + label=nothing, + xlabel="Obstimes", + ylabel="Rt", + title="Plot of Hospitalization Rate Over Time") +``` + ### [Tutorial Contents](@ref tutorial_home) \ No newline at end of file diff --git a/src/UCIWWEIHR.jl b/src/UCIWWEIHR.jl index 6804b77..00da2da 100644 --- a/src/UCIWWEIHR.jl +++ b/src/UCIWWEIHR.jl @@ -19,6 +19,8 @@ using DataFrames using DifferentialEquations using StatsBase using Plots +using Printf +using Colors include("generate_simulation_data_uciwweihr.jl") include("generate_simulation_data_agent.jl") @@ -42,5 +44,7 @@ export uciwweihr_gq_pp export uciwweihr_visualizer export ChainsCustomIndexs export save_plots_to_docs +export startswith_any +export calculate_quantiles end \ No newline at end of file diff --git a/src/generate_simulation_data_uciwweihr.jl b/src/generate_simulation_data_uciwweihr.jl index b5bb2d1..80963e3 100644 --- a/src/generate_simulation_data_uciwweihr.jl +++ b/src/generate_simulation_data_uciwweihr.jl @@ -88,7 +88,8 @@ function generate_simulation_data_uciwweihr( obstimes = 1:time_points, log_ww_conc = data_wastewater, hosp = data_hosp, - rt = Rt_t_no_init + rt = Rt_t_no_init, + wt = w_no_init ); return df diff --git a/src/helper_functions.jl b/src/helper_functions.jl index f720e41..c28714a 100644 --- a/src/helper_functions.jl +++ b/src/helper_functions.jl @@ -122,4 +122,54 @@ function save_plots_to_docs(plot, filename; format = "png") savefig(plot, file_target_path) println("Plot saved to $file_target_path") +end + + +""" + startswith_any(name, patterns) + +Checks if the name of time varying paramter starts with any of the patterns. + +Function created by Christian Bernal Zelaya. +""" +function startswith_any(name, patterns) + for pattern in patterns + if startswith(name, pattern) + return true + end + end + return false +end + + +""" + calculate_quantiles(df, chain, var_prefix, quantiles) + +Calculate quantiles for a given chain and variable prefix. Quantiles can be any user desired quantile. + +Function created by Christian Bernal Zelaya. +""" +function calculate_quantiles(df, chain, var_prefix, quantiles) + df_chain = filter(row -> row.chain == chain, df) + column_names = names(df_chain) + var_names = filter(name -> startswith_any(name, [var_prefix]), column_names) + medians = [median(df_chain[:, var]) for var in var_names] + lower_bounds = [quantile(df_chain[:, var], (1 .- quantiles) / 2) for var in var_names] + upper_bounds = [quantile(df_chain[:, var], 1 .- (1 .- quantiles) / 2) for var in var_names] + + + return medians, lower_bounds, upper_bounds +end + + +""" + generate_ribbon_colors(number_of_colors) + +Generates a vector with colors for ribbons in plots. + +Function created by Christian Bernal Zelaya. +""" +function generate_colors(number_of_colors) + alpha_values = range(0.1, stop=0.7, length=number_of_colors) + return [RGBA(colorant"blue", alpha) for alpha in alpha_values] end \ No newline at end of file diff --git a/src/uciwweihr_model.jl b/src/uciwweihr_model.jl index b8cd57b..f4ef547 100644 --- a/src/uciwweihr_model.jl +++ b/src/uciwweihr_model.jl @@ -162,8 +162,9 @@ The defaults for this fuction will follow those of the default simulation in gen # Generated quantities H_comp = sol_array[3, :] - rt_vals = alpha_t / nu - w_t = w_t + rt_vals = alpha_t_no_init / nu + rt_init = alpha_init / nu + w_t = w_no_init return ( E_init, @@ -180,7 +181,9 @@ The defaults for this fuction will follow those of the default simulation in gen df = df, sigma_hosp = sigma_hosp, H = H_comp, - log_genes_mean = log_genes_mean + log_genes_mean = log_genes_mean, + rt_init = rt_init, + w_init = w_init ) diff --git a/src/uciwweihr_visualizer.jl b/src/uciwweihr_visualizer.jl index 1e29711..74a2f1b 100644 --- a/src/uciwweihr_visualizer.jl +++ b/src/uciwweihr_visualizer.jl @@ -9,10 +9,14 @@ Default visualizer for results of the UCIWWEIHR model, includes posterior/priors - `data_wastewater`: An array of pathogen genome concentration in localized wastewater data. - `obstimes`: An array of timepoints for observed hosp/wastewater. - `param_change_times`: An array of timepoints where the parameters change. +- `actual_rt_vals`: An array of actual Rt values if user has access to them assumed to be on a daily scale. This typically will come from some simulation. Default is nothing. +- `actual_w_t`: An array of actual w_t values if user has access to them assumed to be on a daily scale. This typically will come from some simulation. Default is nothing. - `desired_params`: A list of lists of parameters to visualize. Each list will be visualized in a separate plot. Default is [["E_init", "I_init", "H_init"], ["gamma", "nu", "epsilon"], ["rho_gene", "tau", "df"], ["sigma_hosp"]]. +- `time_varying_params`: A list of time varying parameters to visualize. Default is ["rt_vals", "w_t"]. +- `quantiles`: A list of quantiles to calculate for ploting uncertainty. Default is [0.5, 0.8, 0.95]. - `mcmcdaigs::Bool=true`: A boolean to indicate if user wants to visualize mcmc diagnosis plots and Effective Sample Size(ESS). -- `save_plots::Bool=false`: A boolean to indicate if user wants to save the plots as pngs. - +- `time_varying_plots::Bool=true`: A boolean to indicate if user wants to visualize time varying parameters. +- `save_plots::Bool=false`: A boolean to indicate if user wants to save the plots as pngs into a plots folder. """ function uciwweihr_visualizer(; pp_samples=nothing, @@ -21,13 +25,19 @@ function uciwweihr_visualizer(; data_wastewater=nothing, obstimes=nothing, param_change_times=nothing, + actual_rt_vals=nothing, + actual_w_t=nothing, desired_params=[ ["E_init", "I_init", "H_init"], ["gamma", "nu", "epsilon"], + ["rt_init", "w_init"], ["rho_gene", "tau", "df"], ["sigma_hosp"] ], + time_varying_params = ["rt_vals", "w_t"], + quantiles = [0.5, 0.8, 0.95], mcmcdaigs::Bool=true, + time_varying_plots::Bool=true, save_plots::Bool=false ) @@ -58,6 +68,7 @@ function uciwweihr_visualizer(; legend=false, title=title, xlabel="Iteration", ylabel="Value Drawn", + color = :black, lw = 2 ) push!(cat_plots, plt) else @@ -67,7 +78,7 @@ function uciwweihr_visualizer(; end end if !isempty(cat_plots) - plt = plot(cat_plots..., layout=(length(unique(gq_samples.chain)) * length(desired_params[1]), length(desired_params))) + plt = plot(cat_plots..., layout=(length(unique(gq_samples.chain)) * length(desired_params[1]), length(desired_params)), size = (1000, 1000)) display(plt) if save_plots save_plots_to_docs(plt, "mcmc_diagnosis_plots") @@ -79,6 +90,59 @@ function uciwweihr_visualizer(; println("MCMC Diagnostics Plots are not requested.") end + if time_varying_plots + + # Plotting time varying parameters + var_prefixs = time_varying_params + time_varying_plots = [] + column_names = names(gq_samples) + for var_prefix in var_prefixs + time_varying_param = filter(name -> startswith_any(name, [var_prefix]), column_names) + time_varying_subset_df = gq_samples[:, [time_varying_param..., "iteration", "chain"]] + chains = unique(time_varying_subset_df.chain) + for chain in chains + medians, lower_bounds, upper_bounds = calculate_quantiles(time_varying_subset_df, chain, var_prefix, quantiles) + ribbon_colors = generate_colors(length(quantiles)) + daily_medians = repeat(medians, inner=7) + daily_lower_bounds = repeat(lower_bounds, inner=7) + daily_upper_bounds = repeat(upper_bounds, inner=7) + daily_x = 1:length(daily_medians) + plt = plot(title = "Quantiles for Chain $chain for $var_prefix", + xlabel = "Time Points (daily scale)", + ylabel = "Value for $var_prefix") + for (i, q) in enumerate(quantiles) + daily_upper_bounds_temp = map(x -> x[i], daily_upper_bounds) + daily_lower_bounds_temp = map(x -> x[i], daily_lower_bounds) + plot!(plt, daily_x, daily_medians, ribbon = (daily_upper_bounds_temp .- daily_medians, daily_medians .- daily_lower_bounds_temp), + fillalpha = 0.2, + label = "$(@sprintf("%.0f", q*100))% Quantile", + color = ribbon_colors[i], + fillcolor = ribbon_colors[i]) + end + plot!(plt, daily_x, daily_medians, label = "Median", color = :black, lw = 2) + if !isnothing(actual_rt_vals) && var_prefix == "rt_vals" + scatter!(plt, 1:length(actual_rt_vals), actual_rt_vals, label = "Actual Rt Values", color = :red, lw = 2, marker = :circle) + end + if !isnothing(actual_w_t) && var_prefix == "w_t" + scatter!(plt, 1:length(actual_w_t), actual_w_t, label = "Actual w_t Values", color = :red, lw = 2, marker = :circle) + end + push!(time_varying_plots, plt) + end + end + + if !isempty(time_varying_plots) + chains = unique(gq_samples.chain) + plt = plot(time_varying_plots..., layout = (length(var_prefixs), length(chains)), size = (1000, 1000)) + display(plt) + if save_plots + save_plots_to_docs(plt, "mcmc_time_varying_parameter_plots") + end + else + println("NO TIME VARYING PARAMETER PLOTS TO DISPLAY!!!") + end + else + println("MCMC time varying parameter results are not requested.") + end