-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
"2024-07-29 update : adding visualizer along with model fitting and g…
…q_pp functions."
- Loading branch information
Showing
13 changed files
with
618 additions
and
59 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,3 +4,4 @@ | |
/Manifest.toml | ||
/docs/Manifest.toml | ||
/docs/build/ | ||
src/testfile.jl |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
[deps] | ||
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" | ||
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" | ||
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd" | ||
UCIWWEIHR = "176850df-a79a-485c-b76d-f2c16e00fafb" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
```@setup tutorial | ||
using Plots, StatsPlots; gr() | ||
Plots.reset_defaults() | ||
``` | ||
|
||
# [Generating Posterior Distribution Samples with UCIWWEIHR ODE compartmental based model.](@id uciwwiehr_model_fitting) | ||
|
||
This package has a way to sample from a posterior or prior that is defined in the future paper using the `uciwweihr_fit.jl` and `uciwweihr_model.jl`. We can then generate desired quantities and forecast for a given time period with the posterior predictive distribution, using `uciwweihr_gq_pp.jl`. We first generate data using the `generate_simulation_data_uciwweihr` function which is a non-mispecified version of the model. | ||
|
||
|
||
## 1. Data Generation. | ||
|
||
``` @example tutorial | ||
using UCIWWEIHR | ||
# Running simulation function with defaults | ||
df = generate_simulation_data_uciwweihr() | ||
first(df, 5) | ||
``` | ||
|
||
## 2. Sampling from the Posterior Distribution and Posterior Predictive Distribution. | ||
|
||
Here we sample from the posterior distribution using the `uciwweihr_fit.jl` function. First, we setup some presets, then have an array where index 1 contains the posterior/prior predictive samples, index 2 contains the posterior/prior generated quantities samples, and index 3 contains the original sampled parameters for the model. | ||
|
||
``` @example tutorial | ||
data_hosp = df.hosp | ||
data_wastewater = df.log_ww_conc | ||
obstimes = df.obstimes | ||
param_change_times = 1:7:length(obstimes) # Change every week | ||
priors_only = false | ||
n_samples = 50 | ||
samples = uciwweihr_fit( | ||
data_hosp, | ||
data_wastewater, | ||
obstimes, | ||
param_change_times, | ||
priors_only, | ||
n_samples | ||
) | ||
model_output = uciwweihr_gq_pp( | ||
samples, | ||
data_hosp, | ||
data_wastewater, | ||
obstimes, | ||
param_change_times | ||
) | ||
first(model_output[1][:,1:5], 5) | ||
``` | ||
|
||
``` @example tutorial | ||
first(model_output[2][:,1:5], 5) | ||
``` | ||
|
||
``` @example tutorial | ||
first(model_output[3][:,1:5], 5) | ||
``` | ||
|
||
## 3. MCMC Diagnostic Plots/Results Along with Posterior Predictive Distribution. | ||
|
||
We also provide a very basic way to visualize some MCMC diagnostics along with effective sample sizes of desired generated quantities(does not include functionality for time-varying quantities). Along with this, we can also visualize the posterior predictive distribution with actual observed values, which can be used to examine forecasts generated by the model. | ||
|
||
```@example tutorial | ||
uciwweihr_visualizer(gq_samples = model_output[2], save_plots = true) | ||
``` | ||
![Plot 1](plots/mcmc_diagnosis_plots.png) | ||
|
||
|
||
### [Tutorial Contents](@ref tutorial_home) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
""" | ||
power(a,b) | ||
Raise `a` to the `b` power | ||
""" | ||
function power(a,b) | ||
a^b | ||
end | ||
|
||
""" | ||
ChainsCustomIndex(c::Chains, indices_to_keep::BitMatrix) | ||
Reduce Chains object to only wanted indices. | ||
Function created by Damon Bayer. | ||
""" | ||
function ChainsCustomIndex(c::Chains, indices_to_keep::BitMatrix) | ||
min_length = minimum(mapslices(sum, indices_to_keep, dims = 1)) | ||
v = c.value | ||
new_v = copy(v.data) | ||
new_v_filtered = cat([new_v[indices_to_keep[:, i], :, i][1:min_length, :] for i in 1:size(v, 3)]..., dims = 3) | ||
aa = AxisArray(new_v_filtered; iter = v.axes[1].val[1:min_length], var = v.axes[2].val, chain = v.axes[3].val) | ||
|
||
Chains(aa, c.logevidence, c.name_map, c.info) | ||
end | ||
|
||
# Series of functions for creating correctly scaled parameter draws. | ||
# code snippet shared by @torfjelde | ||
# https://gist.github.com/torfjelde/37be5a672d29e473983b8e82b45c2e41 | ||
generate_names(val) = generate_names("", val) | ||
generate_names(vn_str::String, val::Real) = [vn_str;] | ||
function generate_names(vn_str::String, val::NamedTuple) | ||
return map(keys(val)) do k | ||
generate_names("$(vn_str)$(k)", val[k]) | ||
end | ||
end | ||
function generate_names(vn_str::String, val::AbstractArray{<:Real}) | ||
results = String[] | ||
for idx in CartesianIndices(val) | ||
s = join(idx.I, ",") | ||
push!(results, "$vn_str[$s]") | ||
end | ||
return results | ||
end | ||
|
||
function generate_names(vn_str::String, val::AbstractArray{<:AbstractArray}) | ||
results = String[] | ||
for idx in CartesianIndices(val) | ||
s1 = join(idx.I, ",") | ||
inner_results = map(f("", val[idx])) do s2 | ||
"$vn_str[$s1]$s2" | ||
end | ||
append!(results, inner_results) | ||
end | ||
return results | ||
end | ||
|
||
flatten(val::Real) = [val;] | ||
function flatten(val::AbstractArray{<:Real}) | ||
return mapreduce(vcat, CartesianIndices(val)) do i | ||
val[i] | ||
end | ||
end | ||
function flatten(val::AbstractArray{<:AbstractArray}) | ||
return mapreduce(vcat, CartesianIndices(val)) do i | ||
flatten(val[i]) | ||
end | ||
end | ||
|
||
function vectup2chainargs(ts::AbstractVector{<:NamedTuple}) | ||
ks = keys(first(ts)) | ||
vns = mapreduce(vcat, ks) do k | ||
generate_names(string(k), first(ts)[k]) | ||
end | ||
vals = map(eachindex(ts)) do i | ||
mapreduce(vcat, ks) do k | ||
flatten(ts[i][k]) | ||
end | ||
end | ||
arr_tmp = reduce(hcat, vals)' | ||
arr = reshape(arr_tmp, (size(arr_tmp)..., 1)) # treat as 1 chain | ||
return Array(arr), vns | ||
end | ||
|
||
function vectup2chainargs(ts::AbstractMatrix{<:NamedTuple}) | ||
num_samples, num_chains = size(ts) | ||
res = map(1:num_chains) do chain_idx | ||
vectup2chainargs(ts[:, chain_idx]) | ||
end | ||
|
||
vals = getindex.(res, 1) | ||
vns = getindex.(res, 2) | ||
|
||
# Verify that the variable names are indeed the same | ||
vns_union = reduce(union, vns) | ||
@assert all(isempty.(setdiff.(vns, Ref(vns_union)))) "variable names differ between chains" | ||
|
||
arr = cat(vals...; dims = 3) | ||
|
||
return arr, first(vns) | ||
end | ||
|
||
function MCMCChains.Chains(ts::AbstractArray{<:NamedTuple}) | ||
return MCMCChains.Chains(vectup2chainargs(ts)...) | ||
end | ||
|
||
|
||
""" | ||
save_plots_to_docs(plot, filename; format = "png") | ||
Saves plots to docs/plots directory. | ||
Function created by Christian Bernal Zelaya. | ||
""" | ||
function save_plots_to_docs(plot, filename; format = "png") | ||
doc_loc = "plots" | ||
if !isdir(doc_loc) | ||
mkdir(doc_loc) | ||
end | ||
|
||
file_target_path = joinpath(doc_loc, "$filename.$format") | ||
savefig(plot, file_target_path) | ||
println("Plot saved to $file_target_path") | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
# Fitting UCIWWEIHR model | ||
# ------------------------------------------------- | ||
""" | ||
uciwweihr_fit(...) | ||
This is the sampler for the bayesian semi-parametric model for the wastewater EIHR compartmental model. | ||
The defaults for this fuction will follow those of the default simulation in generate_simulation_data_ww_eihr.jl function. | ||
# Arguments | ||
- `data_hosp`: An array of hospital data. | ||
- `data_wastewater`: An array of pathogen genome concentration in localized wastewater data. | ||
- `obstimes`: An array of timepoints for observed hosp/wastewater. | ||
- `priors_only::Bool=false`: A boolean to indicate if only priors are to be sampled. | ||
- `n_samples::Int64=500`: Number of samples to be drawn. | ||
- `n_chains::Int64=1`: Number of chains to be run. | ||
- `seed::Int64=2024`: Seed for the random number generator. | ||
- `E_init_sd::Float64=50.0`: Standard deviation for the initial number of exposed individuals. | ||
- `E_init_mean::Int64=200`: Mean for the initial number of exposed individuals. | ||
- `I_init_sd::Float64=20.0`: Standard deviation for the initial number of infected individuals. | ||
- `I_init_mean::Int64=100`: Mean for the initial number of infected individuals. | ||
- `H_init_sd::Float64=5.0`: Standard deviation for the initial number of hospitalized individuals. | ||
- `H_init_mean::Int64=20`: Mean for the initial number of hospitalized individuals. | ||
- `gamma_sd::Float64=0.02`: Standard deviation for the rate of incubation. | ||
- `log_gamma_mean::Float64=log(1/4)`: Mean for the rate of incubation on log scale. | ||
- `nu_sd::Float64=0.02`: Standard deviation for the rate of leaving the infected compartment. | ||
- `log_nu_mean::Float64=log(1/7)`: Mean for the rate of leaving the infected compartment on the log scale. | ||
- `epsilon_sd::Float64=0.02`: Standard deviation for the rate of hospitalization recovery. | ||
- `log_epsilon_mean::Float64=log(1/5)`: Mean for the rate of hospitalization recovery on the log scale. | ||
- `rho_gene_sd::Float64=0.02`: Standard deviation for the rho prior. | ||
- `log_rho_gene_mean::Float64=log(0.011)`: Mean for the row prior on log scale. | ||
- `tau_sd::Float64=0.02`: Standard deviation for the scale/variation of the log scale data. | ||
- `log_tau_mean::Float64=log(0.1)`: Mean for the scale/variation of the log scale data on log scale itself. | ||
- `df_shape::Float64=2.0`: Shape parameter for the gamma distribution. | ||
- `df_scale::Float64=10.0`: Scale parameter for the gamma distribution. | ||
- `sigma_hosp_sd::Float64=50.0`: Standard deviation for the negative binomial distribution for hospital data. | ||
- `sigma_hosp_mean::Float64=500.0`: Mean for the negative binomial distribution for hospital data. | ||
- `Rt_init_sd::Float64=0.3`: Standard deviation for the initial value of the time-varying reproduction number. | ||
- `Rt_init_mean::Float64=0.2`: Mean for the initial value of the time-varying reproduction number. | ||
- `sigma_Rt_sd::Float64=0.2`: Standard deviation for normal prior of log time-varying reproduction number standard deviation. | ||
- `sigma_Rt_mean::Float64=-3.0`: Mean for normal prior of log time-varying reproduction number standard deviation. | ||
- `w_init_sd::Float64=0.1`: Standard deviation for the initial value of the time-varying hospitalization rate. | ||
- `w_init_mean::Float64=log(0.35)`: Mean for the initial value of the time-varying hospitalization rate. | ||
- `sigma_w_sd::Float64=0.2`: Standard deviation for normal prior of log time-varying hospitalization rate standard deviation. | ||
- `sigma_w_mean::Float64=-3.5`: Mean for normal prior of time-varying hospitalization rate standard deviation. | ||
- `param_change_times::Array{Float64}`: An array of timepoints where the parameters change. | ||
# Returns | ||
- Samples from the posterior or prior distribution. | ||
""" | ||
function uciwweihr_fit( | ||
data_hosp, | ||
data_wastewater, | ||
obstimes, | ||
param_change_times, | ||
priors_only::Bool=false, | ||
n_samples::Int64=500, n_chains::Int64=1, seed::Int64=2024, | ||
E_init_sd::Float64=50.0, E_init_mean::Int64=200, | ||
I_init_sd::Float64=20.0, I_init_mean::Int64=100, | ||
H_init_sd::Float64=5.0, H_init_mean::Int64=20, | ||
gamma_sd::Float64=0.02, log_gamma_mean::Float64=log(1/4), | ||
nu_sd::Float64=0.02, log_nu_mean::Float64=log(1/7), | ||
epsilon_sd::Float64=0.02, log_epsilon_mean::Float64=log(1/5), | ||
rho_gene_sd::Float64=0.02, log_rho_gene_mean::Float64=log(0.011), | ||
tau_sd::Float64=0.02, log_tau_mean::Float64=log(0.1), | ||
df_shape::Float64=2.0, df_scale::Float64=10.0, | ||
sigma_hosp_sd::Float64=50.0, sigma_hosp_mean::Float64=500.0, | ||
Rt_init_sd::Float64=0.3, Rt_init_mean::Float64=0.2, | ||
sigma_Rt_sd::Float64=0.2, sigma_Rt_mean::Float64=-3.0, | ||
w_init_sd::Float64=0.1, w_init_mean::Float64=log(0.35), | ||
sigma_w_sd::Float64=0.2, sigma_w_mean::Float64=-3.5 | ||
) | ||
|
||
|
||
obstimes = convert(Vector{Float64}, obstimes) | ||
param_change_times = convert(Vector{Float64}, param_change_times) | ||
|
||
|
||
my_model = uciwweihr_model( | ||
data_hosp, | ||
data_wastewater, | ||
obstimes, | ||
param_change_times, | ||
E_init_sd, E_init_mean, | ||
I_init_sd, I_init_mean, | ||
H_init_sd, H_init_mean, | ||
gamma_sd, log_gamma_mean, | ||
nu_sd, log_nu_mean, | ||
epsilon_sd, log_epsilon_mean, | ||
rho_gene_sd, log_rho_gene_mean, | ||
tau_sd, log_tau_mean, | ||
df_shape, df_scale, | ||
sigma_hosp_sd, sigma_hosp_mean, | ||
Rt_init_sd, Rt_init_mean, | ||
sigma_Rt_sd, sigma_Rt_mean, | ||
w_init_sd, w_init_mean, | ||
sigma_w_sd, sigma_w_mean | ||
) | ||
|
||
|
||
# Sample Posterior | ||
if priors_only | ||
Random.seed!(seed) | ||
samples = sample(my_model, Prior(), MCMCThreads(), 400, n_chains) | ||
else | ||
Random.seed!(seed) | ||
samples = sample(my_model, NUTS(), MCMCThreads(), n_samples, n_chains) | ||
end | ||
return(samples) | ||
end |
Oops, something went wrong.