Skip to content

Commit

Permalink
2024-10-17 update : fixed issue 58.
Browse files Browse the repository at this point in the history
  • Loading branch information
cbernalz committed Oct 17, 2024
1 parent d61cd9b commit 4bdfbb0
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 35 deletions.
37 changes: 19 additions & 18 deletions src/generate_simulation_data_uciwweihr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,26 +24,27 @@ Struct for holding parameters used in the UCIWWEIHR ODE compartmental model simu
- `w_init::Float64`: Initial value of the time-varying hospitalization rate, NOT USER SPECIFIED `create_uciwweihr_params` TAKES CARE OF THIS.
"""
struct uciwweihr_sim_params
time_points::Int64
seed::Int64
E_init::Int64
I_init::Int64
H_init::Int64
gamma::Float64
nu::Float64
epsilon::Float64
rho_gene::Float64
tau::Float64
df::Float64
sigma_hosp::Float64
Rt::Union{Float64, Vector{Float64}}
sigma_Rt::Float64
w::Union{Float64, Vector{Float64}}
sigma_w::Float64
rt_init::Float64
w_init::Float64
time_points::Union{Int64, Nothing}
seed::Union{Int64, Nothing}
E_init::Union{Int64, Nothing}
I_init::Union{Int64, Nothing}
H_init::Union{Int64, Nothing}
gamma::Union{Float64, Nothing}
nu::Union{Float64, Nothing}
epsilon::Union{Float64, Nothing}
rho_gene::Union{Float64, Nothing}
tau::Union{Float64, Nothing}
df::Union{Float64, Nothing}
sigma_hosp::Union{Float64, Nothing}
Rt::Union{Float64, Vector{Float64}, Nothing}
sigma_Rt::Union{Float64, Nothing}
w::Union{Float64, Vector{Float64}, Nothing}
sigma_w::Union{Float64, Nothing}
rt_init::Union{Float64, Nothing}
w_init::Union{Float64, Nothing}
end


"""
create_uciwweihr_sim_params(; kwargs...)
Expand Down
2 changes: 1 addition & 1 deletion src/mcmcdiags_vis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ function mcmcdiags_vis(;
)
push!(cat_plots, plt)

