Skip to content

Commit

Permalink
Merge pull request #56 from cbernalz/54-allow-for-misaligned-timepoin…
Browse files Browse the repository at this point in the history
…ts-for-ww-and-hosp-data

2024-09-01 update : adding misaligned timepoints for WW and Hosp data.
  • Loading branch information
cbernalz authored Sep 2, 2024
2 parents fa4d2a5 + 806a52c commit 1cacb57
Show file tree
Hide file tree
Showing 9 changed files with 168 additions and 102 deletions.
23 changes: 14 additions & 9 deletions docs/src/tutorials/uciwwiehr_model_fitting_forecast.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,23 +34,26 @@ first(df_ext, 5)

## 2. Sampling from the Posterior Distribution and Posterior Predictive Distribution.

Here we sample from the posterior distribution using the `uciwweihr_fit.jl` function. First, we setup some presets, where we need to use `create_uciwweihr_model_params()` to get default parameters for the model. Then we have an array where index 1 contains the posterior/prior predictive samples, index 2 contains the posterior/prior generated quantities samples, and index 3 contains the original sampled parameters for the model. The difference here is that we set `forecast = true` and `forecast_weeks = 4` to forecast 4 weeks into the future.
Here we sample from the posterior distribution using the `uciwweihr_fit.jl` function. First, we setup some presets, where we need to use `create_uciwweihr_model_params()` to get default parameters for the model. Then we have an array where index 1 contains the posterior/prior predictive samples, index 2 contains the posterior/prior generated quantities samples, and index 3 contains the original sampled parameters for the model. The difference here is that we set `forecast = true` and `forecast_weeks = 4` to forecast 4 weeks into the future. One other thing to note, is that we allow misalignment of hospital and wastewater data's observed times. For this tutorial, we use the same observed points.

