Skip to content

Commit

Permalink
Rolling window for inference (#335)
Browse files Browse the repository at this point in the history
* adding make_Rt methods

* more make Rt methods

* Make endemic scenarios geometric mean 1

* Update test_constructors.jl

* prefix methods

* format fix

* Update infer.jl

* add prefix for simulate

* define a `lookback` parameter

* Update make_tspan.jl

* Default params for stride of rolling window

* new make_tspan and modify other areas of code that assume a fixed tspan
  • Loading branch information
SamuelBrand1 authored Jul 5, 2024
1 parent d43f033 commit 15d8aff
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 20 deletions.
6 changes: 5 additions & 1 deletion pipeline/src/constructors/make_default_params.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,16 @@ function make_default_params(pipeline::AbstractEpiAwarePipeline)
α_delay = 4.0
θ_delay = 5.0 / 4.0
lookahead = 21
lookback = 35
stride = 7
return Dict(
"Rt" => Rt,
"logit_daily_ascertainment" => logit_daily_ascertainment,
"cluster_factor" => cluster_factor,
"I0" => I0,
"α_delay" => α_delay,
"θ_delay" => θ_delay,
"lookahead" => lookahead)
"lookahead" => lookahead,
"lookback" => lookback,
"stride" => stride)
end
6 changes: 4 additions & 2 deletions pipeline/src/constructors/make_inference_configs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,20 @@ Create inference configurations for the given pipeline. This is the default meth
- An object representing the inference configurations.
"""
function make_inference_configs(pipeline::AbstractEpiAwarePipeline)
function make_inference_configs(pipeline::AbstractEpiAwarePipeline; start = 21)
gi_param_dict = make_gi_params(pipeline)
namemodel_vect = make_epiaware_name_latentmodel_pairs(pipeline)
igps = make_inf_generating_processes(pipeline)
obs = make_observation_model(pipeline)
priors = make_model_priors(pipeline)
default_params = make_default_params(pipeline)
N = size(make_Rt(pipeline), 1)
Ts = start:default_params["stride"]:N |> collect

inference_configs = Dict("igp" => igps, "latent_namemodels" => namemodel_vect,
"observation_model" => obs, "gi_mean" => gi_param_dict["gi_means"],
"gi_std" => gi_param_dict["gi_stds"], "log_I0_prior" => priors["log_I0_prior"],
"lookahead" => default_params["lookahead"]) |>
"lookahead" => default_params["lookahead"], "lookback" => default_params["lookback"], "T" => Ts) |>
dict_list

selected_inference_configs = _selector(inference_configs, pipeline)
Expand Down
21 changes: 14 additions & 7 deletions pipeline/src/constructors/make_tspan.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
"""
Constructs the time span for the given `pipeline` object.
Constructs a time span for performing inference on a case data time series. This
is the default method.
# Arguments
- `pipeline::AbstractEpiAwarePipeline`: The pipeline object for which the time
span is constructed. This is the default method.
- `pipeline::AbstractEpiAwarePipeline`: The pipeline object used for analysis.
- `T::Union{Integer,Nothing} = nothing`: The `stop` point at which to construct
the time span. If `nothing`, the time span will be constructed using the
length of the Rt vector for `pipeline`.
- `lookback = 35`: The number of days to look back from the specified time point.
# Returns
- `tspan::Tuple{Float64, Float64}`: The time span as a tuple of start and end times.
A tuple `(start, stop)` representing the start and stop indices of the time span.
# Examples
"""
function make_tspan(pipeline::AbstractEpiAwarePipeline; backhorizon = 21)
function make_tspan(pipeline::AbstractEpiAwarePipeline;
T::Union{Integer, Nothing} = nothing, lookback = 35)
N = size(make_Rt(pipeline), 1)
@assert backhorizon<N "Backhorizon must be less than the length of the default Rt."
return (1, N - backhorizon)
_T = isnothing(T) ? N : T
return (max(1, _T - lookback), min(N, _T))
end
14 changes: 9 additions & 5 deletions pipeline/src/infer/generate_inference_results.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ Generate inference results based on the given configuration of inference model o
"""
function generate_inference_results(
truthdata, inference_config, pipeline::AbstractEpiAwarePipeline;
tspan, inference_method, datadir_name = "epiaware_observables")
inference_method, datadir_name = "epiaware_observables")
tspan = make_tspan(
pipeline; T = inference_config["T"], lookback = inference_config["lookback"])
config = InferenceConfig(
inference_config; case_data = truthdata["y_t"], tspan, epimethod = inference_method)

Expand Down Expand Up @@ -46,8 +48,9 @@ which is deleted after the function call.
- `inference_results`: The generated inference results.
"""
function generate_inference_results(
truthdata, inference_config, pipeline::EpiAwareExamplePipeline;
tspan, inference_method)
truthdata, inference_config, pipeline::EpiAwareExamplePipeline; inference_method)
tspan = make_tspan(
pipeline; T = inference_config["T"], lookback = inference_config["lookback"])
config = InferenceConfig(inference_config; case_data = truthdata["y_t"],
tspan = tspan, epimethod = inference_method)

Expand All @@ -65,8 +68,9 @@ end
Method for prior predictive modelling.
"""
function generate_inference_results(
inference_config, pipeline::RtwithoutRenewalPriorPipeline;
tspan)
inference_config, pipeline::RtwithoutRenewalPriorPipeline)
tspan = make_tspan(
pipeline; T = inference_config["T"], lookback = inference_config["lookback"])
config = InferenceConfig(
inference_config; case_data = missing, tspan, epimethod = DirectSample())

Expand Down
4 changes: 2 additions & 2 deletions pipeline/src/infer/map_inference_results.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ tasks from `Dagger.@spawn`.
"""
function map_inference_results(
truthdata, inference_configs, pipeline::AbstractEpiAwarePipeline; tspan, inference_method)
truthdata, inference_configs, pipeline::AbstractEpiAwarePipeline; inference_method)
map(inference_configs) do inference_config
Dagger.@spawn generate_inference_results(
truthdata, inference_config, pipeline; tspan, inference_method)
truthdata, inference_config, pipeline; inference_method)
end
end
3 changes: 1 addition & 2 deletions pipeline/src/pipeline/do_inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@ An array of inference results.
"""
function do_inference(truthdata, pipeline::AbstractEpiAwarePipeline)
inference_configs = make_inference_configs(pipeline)
tspan = make_tspan(pipeline)
inference_method = make_inference_method(pipeline)
inference_results = map_inference_results(
truthdata, inference_configs, pipeline; tspan, inference_method)
truthdata, inference_configs, pipeline; inference_method)
return inference_results
end
4 changes: 3 additions & 1 deletion pipeline/test/constructors/test_constructors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,9 @@ end
"I0" => 100.0,
"α_delay" => 4.0,
"θ_delay" => 5.0 / 4.0,
"lookahead" => 21
"lookahead" => 21,
"lookback" => 35,
"stride" => 7
)

# Test the make_default_params function
Expand Down

0 comments on commit 15d8aff

Please sign in to comment.