Skip to content

Commit

Permalink
"2024-07-29 update : adding visualizer along with model fitting and g…
Browse files Browse the repository at this point in the history
…q_pp functions."
  • Loading branch information
cbernalz committed Jul 30, 2024
1 parent 7dfed2e commit c97ec57
Show file tree
Hide file tree
Showing 13 changed files with 618 additions and 59 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
/Manifest.toml
/docs/Manifest.toml
/docs/build/
src/testfile.jl
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"

[compat]
Expand Down
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"
UCIWWEIHR = "176850df-a79a-485c-b76d-f2c16e00fafb"
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ makedocs(;
"GETTING STARTED" => "tutorials/getting_started.md",
"UCIWWEIHR SIMULATION DATA" => "tutorials/uciwweihr_simulation_data.md",
"AGENT-BASED SIMULATION DATA" => "tutorials/agent_based_simulation_data.md",
"UCIWWEIHR FITTING MODEL" => "tutorials/uciwweihr_model_fitting.md",
]
,
"NEWS" => "news.md",
Expand Down
2 changes: 1 addition & 1 deletion docs/src/tutorials/agent_based_simulation_data.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ first(df, 5)

## 2. Visualizing SEIHR compartments.

We can also use the [TidierPlots](https://tidierorg.github.io/TidierPlots.jl/stable/) package to visualize the data generated.
We can also use the [Plots](https://docs.juliaplots.org/stable/) package to visualize the data generated.

```@example tutorial
plot(df.Time, df.S, label = "Suseptible",
Expand Down
70 changes: 70 additions & 0 deletions docs/src/tutorials/uciwweihr_model_fitting.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
```@setup tutorial
using Plots, StatsPlots; gr()
Plots.reset_defaults()
```

# [Generating Posterior Distribution Samples with UCIWWEIHR ODE compartmental based model.](@id uciwwiehr_model_fitting)

This package has a way to sample from a posterior or prior that is defined in the future paper using the `uciwweihr_fit.jl` and `uciwweihr_model.jl`. We can then generate desired quantities and forecast for a given time period with the posterior predictive distribution, using `uciwweihr_gq_pp.jl`. We first generate data using the `generate_simulation_data_uciwweihr` function which is a non-mispecified version of the model.


## 1. Data Generation.

``` @example tutorial
using UCIWWEIHR
# Running simulation function with defaults
df = generate_simulation_data_uciwweihr()
first(df, 5)
```

## 2. Sampling from the Posterior Distribution and Posterior Predictive Distribution.

Here we sample from the posterior distribution using the `uciwweihr_fit.jl` function. First, we setup some presets, then have an array where index 1 contains the posterior/prior predictive samples, index 2 contains the posterior/prior generated quantities samples, and index 3 contains the original sampled parameters for the model.

``` @example tutorial
data_hosp = df.hosp
data_wastewater = df.log_ww_conc
obstimes = df.obstimes
param_change_times = 1:7:length(obstimes) # Change every week
priors_only = false
n_samples = 50
samples = uciwweihr_fit(
data_hosp,
data_wastewater,
obstimes,
param_change_times,
priors_only,
n_samples
)
model_output = uciwweihr_gq_pp(
samples,
data_hosp,
data_wastewater,
obstimes,
param_change_times
)
first(model_output[1][:,1:5], 5)
```

``` @example tutorial
first(model_output[2][:,1:5], 5)
```

``` @example tutorial
first(model_output[3][:,1:5], 5)
```

## 3. MCMC Diagnostic Plots/Results Along with Posterior Predictive Distribution.

We also provide a very basic way to visualize some MCMC diagnostics along with effective sample sizes of desired generated quantities(does not include functionality for time-varying quantities). Along with this, we can also visualize the posterior predictive distribution with actual observed values, which can be used to examine forecasts generated by the model.

```@example tutorial
uciwweihr_visualizer(gq_samples = model_output[2], save_plots = true)
```
![Plot 1](plots/mcmc_diagnosis_plots.png)


### [Tutorial Contents](@ref tutorial_home)
2 changes: 1 addition & 1 deletion docs/src/tutorials/uciwweihr_simulation_data.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ first(df, 5)

## 2. Visualizing UCIWWEIHR model results.

Here we can make simple plots to visualize the data generated using the [TidierPlots](https://tidierorg.github.io/TidierPlots.jl/stable/) package.
Here we can make simple plots to visualize the data generated using the [Plots](https://docs.juliaplots.org/stable/) package.

### 2.1. Concentration of pathogen genome in wastewater(WW).
```@example tutorial
Expand Down
13 changes: 11 additions & 2 deletions src/UCIWWEIHR.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,25 @@ using Plots

include("generate_simulation_data_uciwweihr.jl")
include("generate_simulation_data_agent.jl")
include("bayes_eihr_model.jl")
include("eihr_ode.jl")
include("negativebinomial2.jl")
include("generalizedtdist.jl")
include("uciwweihr_model.jl")
include("uciwweihr_fit.jl")
include("uciwweihr_gq_pp.jl")
include("uciwweihr_visualizer.jl")
include("helper_functions.jl")

export eihr_ode
export generate_simulation_data_uciwweihr
export generate_simulation_data_agent
export bayes_eihr_model
export NegativeBinomial2
export GeneralizedTDist
export uciwweihr_model
export uciwweihr_fit
export uciwweihr_gq_pp
export uciwweihr_visualizer
export ChainsCustomIndexs
export save_plots_to_docs

end
125 changes: 125 additions & 0 deletions src/helper_functions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
"""
power(a,b)
Raise `a` to the `b` power
"""
function power(a,b)
a^b
end

"""
ChainsCustomIndex(c::Chains, indices_to_keep::BitMatrix)
Reduce Chains object to only wanted indices.
Function created by Damon Bayer.
"""
function ChainsCustomIndex(c::Chains, indices_to_keep::BitMatrix)
min_length = minimum(mapslices(sum, indices_to_keep, dims = 1))
v = c.value
new_v = copy(v.data)
new_v_filtered = cat([new_v[indices_to_keep[:, i], :, i][1:min_length, :] for i in 1:size(v, 3)]..., dims = 3)
aa = AxisArray(new_v_filtered; iter = v.axes[1].val[1:min_length], var = v.axes[2].val, chain = v.axes[3].val)

Chains(aa, c.logevidence, c.name_map, c.info)
end

# Series of functions for creating correctly scaled parameter draws.
# code snippet shared by @torfjelde
# https://gist.github.com/torfjelde/37be5a672d29e473983b8e82b45c2e41
generate_names(val) = generate_names("", val)
generate_names(vn_str::String, val::Real) = [vn_str;]
function generate_names(vn_str::String, val::NamedTuple)
return map(keys(val)) do k
generate_names("$(vn_str)$(k)", val[k])
end
end
function generate_names(vn_str::String, val::AbstractArray{<:Real})
results = String[]
for idx in CartesianIndices(val)
s = join(idx.I, ",")
push!(results, "$vn_str[$s]")
end
return results
end

function generate_names(vn_str::String, val::AbstractArray{<:AbstractArray})
results = String[]
for idx in CartesianIndices(val)
s1 = join(idx.I, ",")
inner_results = map(f("", val[idx])) do s2
"$vn_str[$s1]$s2"
end
append!(results, inner_results)
end
return results
end

flatten(val::Real) = [val;]
function flatten(val::AbstractArray{<:Real})
return mapreduce(vcat, CartesianIndices(val)) do i
val[i]
end
end
function flatten(val::AbstractArray{<:AbstractArray})
return mapreduce(vcat, CartesianIndices(val)) do i
flatten(val[i])
end
end

function vectup2chainargs(ts::AbstractVector{<:NamedTuple})
ks = keys(first(ts))
vns = mapreduce(vcat, ks) do k
generate_names(string(k), first(ts)[k])
end
vals = map(eachindex(ts)) do i
mapreduce(vcat, ks) do k
flatten(ts[i][k])
end
end
arr_tmp = reduce(hcat, vals)'
arr = reshape(arr_tmp, (size(arr_tmp)..., 1)) # treat as 1 chain
return Array(arr), vns
end

function vectup2chainargs(ts::AbstractMatrix{<:NamedTuple})
num_samples, num_chains = size(ts)
res = map(1:num_chains) do chain_idx
vectup2chainargs(ts[:, chain_idx])
end

vals = getindex.(res, 1)
vns = getindex.(res, 2)

# Verify that the variable names are indeed the same
vns_union = reduce(union, vns)
@assert all(isempty.(setdiff.(vns, Ref(vns_union)))) "variable names differ between chains"

arr = cat(vals...; dims = 3)

return arr, first(vns)
end

function MCMCChains.Chains(ts::AbstractArray{<:NamedTuple})
return MCMCChains.Chains(vectup2chainargs(ts)...)
end


"""
save_plots_to_docs(plot, filename; format = "png")
Saves plots to docs/plots directory.
Function created by Christian Bernal Zelaya.
"""
function save_plots_to_docs(plot, filename; format = "png")
doc_loc = "plots"
if !isdir(doc_loc)
mkdir(doc_loc)
end

file_target_path = joinpath(doc_loc, "$filename.$format")
savefig(plot, file_target_path)
println("Plot saved to $file_target_path")

end
108 changes: 108 additions & 0 deletions src/uciwweihr_fit.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Fitting UCIWWEIHR model
# -------------------------------------------------
"""
uciwweihr_fit(...)
This is the sampler for the bayesian semi-parametric model for the wastewater EIHR compartmental model.
The defaults for this fuction will follow those of the default simulation in generate_simulation_data_ww_eihr.jl function.
# Arguments
- `data_hosp`: An array of hospital data.
- `data_wastewater`: An array of pathogen genome concentration in localized wastewater data.
- `obstimes`: An array of timepoints for observed hosp/wastewater.
- `priors_only::Bool=false`: A boolean to indicate if only priors are to be sampled.
- `n_samples::Int64=500`: Number of samples to be drawn.
- `n_chains::Int64=1`: Number of chains to be run.
- `seed::Int64=2024`: Seed for the random number generator.
- `E_init_sd::Float64=50.0`: Standard deviation for the initial number of exposed individuals.
- `E_init_mean::Int64=200`: Mean for the initial number of exposed individuals.
- `I_init_sd::Float64=20.0`: Standard deviation for the initial number of infected individuals.
- `I_init_mean::Int64=100`: Mean for the initial number of infected individuals.
- `H_init_sd::Float64=5.0`: Standard deviation for the initial number of hospitalized individuals.
- `H_init_mean::Int64=20`: Mean for the initial number of hospitalized individuals.
- `gamma_sd::Float64=0.02`: Standard deviation for the rate of incubation.
- `log_gamma_mean::Float64=log(1/4)`: Mean for the rate of incubation on log scale.
- `nu_sd::Float64=0.02`: Standard deviation for the rate of leaving the infected compartment.
- `log_nu_mean::Float64=log(1/7)`: Mean for the rate of leaving the infected compartment on the log scale.
- `epsilon_sd::Float64=0.02`: Standard deviation for the rate of hospitalization recovery.
- `log_epsilon_mean::Float64=log(1/5)`: Mean for the rate of hospitalization recovery on the log scale.
- `rho_gene_sd::Float64=0.02`: Standard deviation for the rho prior.
- `log_rho_gene_mean::Float64=log(0.011)`: Mean for the row prior on log scale.
- `tau_sd::Float64=0.02`: Standard deviation for the scale/variation of the log scale data.
- `log_tau_mean::Float64=log(0.1)`: Mean for the scale/variation of the log scale data on log scale itself.
- `df_shape::Float64=2.0`: Shape parameter for the gamma distribution.
- `df_scale::Float64=10.0`: Scale parameter for the gamma distribution.
- `sigma_hosp_sd::Float64=50.0`: Standard deviation for the negative binomial distribution for hospital data.
- `sigma_hosp_mean::Float64=500.0`: Mean for the negative binomial distribution for hospital data.
- `Rt_init_sd::Float64=0.3`: Standard deviation for the initial value of the time-varying reproduction number.
- `Rt_init_mean::Float64=0.2`: Mean for the initial value of the time-varying reproduction number.
- `sigma_Rt_sd::Float64=0.2`: Standard deviation for normal prior of log time-varying reproduction number standard deviation.
- `sigma_Rt_mean::Float64=-3.0`: Mean for normal prior of log time-varying reproduction number standard deviation.
- `w_init_sd::Float64=0.1`: Standard deviation for the initial value of the time-varying hospitalization rate.
- `w_init_mean::Float64=log(0.35)`: Mean for the initial value of the time-varying hospitalization rate.
- `sigma_w_sd::Float64=0.2`: Standard deviation for normal prior of log time-varying hospitalization rate standard deviation.
- `sigma_w_mean::Float64=-3.5`: Mean for normal prior of time-varying hospitalization rate standard deviation.
- `param_change_times::Array{Float64}`: An array of timepoints where the parameters change.
# Returns
- Samples from the posterior or prior distribution.
"""
function uciwweihr_fit(
data_hosp,
data_wastewater,
obstimes,
param_change_times,
priors_only::Bool=false,
n_samples::Int64=500, n_chains::Int64=1, seed::Int64=2024,
E_init_sd::Float64=50.0, E_init_mean::Int64=200,
I_init_sd::Float64=20.0, I_init_mean::Int64=100,
H_init_sd::Float64=5.0, H_init_mean::Int64=20,
gamma_sd::Float64=0.02, log_gamma_mean::Float64=log(1/4),
nu_sd::Float64=0.02, log_nu_mean::Float64=log(1/7),
epsilon_sd::Float64=0.02, log_epsilon_mean::Float64=log(1/5),
rho_gene_sd::Float64=0.02, log_rho_gene_mean::Float64=log(0.011),
tau_sd::Float64=0.02, log_tau_mean::Float64=log(0.1),
df_shape::Float64=2.0, df_scale::Float64=10.0,
sigma_hosp_sd::Float64=50.0, sigma_hosp_mean::Float64=500.0,
Rt_init_sd::Float64=0.3, Rt_init_mean::Float64=0.2,
sigma_Rt_sd::Float64=0.2, sigma_Rt_mean::Float64=-3.0,
w_init_sd::Float64=0.1, w_init_mean::Float64=log(0.35),
sigma_w_sd::Float64=0.2, sigma_w_mean::Float64=-3.5
)


obstimes = convert(Vector{Float64}, obstimes)
param_change_times = convert(Vector{Float64}, param_change_times)


my_model = uciwweihr_model(
data_hosp,
data_wastewater,
obstimes,
param_change_times,
E_init_sd, E_init_mean,
I_init_sd, I_init_mean,
H_init_sd, H_init_mean,
gamma_sd, log_gamma_mean,
nu_sd, log_nu_mean,
epsilon_sd, log_epsilon_mean,
rho_gene_sd, log_rho_gene_mean,
tau_sd, log_tau_mean,
df_shape, df_scale,
sigma_hosp_sd, sigma_hosp_mean,
Rt_init_sd, Rt_init_mean,
sigma_Rt_sd, sigma_Rt_mean,
w_init_sd, w_init_mean,
sigma_w_sd, sigma_w_mean
)


# Sample Posterior
if priors_only
Random.seed!(seed)
samples = sample(my_model, Prior(), MCMCThreads(), 400, n_chains)
else
Random.seed!(seed)
samples = sample(my_model, NUTS(), MCMCThreads(), n_samples, n_chains)
end
return(samples)
end
Loading

0 comments on commit c97ec57

Please sign in to comment.