if !isnothing(actual_non_time_varying_vals)
if !isnothing(actual_non_time_varying_vals.time_points)
actual_param_value = round(getfield(actual_non_time_varying_vals, Symbol(param)), digits=3)
scatter!(plt, [1], Float64[actual_param_value],
label = "Actual Value : $actual_param_value",
Expand Down
16 changes: 6 additions & 10 deletions src/non_time_varying_param_vis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ Used in the `uciwweihr_visualizer` to create visuals for non-time varying parame
- `forecast_weeks`: An integer to indicate the number of weeks to forecast, used only if user desired priors next to posteriors.
- `gq_samples`: Generated quantities samples from the posterior/prior distribution, index 2 in uciwweihr_gq_pp output.
- `desired_params`: A list of lists of parameters to visualize. Each list will be visualized in a separate plot. Default is any parameter not in this list : ["alpha_t", "w_t", "rt_vals", "log_genes_mean", "H"]
- `bayes_dist_type`: A string to indicate if user is using Posterior or Prior distribution. Default is "Posterior".
- `actual_non_time_varying_vals::uciwweihr_sim_params`: A uciwweihr_sim_params object of actual non-time varying parameter values if user has access to them. Default is nothing.
- `save_plots::Bool=false`: A boolean to indicate if user wants to save the plots as pngs into a plots folder.
- `plot_name_to_save`: A string to indicate the name of the plot to save. Default is "mcmc_nontime_varying_parameter_plots".
Expand All @@ -30,7 +29,6 @@ function non_time_varying_param_vis(
forecast_weeks;
gq_samples=nothing,
desired_params=nothing,
bayes_dist_type="Posterior",
actual_non_time_varying_vals::uciwweihr_sim_params = nothing,
save_plots::Bool=false,
plot_name_to_save = "mcmc_nontime_varying_parameter_plots"
Expand Down Expand Up @@ -71,15 +69,15 @@ function non_time_varying_param_vis(
alpha = 0.1)
histogram!(plt, curr_param_chain_df[:, curr_param],
label = "Chain $chain (Posterior)",
title = "$bayes_dist_type $curr_param",
title = "$curr_param",
bins = 50,
normalize = :probability,
xlabel = "Value for $curr_param",
ylabel = "Probability",
alpha = 0.7,
color = :blue,
legend = :topright)
if !isnothing(actual_non_time_varying_vals)
if !isnothing(actual_non_time_varying_vals.time_points)
actual_param_value = round(getfield(actual_non_time_varying_vals, Symbol(curr_param)), digits=3)
vline!(plt,
[actual_param_value],
Expand Down Expand Up @@ -129,7 +127,6 @@ function non_time_varying_param_vis(
forecast_weeks;
gq_samples=nothing,
desired_params=nothing,
bayes_dist_type="Posterior",
actual_non_time_varying_vals::uciwweihr_sim_params = nothing,
save_plots::Bool=false,
plot_name_to_save = "mcmc_nontime_varying_parameter_plots"
Expand Down Expand Up @@ -174,15 +171,15 @@ function non_time_varying_param_vis(
alpha = 0.1)
histogram!(plt, curr_param_chain_df[:, curr_param],
label = "Chain $chain (Posterior)",
title = "$bayes_dist_type $curr_param",
title = "$curr_param",
bins = 50,
normalize = :probability,
xlabel = "Value for $curr_param",
ylabel = "Probability",
alpha = 0.7,
color = :blue,
legend = :topright)
if !isnothing(actual_non_time_varying_vals)
if !isnothing(actual_non_time_varying_vals.time_points)
actual_param_value = round(getfield(actual_non_time_varying_vals, Symbol(curr_param)), digits=3)
vline!(plt,
[actual_param_value],
Expand Down Expand Up @@ -223,7 +220,6 @@ end
function non_time_varying_param_vis(;
gq_samples=nothing,
desired_params=nothing,
bayes_dist_type="Posterior",
actual_non_time_varying_vals::uciwweihr_sim_params = nothing,
save_plots::Bool=false,
plot_name_to_save = "mcmc_nontime_varying_parameter_plots"
Expand All @@ -237,7 +233,7 @@ function non_time_varying_param_vis(;
curr_param_chain_df = filter(row -> row.chain == chain, gq_samples)
plt = histogram(curr_param_chain_df[:, curr_param],
label = "Chain $chain",
title = "$bayes_dist_type $curr_param",
title = "$curr_param",
bins = 50,
normalize = :probability,
xlabel = "Probability",
Expand All @@ -247,7 +243,7 @@ function non_time_varying_param_vis(;
titlefont = font(10),
legendfont = font(8),
legend = :topright)
if !isnothing(actual_non_time_varying_vals)
if !isnothing(actual_non_time_varying_vals.time_points)
actual_param_value = round(getfield(actual_non_time_varying_vals, Symbol(curr_param)), digits=3)
vline!(plt,
[actual_param_value],
Expand Down
4 changes: 4 additions & 0 deletions src/uciwweihr_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,10 @@ The defaults for this fuction will follow those of the default simulation in gen
gamma = gamma,
nu = nu,
w_t = w_t,
sigma_w = sigma_w,
epsilon = epsilon,
rt_vals = rt_vals,
sigma_Rt = sigma_Rt,
rho_gene = rho_gene,
tau = tau,
df = df,
Expand Down Expand Up @@ -274,8 +276,10 @@ The defaults for this fuction will follow those of the default simulation in gen
gamma = gamma,
nu = nu,
w_t = w_t,
sigma_w = sigma_w,
epsilon = epsilon,
rt_vals = rt_vals,
sigma_Rt = sigma_Rt,
sigma_hosp = sigma_hosp,
H = sol_hosp,
rt_init = rt_init,
Expand Down
12 changes: 6 additions & 6 deletions src/uciwweihr_visualizer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,12 @@ function uciwweihr_visualizer(
obs_data_wastewater = nothing,
actual_rt_vals = nothing,
actual_w_t = nothing,
actual_non_time_varying_vals::uciwweihr_sim_params = nothing,
actual_non_time_varying_vals::uciwweihr_sim_params = uciwweihr_sim_params(ntuple(x->nothing, fieldcount(uciwweihr_sim_params))...),
desired_params = [
["E_init", "I_init", "H_init"],
["gamma", "nu", "epsilon"],
["rt_init", "w_init"],
["sigma_w", "sigma_Rt"],
["rho_gene", "tau", "df"],
["sigma_hosp"]
],
Expand Down Expand Up @@ -117,7 +118,6 @@ function uciwweihr_visualizer(
forecast_weeks;
gq_samples = gq_samples,
desired_params = desired_params,
bayes_dist_type = bayes_dist_type,
actual_non_time_varying_vals = actual_non_time_varying_vals,
save_plots = save_plots,
plot_name_to_save = plot_name_to_save_non_time_varying
Expand Down Expand Up @@ -164,11 +164,12 @@ function uciwweihr_visualizer(
obs_data_wastewater = nothing,
actual_rt_vals = nothing,
actual_w_t = nothing,
actual_non_time_varying_vals::uciwweihr_sim_params = nothing,
actual_non_time_varying_vals::uciwweihr_sim_params = uciwweihr_sim_params(ntuple(x->nothing, fieldcount(uciwweihr_sim_params))...),
desired_params = [
["E_init", "I_init", "H_init"],
["gamma", "nu", "epsilon"],
["rt_init", "w_init"],
["sigma_w", "sigma_Rt"],
["rho_gene", "tau", "df"],
["sigma_hosp"]
],
Expand Down Expand Up @@ -237,7 +238,6 @@ function uciwweihr_visualizer(
forecast_weeks;
gq_samples = gq_samples,
desired_params = desired_params,
bayes_dist_type = bayes_dist_type,
actual_non_time_varying_vals = actual_non_time_varying_vals,
save_plots = save_plots,
plot_name_to_save = plot_name_to_save_non_time_varying
Expand Down Expand Up @@ -276,11 +276,12 @@ function uciwweihr_visualizer(
obs_data_wastewater = nothing,
actual_rt_vals = nothing,
actual_w_t = nothing,
actual_non_time_varying_vals::uciwweihr_sim_params = nothing,
actual_non_time_varying_vals::uciwweihr_sim_params = uciwweihr_sim_params(ntuple(x->nothing, fieldcount(uciwweihr_sim_params))...),
desired_params = [
["E_init", "I_init", "H_init"],
["gamma", "nu", "epsilon"],
["rt_init", "w_init"],
["sigma_w", "sigma_Rt"],
["rho_gene", "tau", "df"],
["sigma_hosp"]
],
Expand Down Expand Up @@ -331,7 +332,6 @@ function uciwweihr_visualizer(
non_time_varying_param_vis(
gq_samples = gq_samples,
desired_params = desired_params,
bayes_dist_type = bayes_dist_type,
actual_non_time_varying_vals = actual_non_time_varying_vals,
save_plots = save_plots,
plot_name_to_save = plot_name_to_save_non_time_varying
Expand Down

0 comments on commit 4bdfbb0

Please sign in to comment.