Skip to content

Commit

Permalink
add forecasting to pipeline solve and save (#308)
Browse files Browse the repository at this point in the history
* include a default lookahead

* Include a forecast step in the `infer` function

* extend tests
  • Loading branch information
SamuelBrand1 authored Jun 27, 2024
1 parent af37402 commit d7f3b69
Show file tree
Hide file tree
Showing 8 changed files with 61 additions and 29 deletions.
4 changes: 3 additions & 1 deletion pipeline/src/constructors/make_default_params.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@ function make_default_params(pipeline::AbstractEpiAwarePipeline)
I0 = 100.0
α_delay = 4.0
θ_delay = 5.0 / 4.0
lookahead = 21
return Dict(
"Rt" => Rt,
"logit_daily_ascertainment" => logit_daily_ascertainment,
"cluster_factor" => cluster_factor,
"I0" => I0,
"α_delay" => α_delay,
"θ_delay" => θ_delay)
"θ_delay" => θ_delay,
"lookahead" => lookahead)
end
4 changes: 3 additions & 1 deletion pipeline/src/constructors/make_inference_configs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@ function make_inference_configs(pipeline::AbstractEpiAwarePipeline)
igps = make_inf_generating_processes(pipeline)
obs = make_observation_model(pipeline)
priors = make_model_priors(pipeline)
default_params = make_default_params(pipeline)

inference_configs = Dict("igp" => igps, "latent_namemodels" => namemodel_vect,
"observation_model" => obs, "gi_mean" => gi_param_dict["gi_means"],
"gi_std" => gi_param_dict["gi_stds"], "log_I0_prior" => priors["log_I0_prior"]) |>
"gi_std" => gi_param_dict["gi_stds"], "log_I0_prior" => priors["log_I0_prior"],
"lookahead" => default_params["lookahead"]) |>
dict_list

selected_inference_configs = _selector(inference_configs, pipeline)
Expand Down
40 changes: 28 additions & 12 deletions pipeline/src/forecast/generate_forecasts.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,15 @@
"""
Generate forecasts for `n` time steps above based on the given inference results.
Generate forecasts for `lookahead` time steps ahead based on the results of the
inference process.
# Arguments
- `inference_results`: The results of the inference process.
- `n`: The number of forecasts to generate.
# Returns
- `forecast_quantities`: The generated forecast quantities.
- `inference_chn`: The posterior chains of the inference process.
- `data`: The data used in the inference process.
- `epiprob`: The EpiProblem object used in the inference process.
- `lookahead`: The number of time steps to forecast ahead.
"""
function generate_forecasts(inference_results, n::Integer)
inference_chn = inference_results["inference_results"].samples
data = inference_results["inference_results"].data
epiprob = inference_results["epiprob"]
forecast_epiprob = define_forecast_epiprob(epiprob, n)
function generate_forecasts(inference_chn, data, epiprob, lookahead::Integer)
forecast_epiprob = define_forecast_epiprob(epiprob, lookahead)
forecast_mdl = generate_epiaware(forecast_epiprob, (y_t = missing,))

# Add forward generation of latent variables using `predict`
Expand All @@ -27,3 +23,23 @@ function generate_forecasts(inference_results, n::Integer)
forecast_quantities = generated_observables(forecast_mdl, data, pred_chn)
return forecast_quantities
end

