Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

2024-09-08 update : adding repeated forecast func. #57

Merged
merged 9 commits into from
Sep 14, 2024
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ makedocs(;
"AGENT-BASED SIMULATION DATA" => "tutorials/agent_based_simulation_data.md",
"UCIWWEIHR FITTING MODEL W/OUT FORECASTING" => "tutorials/uciwwiehr_model_fitting_no_forecast.md",
"UCIWWEIHR FITTING MODEL W/ FORECASTING" => "tutorials/uciwwiehr_model_fitting_forecast.md",
"UCIWWEIHR REPEATED FORECASTING" => "tutorials/uciwweihr_model_repeated_forecasts.md",
]
,
"REFERENCE" => "reference.md",
Expand Down
1 change: 1 addition & 0 deletions docs/src/tutorial_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@ Future Description.
- [Generating simulated data with an agent based model.](@ref agent_based_simulation_data)
- [Generating posterior distribution samples with UCIWWEIHR ODE compartmental based model without forecasting.](@ref uciwwiehr_model_fitting_no_forecast)
- [Generating posterior distribution samples with UCIWWEIHR ODE compartmental based model with forecasting.](@ref uciwwiehr_model_fitting_with_forecast)
- [Generating repeated forecasts using the UCIWWEIHR model.](@ref uciwwiehr_model_repeated_forecasts)


133 changes: 133 additions & 0 deletions docs/src/tutorials/uciwweihr_model_repeated_forecasts.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
```@setup tutorial_forecast
using Plots, StatsPlots; gr()
Plots.reset_defaults()

```

# [Generating Repeated Forecasts Using the UCIWWEIHR model.](@id uciwwiehr_model_repeated_forecasts)

Here we show how we can construct repeated forecasts using the UCIWWEIHR model. We start with generating out data using `generate_simulation_data_uciwweihr`'s alternate parameterization where we do prespecify the effective reproduction number and hospitalization probability.



## 1. Data Generation.

Here we simulate a dataset, one with 175 time points.

``` @example tutorial_forecast
using UCIWWEIHR
# Running simulation function with presets
rt_custom = vcat(
range(1, stop=1.8, length=7*4),
fill(1.8, 7*2),
range(1.8, stop=1, length=7*8),
range(0.98, stop=0.8, length=7*2),
range(0.8, stop=1.1, length=7*6),
range(1.1, stop=0.97, length=7*3)
)
w_custom = vcat(
range(0.3, stop=0.38, length=7*5),
fill(0.38, 7*2),
range(0.38, stop=0.25, length=7*8),
range(0.25, stop=0.28, length=7*2),
range(0.28, stop=0.34, length=7*6),
range(0.34, stop=0.28, length=7*2)
)
params = create_uciwweihr_sim_params(
time_points = length(rt_custom),
Rt = rt_custom,
w = w_custom
)
df = generate_simulation_data_uciwweihr(params)
first(df, 5)
```

## 2. Constructing Repeat Forecasts.

We use the `repeated_forecast` function to generate forecasts for a given number of weeks, for a given number of time points. Along with this we need to specify presets. Output of this function is an array with the first index controlling which result we are looking at. The next contains a `uciwweihr_gq_pp` output.

``` @example tutorial_forecast
data_hosp = df.hosp
data_wastewater = df.log_ww_conc
obstimes_hosp = df.obstimes
obstimes_wastewater = df.obstimes
max_obstime = max(length(obstimes_hosp), length(obstimes_wastewater))
param_change_times = 1:7:max_obstime # Change every week
priors_only = false
n_samples = 200
n_forecast_weeks = 2
forecast_points = [
param_change_times[end-5],
param_change_times[end-4],
param_change_times[end-3],
param_change_times[end-2]
]

model_params = create_uciwweihr_model_params()

rep_results = repeated_forecast(
data_hosp,
data_wastewater,
obstimes_hosp,
obstimes_wastewater;
n_samples = n_samples,
params = model_params,
n_forecast_weeks = 2,
forecast_points = forecast_points
)

first(rep_results, 2)
```

## 3. Visualizing Results Of Repeated Forecasts.

We can take a look at these forecasts using the `uciwweihr_visualizer` function. We can also add certain parameters to ensure we only see the plots we want.

```@example tutorial_forecast
for res_index in 1:length(forecast_points)
uciwweihr_visualizer(
data_hosp,
data_wastewater,
n_forecast_weeks,
obstimes_hosp,
obstimes_wastewater,
param_change_times,
2024,
true,
model_params;
pp_samples = rep_results[res_index][2][1],
gq_samples = rep_results[res_index][2][2],
obs_data_hosp = data_hosp,
obs_data_wastewater = data_wastewater,
actual_rt_vals = df.rt,
actual_w_t = df.wt,
actual_non_time_varying_vals = params,
bayes_dist_type = "Posterior",
mcmcdaigs = false,
time_varying_plots = false,
non_time_varying_plots = false,
pred_param_plots = true,
save_plots = true,
plot_name_to_save_pred_param = "mcmc_pred_parameter_plots_rep_res"*string(res_index)
)
end
```

### 3.1. Forecast Point 1.

![Plot 1](plots/mcmc_pred_parameter_plots_rep_res1.png)

### 3.2. Forecast Point 2.

![Plot 2](plots/mcmc_pred_parameter_plots_rep_res2.png)

### 3.3. Forecast Point 3.

![Plot 3](plots/mcmc_pred_parameter_plots_rep_res3.png)

### 3.4. Forecast Point 4.

![Plot 4](plots/mcmc_pred_parameter_plots_rep_res4.png)