``` @example tutorial_forecast
data_hosp = df.hosp
data_wastewater = df.log_ww_conc
obstimes = df.obstimes
param_change_times = 1:7:length(obstimes) # Change every week
obstimes_hosp = df.obstimes
obstimes_wastewater = df.obstimes
max_obstime = max(length(obstimes_hosp), length(obstimes_wastewater))
param_change_times = 1:7:max_obstime # Change every week
priors_only = false
n_samples = 200
n_samples = 50
forecast = true
forecast_weeks = 4
model_params = create_uciwweihr_model_params()
samples = uciwweihr_fit(
data_hosp,
data_wastewater;
obstimes,
data_wastewater,
obstimes_hosp,
obstimes_wastewater;
param_change_times,
priors_only,
n_samples,
Expand All @@ -59,8 +62,9 @@ samples = uciwweihr_fit(
model_output = uciwweihr_gq_pp(
samples,
data_hosp,
data_wastewater;
obstimes = obstimes,
data_wastewater,
obstimes_hosp,
obstimes_wastewater;
param_change_times = param_change_times,
params = model_params,
forecast = forecast,
Expand All @@ -87,7 +91,8 @@ uciwweihr_visualizer(
data_hosp,
data_wastewater,
forecast_weeks,
obstimes,
obstimes_hosp,
obstimes_wastewater,
param_change_times,
2024,
forecast,
Expand Down
23 changes: 14 additions & 9 deletions docs/src/tutorials/uciwwiehr_model_fitting_no_forecast.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,23 +41,26 @@ first(df, 5)

## 2. Sampling from the Posterior Distribution and Posterior Predictive Distribution.

Here we sample from the posterior distribution using the `uciwweihr_fit.jl` function. First, we setup some presets, where we need to use `create_uciwweihr_model_params()` to get default parameters for the model. Then we have an array where index 1 contains the posterior/prior predictive samples, index 2 contains the posterior/prior generated quantities samples, and index 3 contains the original sampled parameters for the model.
Here we sample from the posterior distribution using the `uciwweihr_fit.jl` function. First, we setup some presets, where we need to use `create_uciwweihr_model_params()` to get default parameters for the model. Then we have an array where index 1 contains the posterior/prior predictive samples, index 2 contains the posterior/prior generated quantities samples, and index 3 contains the original sampled parameters for the model. Again, we can allow misalignment of hospital and wastewater data's observed times. For this tutorial, we use the same observed points.

``` @example tutorial
data_hosp = df.hosp
data_wastewater = df.log_ww_conc
obstimes = df.obstimes
param_change_times = 1:7:length(obstimes) # Change every week
obstimes_hosp = df.obstimes
obstimes_wastewater = df.obstimes
max_obstime = max(length(obstimes_hosp), length(obstimes_wastewater))
param_change_times = 1:7:max_obstime # Change every week
priors_only = false
n_samples = 200
n_samples = 50
forecast = false
forecast_weeks = 0
model_params = create_uciwweihr_model_params()
samples = uciwweihr_fit(
data_hosp,
data_wastewater;
obstimes,
data_wastewater,
obstimes_hosp,
obstimes_wastewater;
param_change_times,
priors_only,
n_samples,
Expand All @@ -66,8 +69,9 @@ samples = uciwweihr_fit(
model_output = uciwweihr_gq_pp(
samples,
data_hosp,
data_wastewater;
obstimes = obstimes,
data_wastewater,
obstimes_hosp,
obstimes_wastewater;
param_change_times = param_change_times,
params = model_params
)
Expand All @@ -92,7 +96,8 @@ uciwweihr_visualizer(
data_hosp,
data_wastewater,
forecast_weeks,
obstimes,
obstimes_hosp,
obstimes_wastewater,
param_change_times,
2024,
forecast,
Expand Down
18 changes: 18 additions & 0 deletions src/helper_functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,15 @@ function generate_names(vn_str::String, val::AbstractArray{<:AbstractArray})
return results
end

function generate_names(vn_str::String, val::Vector{Any})
# Added by Christian Bernal Zelaya.
results = String[]
for idx in 1:length(val)
push!(results, "$(vn_str)[$idx]")
end
return results
end

flatten(val::Real) = [val;]
function flatten(val::AbstractArray{<:Real})
return mapreduce(vcat, CartesianIndices(val)) do i
Expand All @@ -66,6 +75,15 @@ function flatten(val::AbstractArray{<:AbstractArray})
flatten(val[i])
end
end
function flatten(val::Vector{Any})
# Added by Christian Bernal Zelaya.
results = []
for item in val
append!(results, flatten(item))
end
return results
end


function vectup2chainargs(ts::AbstractVector{<:NamedTuple})
ks = keys(first(ts))
Expand Down
36 changes: 24 additions & 12 deletions src/non_time_varying_param_vis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,15 @@
Used in the `uciwweihr_visualizer` to create visuals for non-time varying parameters.
# Arguments
- `build_params::uciwweihr_model_params`: A struct of model parameters used to build `gq_samples`, used only if user desired priors next to posteriors.
- `data_hosp`: Hospitalization data, used only if user desired priors next to posteriors.
- `data_wastewater`: Wastewater data, if model does not use this do not specify this, if user desires priors next to plot (do not specify if you do not want prior plots).
- `obstimes_hosp`: An array of time points for hospital data, used only if user desired priors next to posteriors.
- `obstimes_wastewater`: An array of time points for wastewater data, used only if user desired priors next to posteriors.
- `param_change_times`: An array of time points where the parameters change, used only if user desired priors next to posteriors.
- `seed`: An integer to set the seed for reproducibility, used only if user desired priors next to posteriors.
- `forecast`: A boolean to indicate if user wants to forecast, used only if user desired priors next to posteriors.
- `forecast_weeks`: An integer to indicate the number of weeks to forecast, used only if user desired priors next to posteriors.
- `gq_samples`: Generated quantities samples from the posterior/prior distribution, index 2 in uciwweihr_gq_pp output.
- `desired_params`: A list of lists of parameters to visualize. Each list will be visualized in a separate plot. Default is any parameter not in this list : ["alpha_t", "w_t", "rt_vals", "log_genes_mean", "H"]
- `bayes_dist_type`: A string to indicate if user is using Posterior or Prior distribution. Default is "Posterior".
Expand All @@ -14,7 +23,7 @@ Used in the `uciwweihr_visualizer` to create visuals for non-time varying parame
function non_time_varying_param_vis(
build_params::uciwweihr_model_params,
data_hosp,
obstimes,
obstimes_hosp,
param_change_times,
seed,
forecast,
Expand All @@ -28,17 +37,17 @@ function non_time_varying_param_vis(
)
println("Generating non-time varying parameter plots (with priors and w/out wastewater)...")
samples = uciwweihr_fit(
data_hosp;
obstimes = obstimes,
data_hosp,
obstimes_hosp;
param_change_times = param_change_times,
priors_only = true,
seed = seed,
params = build_params
)
prior_model_output = uciwweihr_gq_pp(
samples,
data_hosp;
obstimes,
data_hosp,
obstimes_hosp;
param_change_times,
seed = seed,
params = build_params,
Expand Down Expand Up @@ -96,7 +105,7 @@ function non_time_varying_param_vis(
size = (1500, 1500))
display(final_plot)
if save_plots
savefig(final_plot, plot_name_to_save)
save_plots_to_docs(final_plot, plot_name_to_save)
end
else
println("NO NON-TIME VARYING PARAMETER PLOTS TO DISPLAY!!!")
Expand All @@ -111,7 +120,8 @@ function non_time_varying_param_vis(
build_params::uciwweihr_model_params,
data_hosp,
data_wastewater,
obstimes,
obstimes_hosp,
obstimes_wastewater,
param_change_times,
seed,
forecast,
Expand All @@ -126,8 +136,9 @@ function non_time_varying_param_vis(
println("Generating non-time varying parameter plots (with priors and with wastewater)...")
samples = uciwweihr_fit(
data_hosp,
data_wastewater;
obstimes = obstimes,
data_wastewater,
obstimes_hosp,
obstimes_wastewater;
param_change_times = param_change_times,
priors_only = true,
seed = seed,
Expand All @@ -136,8 +147,9 @@ function non_time_varying_param_vis(
prior_model_output = uciwweihr_gq_pp(
samples,
data_hosp,
data_wastewater;
obstimes,
data_wastewater,
obstimes_hosp,
obstimes_wastewater;
param_change_times,
seed = seed,
params = build_params,
Expand Down Expand Up @@ -195,7 +207,7 @@ function non_time_varying_param_vis(
size = (1500, 1500))
display(final_plot)
if save_plots
savefig(final_plot, plot_name_to_save)
save_plots_to_docs(final_plot, plot_name_to_save)
end
else
println("NO NON-TIME VARYING PARAMETER PLOTS TO DISPLAY!!!")
Expand Down
26 changes: 15 additions & 11 deletions src/time_varying_param_vis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ Used in the `uciwweihr_visualizer` to create visuals for time varying parameters
- `build_params::uciwweihr_model_params`: A struct of model parameters used to build `gq_samples`, used only if user desired priors next to posteriors.
- `data_hosp`: Hospitalization data, used only if user desired priors next to posteriors.
- `data_wastewater`: Wastewater data, if model does not use this do not specify this, if user desires priors next to plot (do not specify if you do not want prior plots).
- `obstimes`: An array of time points for the data, used only if user desired priors next to posteriors.
- `obstimes_hosp`: An array of time points for hospital data, used only if user desired priors next to posteriors.
- `obstimes_wastewater`: An array of time points for wastewater data, used only if user desired priors next to posteriors.
- `param_change_times`: An array of time points where the parameters change, used only if user desired priors next to posteriors.
- `seed`: An integer to set the seed for reproducibility, used only if user desired priors next to posteriors.
- `forecast`: A boolean to indicate if user wants to forecast, used only if user desired priors next to posteriors.
Expand All @@ -23,7 +24,7 @@ Used in the `uciwweihr_visualizer` to create visuals for time varying parameters
function time_varying_param_vis(
build_params::uciwweihr_model_params,
data_hosp,
obstimes,
obstimes_hosp,
param_change_times,
seed,
forecast,
Expand All @@ -38,17 +39,17 @@ function time_varying_param_vis(
)
println("Generating time varying parameter plots (with priors and w/out wastewater)...")
samples = uciwweihr_fit(
data_hosp;
obstimes = obstimes,
data_hosp,
obstimes_hosp;
param_change_times = param_change_times,
priors_only = true,
seed = seed,
params = build_params
)
prior_model_output = uciwweihr_gq_pp(
samples,
data_hosp;
obstimes,
data_hosp,
obstimes_hosp;
param_change_times,
seed = seed,
params = build_params,
Expand Down Expand Up @@ -157,7 +158,8 @@ function time_varying_param_vis(
build_params::uciwweihr_model_params,
data_hosp,
data_wastewater,
obstimes,
obstimes_hosp,
obstimes_wastewater,
param_change_times,
seed,
forecast,
Expand All @@ -173,8 +175,9 @@ function time_varying_param_vis(
println("Generating time varying parameter plots (with priors and with wastewater)...")
samples = uciwweihr_fit(
data_hosp,
data_wastewater;
obstimes = obstimes,
data_wastewater,
obstimes_hosp,
obstimes_wastewater;
param_change_times = param_change_times,
priors_only = true,
seed = seed,
Expand All @@ -183,8 +186,9 @@ function time_varying_param_vis(
prior_model_output = uciwweihr_gq_pp(
samples,
data_hosp,
data_wastewater;
obstimes,
data_wastewater,
obstimes_hosp,
obstimes_wastewater;
param_change_times,
seed = seed,
params = build_params,
Expand Down
28 changes: 15 additions & 13 deletions src/uciwweihr_fit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,27 +20,29 @@ The defaults for this fuction will follow those of the default simulation in gen
"""
function uciwweihr_fit(
data_hosp,
data_wastewater;
obstimes,
data_wastewater,
obstimes_hosp,
obstimes_wastewater;
param_change_times,
priors_only::Bool=false,
n_samples::Int64=500, n_chains::Int64=1, seed::Int64=2024,
params::uciwweihr_model_params
)
println("Using uciwweihr_model with wastewater!!!")
obstimes = convert(Vector{Float64}, obstimes)
param_change_times = convert(Vector{Float64}, param_change_times)
obstimes_hosp = convert(Vector{Int64}, obstimes_hosp)
obstimes_wastewater = convert(Vector{Int64}, obstimes_wastewater)
param_change_times = convert(Vector{Int64}, param_change_times)


my_model = uciwweihr_model(
data_hosp,
data_wastewater;
obstimes,
data_wastewater,
obstimes_hosp,
obstimes_wastewater;
param_change_times,
params
)


# Sample Posterior
if priors_only
Random.seed!(seed)
Expand All @@ -53,21 +55,21 @@ function uciwweihr_fit(
end

function uciwweihr_fit(
data_hosp;
obstimes,
data_hosp,
obstimes_hosp;
param_change_times,
priors_only::Bool=false,
n_samples::Int64=500, n_chains::Int64=1, seed::Int64=2024,
params::uciwweihr_model_params
)
println("Using uciwweihr_model without wastewater!!!")
obstimes = convert(Vector{Float64}, obstimes)
param_change_times = convert(Vector{Float64}, param_change_times)
obstimes_hosp = convert(Vector{Int64}, obstimes_hosp)
param_change_times = convert(Vector{Int64}, param_change_times)


my_model = uciwweihr_model(
data_hosp;
obstimes,
data_hosp,
obstimes_hosp;
param_change_times,
params
)
Expand Down
Loading

0 comments on commit 1cacb57

Please sign in to comment.