"""
Generate forecasts for `lookahead` time steps ahead based on the given inference results
in dictionary form.
# Arguments
- `inference_results_dict`: A dictionary of results of the inference process.
- `lookahead`: The number of time steps to forecast ahead.
# Returns
- `forecast_quantities`: The generated forecast quantities.
"""
function generate_forecasts(inference_results_dict::Dict, lookahead::Integer)
@assert haskey(inference_results_dict, "inference_results") "Results dictionary must contain `inference_results` key"
inference_chn = inference_results["inference_results"].samples
data = inference_results["inference_results"].data
epiprob = inference_results["epiprob"]
return generate_forecasts(inference_chn, data, epiprob, lookahead)
end
26 changes: 17 additions & 9 deletions pipeline/src/infer/InferenceConfig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Inference configuration struct for specifying the parameters and models used in
- `InferenceConfig(inference_config::Dict; case_data, tspan, epimethod)`: Constructs an `InferenceConfig` object from a dictionary of configuration values.
"""
struct InferenceConfig{T, F, I, L, O, E}
struct InferenceConfig{T, F, I, L, O, E, D <: Distribution, X <: Integer}
gi_mean::T
gi_std::T
igp::I
Expand All @@ -28,14 +28,16 @@ struct InferenceConfig{T, F, I, L, O, E}
tspan::Tuple{Integer, Integer}
epimethod::E
transformation::F
log_I0_prior::Distribution
log_I0_prior::D
lookahead::X

function InferenceConfig(igp, latent_model, observation_model; gi_mean, gi_std,
case_data, tspan, epimethod, transformation = exp, log_I0_prior)
case_data, tspan, epimethod, transformation = exp, log_I0_prior, lookahead)
new{typeof(gi_mean), typeof(transformation),
typeof(igp), typeof(latent_model), typeof(observation_model), typeof(epimethod)}(
typeof(igp), typeof(latent_model), typeof(observation_model),
typeof(epimethod), typeof(log_I0_prior), typeof(lookahead)}(
gi_mean, gi_std, igp, latent_model, observation_model,
case_data, tspan, epimethod, transformation, log_I0_prior)
case_data, tspan, epimethod, transformation, log_I0_prior, lookahead)
end

function InferenceConfig(
Expand All @@ -49,7 +51,8 @@ struct InferenceConfig{T, F, I, L, O, E}
case_data = case_data,
tspan = tspan,
epimethod = epimethod,
log_I0_prior = inference_config["log_I0_prior"]
log_I0_prior = inference_config["log_I0_prior"],
lookahead = inference_config["lookahead"]
)
end
end
Expand All @@ -68,14 +71,19 @@ to make inference on and model configuration.
"""
function infer(config::InferenceConfig)
#Define the EpiProblem
epi_prob = define_epiprob(config)
epiprob = define_epiprob(config)
idxs = config.tspan[1]:config.tspan[2]

#Return the sampled infections and observations
y_t = ismissing(config.case_data) ? missing : config.case_data[idxs]
inference_results = apply_method(epi_prob,
inference_results = apply_method(epiprob,
config.epimethod,
(y_t = y_t,);
)
return Dict("inference_results" => inference_results, "epiprob" => epi_prob)

forecast_results = generate_forecasts(
inference_results.samples, inference_results.data, epiprob, config.lookahead)

return Dict("inference_results" => inference_results, "epiprob" => epiprob,
"inference_config" => config, "forecast_results" => forecast_results)
end
3 changes: 2 additions & 1 deletion pipeline/test/constructors/test_constructors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ end
"cluster_factor" => 0.05,
"I0" => 100.0,
"α_delay" => 4.0,
"θ_delay" => 5.0 / 4.0
"θ_delay" => 5.0 / 4.0,
"lookahead" => 21
)

# Test the make_default_params function
Expand Down
5 changes: 3 additions & 2 deletions pipeline/test/end-to-end/test_full_inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ using Test
inference_config = rand(inference_configs)
truthdata = Dict("y_t" => fill(100, 28), "truth_gi_mean" => 1.5)

inference_results = generate_inference_results(
results = generate_inference_results(
truthdata, inference_config, pipeline; tspan, inference_method)
@test inference_results["inference_results"] isa EpiAwareObservables
@test results["inference_results"] isa EpiAwareObservables
@test results["forecast_results"] isa EpiAwareObservables
end
4 changes: 3 additions & 1 deletion pipeline/test/infer/test_InferenceConfig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,16 @@
epimethod = TestMethod()
case_data = [10, 20, 30, 40, 50]
tspan = (1, 5)
lookahead = 10
@testset "config_parameters back from constructor" begin
config = InferenceConfig(igp, latent_model, observation_model;
gi_mean = gi_mean,
gi_std = gi_std,
case_data = case_data,
tspan = tspan,
epimethod = epimethod,
log_I0_prior = Normal(log(100.0), 1e-5)
log_I0_prior = Normal(log(100.0), 1e-5),
lookahead = lookahead
)

@test config.gi_mean == gi_mean
Expand Down
4 changes: 2 additions & 2 deletions pipeline/test/pipeline/test_pipelinefunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
end

@testset "do_inference tests" begin
using EpiAwarePipeline
using EpiAwarePipeline, Dagger
pipeline = EpiAwareExamplePipeline()

function make_inference()
truthdata = do_truthdata(pipeline)
do_inference(truthdata[1:1], pipeline)
do_inference(truthdata[1], pipeline)
end

inference_results_tsk = make_inference()
Expand Down

0 comments on commit d7f3b69

Please sign in to comment.