diff --git a/src/generate_simulation_data_uciwweihr.jl b/src/generate_simulation_data_uciwweihr.jl index 85e2033..aa8a8fd 100644 --- a/src/generate_simulation_data_uciwweihr.jl +++ b/src/generate_simulation_data_uciwweihr.jl @@ -24,26 +24,27 @@ Struct for holding parameters used in the UCIWWEIHR ODE compartmental model simu - `w_init::Float64`: Initial value of the time-varying hospitalization rate, NOT USER SPECIFIED `create_uciwweihr_params` TAKES CARE OF THIS. """ struct uciwweihr_sim_params - time_points::Int64 - seed::Int64 - E_init::Int64 - I_init::Int64 - H_init::Int64 - gamma::Float64 - nu::Float64 - epsilon::Float64 - rho_gene::Float64 - tau::Float64 - df::Float64 - sigma_hosp::Float64 - Rt::Union{Float64, Vector{Float64}} - sigma_Rt::Float64 - w::Union{Float64, Vector{Float64}} - sigma_w::Float64 - rt_init::Float64 - w_init::Float64 + time_points::Union{Int64, Nothing} + seed::Union{Int64, Nothing} + E_init::Union{Int64, Nothing} + I_init::Union{Int64, Nothing} + H_init::Union{Int64, Nothing} + gamma::Union{Float64, Nothing} + nu::Union{Float64, Nothing} + epsilon::Union{Float64, Nothing} + rho_gene::Union{Float64, Nothing} + tau::Union{Float64, Nothing} + df::Union{Float64, Nothing} + sigma_hosp::Union{Float64, Nothing} + Rt::Union{Float64, Vector{Float64}, Nothing} + sigma_Rt::Union{Float64, Nothing} + w::Union{Float64, Vector{Float64}, Nothing} + sigma_w::Union{Float64, Nothing} + rt_init::Union{Float64, Nothing} + w_init::Union{Float64, Nothing} end + """ create_uciwweihr_sim_params(; kwargs...) diff --git a/src/mcmcdiags_vis.jl b/src/mcmcdiags_vis.jl index c6938ca..0e9a7e1 100644 --- a/src/mcmcdiags_vis.jl +++ b/src/mcmcdiags_vis.jl @@ -57,7 +57,7 @@ function mcmcdiags_vis(; ) push!(cat_plots, plt) - if !isnothing(actual_non_time_varying_vals) + if !isnothing(actual_non_time_varying_vals.time_points) actual_param_value = round(getfield(actual_non_time_varying_vals, Symbol(param)), digits=3) scatter!(plt, [1], Float64[actual_param_value], label = "Actual Value : $actual_param_value", diff --git a/src/non_time_varying_param_vis.jl b/src/non_time_varying_param_vis.jl index 776990e..436e590 100644 --- a/src/non_time_varying_param_vis.jl +++ b/src/non_time_varying_param_vis.jl @@ -15,7 +15,6 @@ Used in the `uciwweihr_visualizer` to create visuals for non-time varying parame - `forecast_weeks`: An integer to indicate the number of weeks to forecast, used only if user desired priors next to posteriors. - `gq_samples`: Generated quantities samples from the posterior/prior distribution, index 2 in uciwweihr_gq_pp output. - `desired_params`: A list of lists of parameters to visualize. Each list will be visualized in a separate plot. Default is any parameter not in this list : ["alpha_t", "w_t", "rt_vals", "log_genes_mean", "H"] -- `bayes_dist_type`: A string to indicate if user is using Posterior or Prior distribution. Default is "Posterior". - `actual_non_time_varying_vals::uciwweihr_sim_params`: A uciwweihr_sim_params object of actual non-time varying parameter values if user has access to them. Default is nothing. - `save_plots::Bool=false`: A boolean to indicate if user wants to save the plots as pngs into a plots folder. - `plot_name_to_save`: A string to indicate the name of the plot to save. Default is "mcmc_nontime_varying_parameter_plots". @@ -30,7 +29,6 @@ function non_time_varying_param_vis( forecast_weeks; gq_samples=nothing, desired_params=nothing, - bayes_dist_type="Posterior", actual_non_time_varying_vals::uciwweihr_sim_params = nothing, save_plots::Bool=false, plot_name_to_save = "mcmc_nontime_varying_parameter_plots" @@ -71,7 +69,7 @@ function non_time_varying_param_vis( alpha = 0.1) histogram!(plt, curr_param_chain_df[:, curr_param], label = "Chain $chain (Posterior)", - title = "$bayes_dist_type $curr_param", + title = "$curr_param", bins = 50, normalize = :probability, xlabel = "Value for $curr_param", @@ -79,7 +77,7 @@ function non_time_varying_param_vis( alpha = 0.7, color = :blue, legend = :topright) - if !isnothing(actual_non_time_varying_vals) + if !isnothing(actual_non_time_varying_vals.time_points) actual_param_value = round(getfield(actual_non_time_varying_vals, Symbol(curr_param)), digits=3) vline!(plt, [actual_param_value], @@ -129,7 +127,6 @@ function non_time_varying_param_vis( forecast_weeks; gq_samples=nothing, desired_params=nothing, - bayes_dist_type="Posterior", actual_non_time_varying_vals::uciwweihr_sim_params = nothing, save_plots::Bool=false, plot_name_to_save = "mcmc_nontime_varying_parameter_plots" @@ -174,7 +171,7 @@ function non_time_varying_param_vis( alpha = 0.1) histogram!(plt, curr_param_chain_df[:, curr_param], label = "Chain $chain (Posterior)", - title = "$bayes_dist_type $curr_param", + title = "$curr_param", bins = 50, normalize = :probability, xlabel = "Value for $curr_param", @@ -182,7 +179,7 @@ function non_time_varying_param_vis( alpha = 0.7, color = :blue, legend = :topright) - if !isnothing(actual_non_time_varying_vals) + if !isnothing(actual_non_time_varying_vals.time_points) actual_param_value = round(getfield(actual_non_time_varying_vals, Symbol(curr_param)), digits=3) vline!(plt, [actual_param_value], @@ -223,7 +220,6 @@ end function non_time_varying_param_vis(; gq_samples=nothing, desired_params=nothing, - bayes_dist_type="Posterior", actual_non_time_varying_vals::uciwweihr_sim_params = nothing, save_plots::Bool=false, plot_name_to_save = "mcmc_nontime_varying_parameter_plots" @@ -237,7 +233,7 @@ function non_time_varying_param_vis(; curr_param_chain_df = filter(row -> row.chain == chain, gq_samples) plt = histogram(curr_param_chain_df[:, curr_param], label = "Chain $chain", - title = "$bayes_dist_type $curr_param", + title = "$curr_param", bins = 50, normalize = :probability, xlabel = "Probability", @@ -247,7 +243,7 @@ function non_time_varying_param_vis(; titlefont = font(10), legendfont = font(8), legend = :topright) - if !isnothing(actual_non_time_varying_vals) + if !isnothing(actual_non_time_varying_vals.time_points) actual_param_value = round(getfield(actual_non_time_varying_vals, Symbol(curr_param)), digits=3) vline!(plt, [actual_param_value], diff --git a/src/uciwweihr_model.jl b/src/uciwweihr_model.jl index ca18e7a..05f9cdf 100644 --- a/src/uciwweihr_model.jl +++ b/src/uciwweihr_model.jl @@ -147,8 +147,10 @@ The defaults for this fuction will follow those of the default simulation in gen gamma = gamma, nu = nu, w_t = w_t, + sigma_w = sigma_w, epsilon = epsilon, rt_vals = rt_vals, + sigma_Rt = sigma_Rt, rho_gene = rho_gene, tau = tau, df = df, @@ -274,8 +276,10 @@ The defaults for this fuction will follow those of the default simulation in gen gamma = gamma, nu = nu, w_t = w_t, + sigma_w = sigma_w, epsilon = epsilon, rt_vals = rt_vals, + sigma_Rt = sigma_Rt, sigma_hosp = sigma_hosp, H = sol_hosp, rt_init = rt_init, diff --git a/src/uciwweihr_visualizer.jl b/src/uciwweihr_visualizer.jl index 3117f76..23e1a7b 100644 --- a/src/uciwweihr_visualizer.jl +++ b/src/uciwweihr_visualizer.jl @@ -48,11 +48,12 @@ function uciwweihr_visualizer( obs_data_wastewater = nothing, actual_rt_vals = nothing, actual_w_t = nothing, - actual_non_time_varying_vals::uciwweihr_sim_params = nothing, + actual_non_time_varying_vals::uciwweihr_sim_params = uciwweihr_sim_params(ntuple(x->nothing, fieldcount(uciwweihr_sim_params))...), desired_params = [ ["E_init", "I_init", "H_init"], ["gamma", "nu", "epsilon"], ["rt_init", "w_init"], + ["sigma_w", "sigma_Rt"], ["rho_gene", "tau", "df"], ["sigma_hosp"] ], @@ -117,7 +118,6 @@ function uciwweihr_visualizer( forecast_weeks; gq_samples = gq_samples, desired_params = desired_params, - bayes_dist_type = bayes_dist_type, actual_non_time_varying_vals = actual_non_time_varying_vals, save_plots = save_plots, plot_name_to_save = plot_name_to_save_non_time_varying @@ -164,11 +164,12 @@ function uciwweihr_visualizer( obs_data_wastewater = nothing, actual_rt_vals = nothing, actual_w_t = nothing, - actual_non_time_varying_vals::uciwweihr_sim_params = nothing, + actual_non_time_varying_vals::uciwweihr_sim_params = uciwweihr_sim_params(ntuple(x->nothing, fieldcount(uciwweihr_sim_params))...), desired_params = [ ["E_init", "I_init", "H_init"], ["gamma", "nu", "epsilon"], ["rt_init", "w_init"], + ["sigma_w", "sigma_Rt"], ["rho_gene", "tau", "df"], ["sigma_hosp"] ], @@ -237,7 +238,6 @@ function uciwweihr_visualizer( forecast_weeks; gq_samples = gq_samples, desired_params = desired_params, - bayes_dist_type = bayes_dist_type, actual_non_time_varying_vals = actual_non_time_varying_vals, save_plots = save_plots, plot_name_to_save = plot_name_to_save_non_time_varying @@ -276,11 +276,12 @@ function uciwweihr_visualizer( obs_data_wastewater = nothing, actual_rt_vals = nothing, actual_w_t = nothing, - actual_non_time_varying_vals::uciwweihr_sim_params = nothing, + actual_non_time_varying_vals::uciwweihr_sim_params = uciwweihr_sim_params(ntuple(x->nothing, fieldcount(uciwweihr_sim_params))...), desired_params = [ ["E_init", "I_init", "H_init"], ["gamma", "nu", "epsilon"], ["rt_init", "w_init"], + ["sigma_w", "sigma_Rt"], ["rho_gene", "tau", "df"], ["sigma_hosp"] ], @@ -331,7 +332,6 @@ function uciwweihr_visualizer( non_time_varying_param_vis( gq_samples = gq_samples, desired_params = desired_params, - bayes_dist_type = bayes_dist_type, actual_non_time_varying_vals = actual_non_time_varying_vals, save_plots = save_plots, plot_name_to_save = plot_name_to_save_non_time_varying