From d7f3b6915d8c67f0d8316164f2d5456fd2251db2 Mon Sep 17 00:00:00 2001 From: Samuel Brand <48288458+SamuelBrand1@users.noreply.github.com> Date: Thu, 27 Jun 2024 12:53:03 -0400 Subject: [PATCH] add forecasting to pipeline solve and save (#308) * include a default lookahead * Include a forecast step in the `infer` function * extend tests --- .../src/constructors/make_default_params.jl | 4 +- .../constructors/make_inference_configs.jl | 4 +- pipeline/src/forecast/generate_forecasts.jl | 40 +++++++++++++------ pipeline/src/infer/InferenceConfig.jl | 26 +++++++----- .../test/constructors/test_constructors.jl | 3 +- .../test/end-to-end/test_full_inference.jl | 5 ++- pipeline/test/infer/test_InferenceConfig.jl | 4 +- .../test/pipeline/test_pipelinefunctions.jl | 4 +- 8 files changed, 61 insertions(+), 29 deletions(-) diff --git a/pipeline/src/constructors/make_default_params.jl b/pipeline/src/constructors/make_default_params.jl index 4b85035a3..1f73b3403 100644 --- a/pipeline/src/constructors/make_default_params.jl +++ b/pipeline/src/constructors/make_default_params.jl @@ -21,11 +21,13 @@ function make_default_params(pipeline::AbstractEpiAwarePipeline) I0 = 100.0 α_delay = 4.0 θ_delay = 5.0 / 4.0 + lookahead = 21 return Dict( "Rt" => Rt, "logit_daily_ascertainment" => logit_daily_ascertainment, "cluster_factor" => cluster_factor, "I0" => I0, "α_delay" => α_delay, - "θ_delay" => θ_delay) + "θ_delay" => θ_delay, + "lookahead" => lookahead) end diff --git a/pipeline/src/constructors/make_inference_configs.jl b/pipeline/src/constructors/make_inference_configs.jl index ab0d3815e..332b24c90 100644 --- a/pipeline/src/constructors/make_inference_configs.jl +++ b/pipeline/src/constructors/make_inference_configs.jl @@ -14,10 +14,12 @@ function make_inference_configs(pipeline::AbstractEpiAwarePipeline) igps = make_inf_generating_processes(pipeline) obs = make_observation_model(pipeline) priors = make_model_priors(pipeline) + default_params = make_default_params(pipeline) inference_configs = Dict("igp" => igps, "latent_namemodels" => namemodel_vect, "observation_model" => obs, "gi_mean" => gi_param_dict["gi_means"], - "gi_std" => gi_param_dict["gi_stds"], "log_I0_prior" => priors["log_I0_prior"]) |> + "gi_std" => gi_param_dict["gi_stds"], "log_I0_prior" => priors["log_I0_prior"], + "lookahead" => default_params["lookahead"]) |> dict_list selected_inference_configs = _selector(inference_configs, pipeline) diff --git a/pipeline/src/forecast/generate_forecasts.jl b/pipeline/src/forecast/generate_forecasts.jl index 52842bd75..752b926ea 100644 --- a/pipeline/src/forecast/generate_forecasts.jl +++ b/pipeline/src/forecast/generate_forecasts.jl @@ -1,19 +1,15 @@ """ -Generate forecasts for `n` time steps above based on the given inference results. +Generate forecasts for `lookahead` time steps ahead based on the results of the +inference process. # Arguments -- `inference_results`: The results of the inference process. -- `n`: The number of forecasts to generate. - -# Returns -- `forecast_quantities`: The generated forecast quantities. - +- `inference_chn`: The posterior chains of the inference process. +- `data`: The data used in the inference process. +- `epiprob`: The EpiProblem object used in the inference process. +- `lookahead`: The number of time steps to forecast ahead. """ -function generate_forecasts(inference_results, n::Integer) - inference_chn = inference_results["inference_results"].samples - data = inference_results["inference_results"].data - epiprob = inference_results["epiprob"] - forecast_epiprob = define_forecast_epiprob(epiprob, n) +function generate_forecasts(inference_chn, data, epiprob, lookahead::Integer) + forecast_epiprob = define_forecast_epiprob(epiprob, lookahead) forecast_mdl = generate_epiaware(forecast_epiprob, (y_t = missing,)) # Add forward generation of latent variables using `predict` @@ -27,3 +23,23 @@ function generate_forecasts(inference_results, n::Integer) forecast_quantities = generated_observables(forecast_mdl, data, pred_chn) return forecast_quantities end + +""" +Generate forecasts for `lookahead` time steps ahead based on the given inference results +in dictionary form. + +# Arguments +- `inference_results_dict`: A dictionary of results of the inference process. +- `lookahead`: The number of time steps to forecast ahead. + +# Returns +- `forecast_quantities`: The generated forecast quantities. + +""" +function generate_forecasts(inference_results_dict::Dict, lookahead::Integer) + @assert haskey(inference_results_dict, "inference_results") "Results dictionary must contain `inference_results` key" + inference_chn = inference_results["inference_results"].samples + data = inference_results["inference_results"].data + epiprob = inference_results["epiprob"] + return generate_forecasts(inference_chn, data, epiprob, lookahead) +end diff --git a/pipeline/src/infer/InferenceConfig.jl b/pipeline/src/infer/InferenceConfig.jl index 65236dd90..dcb516cbf 100644 --- a/pipeline/src/infer/InferenceConfig.jl +++ b/pipeline/src/infer/InferenceConfig.jl @@ -18,7 +18,7 @@ Inference configuration struct for specifying the parameters and models used in - `InferenceConfig(inference_config::Dict; case_data, tspan, epimethod)`: Constructs an `InferenceConfig` object from a dictionary of configuration values. """ -struct InferenceConfig{T, F, I, L, O, E} +struct InferenceConfig{T, F, I, L, O, E, D <: Distribution, X <: Integer} gi_mean::T gi_std::T igp::I @@ -28,14 +28,16 @@ struct InferenceConfig{T, F, I, L, O, E} tspan::Tuple{Integer, Integer} epimethod::E transformation::F - log_I0_prior::Distribution + log_I0_prior::D + lookahead::X function InferenceConfig(igp, latent_model, observation_model; gi_mean, gi_std, - case_data, tspan, epimethod, transformation = exp, log_I0_prior) + case_data, tspan, epimethod, transformation = exp, log_I0_prior, lookahead) new{typeof(gi_mean), typeof(transformation), - typeof(igp), typeof(latent_model), typeof(observation_model), typeof(epimethod)}( + typeof(igp), typeof(latent_model), typeof(observation_model), + typeof(epimethod), typeof(log_I0_prior), typeof(lookahead)}( gi_mean, gi_std, igp, latent_model, observation_model, - case_data, tspan, epimethod, transformation, log_I0_prior) + case_data, tspan, epimethod, transformation, log_I0_prior, lookahead) end function InferenceConfig( @@ -49,7 +51,8 @@ struct InferenceConfig{T, F, I, L, O, E} case_data = case_data, tspan = tspan, epimethod = epimethod, - log_I0_prior = inference_config["log_I0_prior"] + log_I0_prior = inference_config["log_I0_prior"], + lookahead = inference_config["lookahead"] ) end end @@ -68,14 +71,19 @@ to make inference on and model configuration. """ function infer(config::InferenceConfig) #Define the EpiProblem - epi_prob = define_epiprob(config) + epiprob = define_epiprob(config) idxs = config.tspan[1]:config.tspan[2] #Return the sampled infections and observations y_t = ismissing(config.case_data) ? missing : config.case_data[idxs] - inference_results = apply_method(epi_prob, + inference_results = apply_method(epiprob, config.epimethod, (y_t = y_t,); ) - return Dict("inference_results" => inference_results, "epiprob" => epi_prob) + + forecast_results = generate_forecasts( + inference_results.samples, inference_results.data, epiprob, config.lookahead) + + return Dict("inference_results" => inference_results, "epiprob" => epiprob, + "inference_config" => config, "forecast_results" => forecast_results) end diff --git a/pipeline/test/constructors/test_constructors.jl b/pipeline/test/constructors/test_constructors.jl index 5f2e07aba..91fa28dc6 100644 --- a/pipeline/test/constructors/test_constructors.jl +++ b/pipeline/test/constructors/test_constructors.jl @@ -142,7 +142,8 @@ end "cluster_factor" => 0.05, "I0" => 100.0, "α_delay" => 4.0, - "θ_delay" => 5.0 / 4.0 + "θ_delay" => 5.0 / 4.0, + "lookahead" => 21 ) # Test the make_default_params function diff --git a/pipeline/test/end-to-end/test_full_inference.jl b/pipeline/test/end-to-end/test_full_inference.jl index c3852fd77..ed05b7271 100644 --- a/pipeline/test/end-to-end/test_full_inference.jl +++ b/pipeline/test/end-to-end/test_full_inference.jl @@ -13,7 +13,8 @@ using Test inference_config = rand(inference_configs) truthdata = Dict("y_t" => fill(100, 28), "truth_gi_mean" => 1.5) - inference_results = generate_inference_results( + results = generate_inference_results( truthdata, inference_config, pipeline; tspan, inference_method) - @test inference_results["inference_results"] isa EpiAwareObservables + @test results["inference_results"] isa EpiAwareObservables + @test results["forecast_results"] isa EpiAwareObservables end diff --git a/pipeline/test/infer/test_InferenceConfig.jl b/pipeline/test/infer/test_InferenceConfig.jl index f157b0a40..bdb2bd7ef 100644 --- a/pipeline/test/infer/test_InferenceConfig.jl +++ b/pipeline/test/infer/test_InferenceConfig.jl @@ -18,6 +18,7 @@ epimethod = TestMethod() case_data = [10, 20, 30, 40, 50] tspan = (1, 5) + lookahead = 10 @testset "config_parameters back from constructor" begin config = InferenceConfig(igp, latent_model, observation_model; gi_mean = gi_mean, @@ -25,7 +26,8 @@ case_data = case_data, tspan = tspan, epimethod = epimethod, - log_I0_prior = Normal(log(100.0), 1e-5) + log_I0_prior = Normal(log(100.0), 1e-5), + lookahead = lookahead ) @test config.gi_mean == gi_mean diff --git a/pipeline/test/pipeline/test_pipelinefunctions.jl b/pipeline/test/pipeline/test_pipelinefunctions.jl index 45e8a6123..704d69b6f 100644 --- a/pipeline/test/pipeline/test_pipelinefunctions.jl +++ b/pipeline/test/pipeline/test_pipelinefunctions.jl @@ -9,12 +9,12 @@ end @testset "do_inference tests" begin - using EpiAwarePipeline + using EpiAwarePipeline, Dagger pipeline = EpiAwareExamplePipeline() function make_inference() truthdata = do_truthdata(pipeline) - do_inference(truthdata[1:1], pipeline) + do_inference(truthdata[1], pipeline) end inference_results_tsk = make_inference()