### [Tutorial Contents](@ref tutorial_home)
2 changes: 1 addition & 1 deletion docs/src/tutorials/uciwwiehr_model_fitting_forecast.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ obstimes_wastewater = df.obstimes
max_obstime = max(length(obstimes_hosp), length(obstimes_wastewater))
param_change_times = 1:7:max_obstime # Change every week
priors_only = false
n_samples = 50
n_samples = 200
forecast = true
forecast_weeks = 4

Expand Down
2 changes: 1 addition & 1 deletion docs/src/tutorials/uciwwiehr_model_fitting_no_forecast.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ obstimes_wastewater = df.obstimes
max_obstime = max(length(obstimes_hosp), length(obstimes_wastewater))
param_change_times = 1:7:max_obstime # Change every week
priors_only = false
n_samples = 50
n_samples = 200
forecast = false
forecast_weeks = 0

Expand Down
3 changes: 3 additions & 0 deletions src/UCIWWEIHR.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ include("time_varying_param_vis.jl")
include("non_time_varying_param_vis.jl")
include("predictive_param_vis.jl")
include("uciwweihr_visualizer.jl")
include("repeated_forecast.jl")

export eihr_ode
export uciwweihr_sim_params
Expand All @@ -63,5 +64,7 @@ export mcmcdiags_vis
export time_varying_param_vis
export non_time_varying_param_vis
export predictive_param_vis
export repeated_forecast
export is_time_varying_above_n

end
32 changes: 31 additions & 1 deletion src/helper_functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -206,4 +206,34 @@ function repeat_last_n_elements(x::Vector{T}, n::Int, w::Int) where T

return x_new
end
end
end

"""
is_time_varying_above_n(name, n)

Checks if the time varying parameter is above a given time point.
"""
function is_time_varying_above_n(name::Symbol, n::Int)
name_str = string(name)
#println("Checking parameter: ", name_str)

if occursin(r"\[\d+\]", name_str)
#println("Pattern matched")
m = match(r"\d+", name_str)
number = parse(Int, m.match)

if number !== nothing
#println("Extracted time point string: ", number)
return number > n
else
#println("No match found")
end
else
#println("Not a time-varying parameter")
end

return false
end



82 changes: 82 additions & 0 deletions src/repeated_forecast.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""
repeated_forecast(...)
This is the function to make repreated forecast for a given forecast time span, `n_forecast_weeks`, and for given time points, `forecast_points`.
Plots can be made for these forecasts. The output is an array of `uciwweihr_gq_pp` results for each `forecast_points`.

# Arguments
- `data_hosp`: The hospitalization data.
- `data_wastewater`: The wastewater data.
- `obstimes_hosp`: The time points for the hospitalization data.
- `obstimes_wastewater`: The time points for the wastewater data.
- `n_samples`: The number of samples to draw from the posterior.
- `param_change_times`: The time points where the parameters change.
- `params::uciwweihr_model_params`: The model parameters.
- `n_forecast_weeks`: The number of weeks to forecast.
- `forecast_points`: The time points to forecast, thees points should be present in obstimes_hosp.

# Returns
- An array of `uciwweihr_gq_pp` resuts and timeseries used for building for each `forecast_points`.
"""
function repeated_forecast(
data_hosp,
data_wastewater,
obstimes_hosp,
obstimes_wastewater;
n_samples::Int64,
params::uciwweihr_model_params,
n_forecast_weeks::Int64,
forecast_points::Vector{Int64}
)
results = []
for max_point in forecast_points
index_hosp = findfirst(x -> x == max_point, obstimes_hosp)
if index_hosp === nothing
error("THE FORECAST POINT SHOUDL BE PRESENT IN OBSTIMES_HOSP!!!")
end
index_ww = findfirst(x -> x == max_point, obstimes_wastewater)
if index_ww === nothing
index_ww = findfirst(x -> x < point, obstimes_wastewater)
if index_ww === nothing
error("FINDING THE INDEX FOR WW FORECAST POINT FAILED!!!")
end
end
max_week = Int(ceil(max_point / 7))

temp_data_hosp = data_hosp[1:index_hosp]
temp_data_wastewater = data_wastewater[1:index_ww]
temp_obstimes_hosp = obstimes_hosp[1:index_hosp]
temp_obstimes_wastewater = obstimes_wastewater[1:index_ww]
temp_param_change_times = 1:1:max_week
temp_build_object = [
temp_data_hosp,
temp_data_wastewater,
temp_obstimes_hosp,
temp_obstimes_wastewater,
temp_param_change_times
]

samples = uciwweihr_fit(
temp_data_hosp,
temp_data_wastewater,
temp_obstimes_hosp,
temp_obstimes_wastewater;
param_change_times = temp_param_change_times,
priors_only = false,
n_samples = n_samples,
params = params
)
model_output = uciwweihr_gq_pp(
samples,
temp_data_hosp,
temp_data_wastewater,
temp_obstimes_hosp,
temp_obstimes_wastewater;
param_change_times = temp_param_change_times,
params = params,
forecast = true,
forecast_weeks = n_forecast_weeks
)
push!(results, [temp_build_object, model_output])
end
return(results)
end
1 change: 0 additions & 1 deletion src/uciwweihr_gq_pp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ function uciwweihr_gq_pp(
indices_to_keep = .!isnothing.(generated_quantities(my_model, samples))
samples_randn = ChainsCustomIndex(samples, indices_to_keep)


Random.seed!(seed)
predictive_randn = predict(my_model_forecast_missing, samples_randn)

Expand Down
Loading