From 01c1e95bf880356ebdbb3ad3ccf4a1512645cee3 Mon Sep 17 00:00:00 2001 From: cbernalz Date: Sun, 8 Sep 2024 18:55:21 -0700 Subject: [PATCH] 2024-09-08 update : adding repeated forecast func. --- docs/make.jl | 1 + docs/src/tutorial_index.md | 1 + .../uciwweihr_model_repeated_forecasts.md | 130 ++++++++++++++++++ .../uciwwiehr_model_fitting_forecast.md | 2 +- .../uciwwiehr_model_fitting_no_forecast.md | 2 +- src/UCIWWEIHR.jl | 3 + src/helper_functions.jl | 32 ++++- src/repeated_forecast.jl | 84 +++++++++++ src/uciwweihr_gq_pp.jl | 1 - 9 files changed, 252 insertions(+), 4 deletions(-) create mode 100644 docs/src/tutorials/uciwweihr_model_repeated_forecasts.md create mode 100644 src/repeated_forecast.jl diff --git a/docs/make.jl b/docs/make.jl index d1a59b5..d61dafb 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -24,6 +24,7 @@ makedocs(; "AGENT-BASED SIMULATION DATA" => "tutorials/agent_based_simulation_data.md", "UCIWWEIHR FITTING MODEL W/OUT FORECASTING" => "tutorials/uciwwiehr_model_fitting_no_forecast.md", "UCIWWEIHR FITTING MODEL W/ FORECASTING" => "tutorials/uciwwiehr_model_fitting_forecast.md", + "UCIWWEIHR REPEATED FORECASTING" => "tutorials/uciwwiehr_model_repeated_forecasting.md", ] , "REFERENCE" => "reference.md", diff --git a/docs/src/tutorial_index.md b/docs/src/tutorial_index.md index 355a04a..53d681f 100644 --- a/docs/src/tutorial_index.md +++ b/docs/src/tutorial_index.md @@ -9,5 +9,6 @@ Future Description. - [Generating simulated data with an agent based model.](@ref agent_based_simulation_data) - [Generating posterior distribution samples with UCIWWEIHR ODE compartmental based model without forecasting.](@ref uciwwiehr_model_fitting_no_forecast) - [Generating posterior distribution samples with UCIWWEIHR ODE compartmental based model with forecasting.](@ref uciwwiehr_model_fitting_with_forecast) +- [Generating repeated forecasts using the UCIWWEIHR model.](@ref uciwweihr_model_repeated_forecasts) diff --git a/docs/src/tutorials/uciwweihr_model_repeated_forecasts.md b/docs/src/tutorials/uciwweihr_model_repeated_forecasts.md new file mode 100644 index 0000000..2ba94eb --- /dev/null +++ b/docs/src/tutorials/uciwweihr_model_repeated_forecasts.md @@ -0,0 +1,130 @@ +```@setup tutorial_forecast +using Plots, StatsPlots; gr() +Plots.reset_defaults() + +``` + +# [Generating Repeated Forecasts Using the UCIWWEIHR model.](@id uciwwiehr_model_repeated_forecasts) + +Here we show how we can construct repeated forecasts using the UCIWWEIHR model. We start with generating out data using `generate_simulation_data_uciwweihr`'s alternate parameterization where we do prespecify the effective reproduction number and hospitalization probability. + + + +## 1. Data Generation. + +Here we simulate a dataset, one with 175 time points. + +``` @example tutorial_forecast +# Running simulation function with presets +rt_custom = vcat( + range(1, stop=1.8, length=7*4), + fill(1.8, 7*2), + range(1.8, stop=1, length=7*8), + range(0.98, stop=0.8, length=7*2), + range(0.8, stop=1.1, length=7*6), + range(1.1, stop=0.97, length=7*3) +) +w_custom = vcat( + range(0.3, stop=0.38, length=7*5), + fill(0.38, 7*2), + range(0.38, stop=0.25, length=7*8), + range(0.25, stop=0.28, length=7*2), + range(0.28, stop=0.34, length=7*6), + range(0.34, stop=0.28, length=7*2) +) +params = create_uciwweihr_sim_params( + time_points = length(rt_custom), + Rt = rt_custom, + w = w_custom +) +df = generate_simulation_data_uciwweihr(params) +``` + +## 2. Constructing Repeat Forecasts. + +We use the `repeated_forecast` function to generate forecasts for a given number of weeks, for a given number of time points. Along with this we need to specify presets. Output of this function is an array with the first index controlling which result we are looking at. The next contains a `uciwweihr_gq_pp` output. + +``` @example tutorial_forecast +data_hosp = df.hosp +data_wastewater = df.log_ww_conc +obstimes_hosp = df.obstimes +obstimes_wastewater = df.obstimes +max_obstime = max(length(obstimes_hosp), length(obstimes_wastewater)) +param_change_times = 1:7:max_obstime # Change every week +priors_only = false +n_samples = 200 +n_forecast_weeks = 2 +forecast_points = [ + param_change_times[end-5], + param_change_times[end-4], + param_change_times[end-3], + param_change_times[end-2] +] + +model_params = create_uciwweihr_model_params() + +rep_results = repeated_forecast( + samples, + data_hosp, + data_wastewater, + obstimes_hosp, + obstimes_wastewater; + n_samples = n_samples, + params = model_params, + n_forecast_weeks = 2, + forecast_points = forecast_points +) +``` + +## 3. Visualizing Results Of Repeated Forecasts. + +We can take a look at these forecasts using the `uciwweihr_visualizer` function. We can also add certain parameters to ensure we only see the plots we want. + +```@example tutorial_forecast +for res_index in 1:length(forecast_points) + uciwweihr_visualizer( + data_hosp, + data_wastewater, + n_forecast_weeks, + obstimes_hosp, + obstimes_wastewater, + param_change_times, + 2024, + true, + model_params; + pp_samples = rep_results[res_index][2][1], + gq_samples = rep_results[res_index][2][2], + obs_data_hosp = data_hosp, + obs_data_wastewater = data_wastewater, + actual_rt_vals = df.rt, + actual_w_t = df.wt, + actual_non_time_varying_vals = params, + bayes_dist_type = "Posterior", + mcmcdaigs = false, + time_varying_plots = false, + non_time_varying_plots = false, + pred_param_plots = true, + save_plots = true, + plot_name_to_save_pred_param = "mcmc_pred_parameter_plots_rep_res"*string(res_index)*".png" + ) +end +``` + +### 3.1. Forecast Point 1. + +![Plot 1](plots/mcmc_pred_parameter_plots_rep_res1.png) + +### 3.2. Forecast Point 2. + +![Plot 2](plots/mcmc_pred_parameter_plots_rep_res2.png) + +### 3.3. Forecast Point 3. + +![Plot 3](plots/mcmc_pred_parameter_plots_rep_res3.png) + +### 3.4. Forecast Point 4. + +![Plot 4](plots/mcmc_pred_parameter_plots_rep_res4.png) + + +### [Tutorial Contents](@ref tutorial_home) \ No newline at end of file diff --git a/docs/src/tutorials/uciwwiehr_model_fitting_forecast.md b/docs/src/tutorials/uciwwiehr_model_fitting_forecast.md index 70bcabf..ca65812 100644 --- a/docs/src/tutorials/uciwwiehr_model_fitting_forecast.md +++ b/docs/src/tutorials/uciwwiehr_model_fitting_forecast.md @@ -44,7 +44,7 @@ obstimes_wastewater = df.obstimes max_obstime = max(length(obstimes_hosp), length(obstimes_wastewater)) param_change_times = 1:7:max_obstime # Change every week priors_only = false -n_samples = 50 +n_samples = 200 forecast = true forecast_weeks = 4 diff --git a/docs/src/tutorials/uciwwiehr_model_fitting_no_forecast.md b/docs/src/tutorials/uciwwiehr_model_fitting_no_forecast.md index ee56bd2..98d50e8 100644 --- a/docs/src/tutorials/uciwwiehr_model_fitting_no_forecast.md +++ b/docs/src/tutorials/uciwwiehr_model_fitting_no_forecast.md @@ -51,7 +51,7 @@ obstimes_wastewater = df.obstimes max_obstime = max(length(obstimes_hosp), length(obstimes_wastewater)) param_change_times = 1:7:max_obstime # Change every week priors_only = false -n_samples = 50 +n_samples = 200 forecast = false forecast_weeks = 0 diff --git a/src/UCIWWEIHR.jl b/src/UCIWWEIHR.jl index 8368691..d17d5c7 100644 --- a/src/UCIWWEIHR.jl +++ b/src/UCIWWEIHR.jl @@ -37,6 +37,7 @@ include("time_varying_param_vis.jl") include("non_time_varying_param_vis.jl") include("predictive_param_vis.jl") include("uciwweihr_visualizer.jl") +include("repeated_forecast.jl") export eihr_ode export uciwweihr_sim_params @@ -63,5 +64,7 @@ export mcmcdiags_vis export time_varying_param_vis export non_time_varying_param_vis export predictive_param_vis +export repeated_forecast +export is_time_varying_above_n end \ No newline at end of file diff --git a/src/helper_functions.jl b/src/helper_functions.jl index 1332265..77a16be 100644 --- a/src/helper_functions.jl +++ b/src/helper_functions.jl @@ -206,4 +206,34 @@ function repeat_last_n_elements(x::Vector{T}, n::Int, w::Int) where T return x_new end -end \ No newline at end of file +end + +""" + is_time_varying_above_n(name, n) + +Checks if the time varying parameter is above a given time point. +""" +function is_time_varying_above_n(name::Symbol, n::Int) + name_str = string(name) + #println("Checking parameter: ", name_str) + + if occursin(r"\[\d+\]", name_str) + #println("Pattern matched") + m = match(r"\d+", name_str) + number = parse(Int, m.match) + + if number !== nothing + #println("Extracted time point string: ", number) + return number > n + else + #println("No match found") + end + else + #println("Not a time-varying parameter") + end + + return false +end + + + diff --git a/src/repeated_forecast.jl b/src/repeated_forecast.jl new file mode 100644 index 0000000..6c4115f --- /dev/null +++ b/src/repeated_forecast.jl @@ -0,0 +1,84 @@ +""" + repeated_forecast(...) +This is the function to make repreated forecast for a given forecast time span, `n_forecast_weeks`, and for given time points, `forecast_points`. +Plots can be made for these forecasts. The output is an array of `uciwweihr_gq_pp` results for each `forecast_points`. + +# Arguments +- `samples`: The MCMC samples from the model fit. +- `data_hosp`: The hospitalization data. +- `data_wastewater`: The wastewater data. +- `obstimes_hosp`: The time points for the hospitalization data. +- `obstimes_wastewater`: The time points for the wastewater data. +- `n_samples`: The number of samples to draw from the posterior. +- `param_change_times`: The time points where the parameters change. +- `params::uciwweihr_model_params`: The model parameters. +- `n_forecast_weeks`: The number of weeks to forecast. +- `forecast_points`: The time points to forecast, thees points should be present in obstimes_hosp. + +# Returns +- An array of `uciwweihr_gq_pp` resuts and timeseries used for building for each `forecast_points`. +""" +function repeated_forecast( + samples, + data_hosp, + data_wastewater, + obstimes_hosp, + obstimes_wastewater; + n_samples::Int64, + params::uciwweihr_model_params, + n_forecast_weeks::Int64, + forecast_points::Vector{Int64} +) + results = [] + for max_point in forecast_points + index_hosp = findfirst(x -> x == max_point, obstimes_hosp) + if index_hosp === nothing + error("THE FORECAST POINT SHOUDL BE PRESENT IN OBSTIMES_HOSP!!!") + end + index_ww = findfirst(x -> x == max_point, obstimes_wastewater) + if index_ww === nothing + index_ww = findfirst(x -> x < point, obstimes_wastewater) + if index_ww === nothing + error("FINDING THE INDEX FOR WW FORECAST POINT FAILED!!!") + end + end + max_week = Int(ceil(max_point / 7)) + + temp_data_hosp = data_hosp[1:index_hosp] + temp_data_wastewater = data_wastewater[1:index_ww] + temp_obstimes_hosp = obstimes_hosp[1:index_hosp] + temp_obstimes_wastewater = obstimes_wastewater[1:index_ww] + temp_param_change_times = 1:1:max_week + temp_build_object = [ + temp_data_hosp, + temp_data_wastewater, + temp_obstimes_hosp, + temp_obstimes_wastewater, + temp_param_change_times + ] + + samples = uciwweihr_fit( + temp_data_hosp, + temp_data_wastewater, + temp_obstimes_hosp, + temp_obstimes_wastewater; + param_change_times = temp_param_change_times, + priors_only = false, + n_samples = n_samples, + params = params + ) + model_output = uciwweihr_gq_pp( + samples, + temp_data_hosp, + temp_data_wastewater, + temp_obstimes_hosp, + temp_obstimes_wastewater; + param_change_times = temp_param_change_times, + params = params, + forecast = true, + forecast_weeks = n_forecast_weeks + ) + push!(results, [temp_build_object, model_output]) + end + return(results) +end diff --git a/src/uciwweihr_gq_pp.jl b/src/uciwweihr_gq_pp.jl index 0c6dc0b..612cbce 100644 --- a/src/uciwweihr_gq_pp.jl +++ b/src/uciwweihr_gq_pp.jl @@ -85,7 +85,6 @@ function uciwweihr_gq_pp( indices_to_keep = .!isnothing.(generated_quantities(my_model, samples)) samples_randn = ChainsCustomIndex(samples, indices_to_keep) - Random.seed!(seed) predictive_randn = predict(my_model_forecast_missing, samples_randn)