From c980e221acccab5ffa942bb033a7d869bdb6c052 Mon Sep 17 00:00:00 2001 From: Samuel Brand <48288458+SamuelBrand1@users.noreply.github.com> Date: Fri, 17 May 2024 13:49:40 +0100 Subject: [PATCH] rolling out a pipeline type system for dispatching on analysis details (#224) * move default constructors into own folder * pipeline types and a new default constructor + tests * More default constructors * reduce to one run_pipeline * Update generate_inference_results.jl * pipeline functions * Update make_truthdata.jl * Constructor functions + tests refactored into folders and added dispatch on pipeline type * rename pipeline functions to `do_...`, refactor in folders * simulation of truth data functions/tests refactored into folders and now dispatch on pipeline type * inference functions/tests refactored into folders and now dispatch on pipeline type * Add a rmprocs at end to clear worker processes * Rewrite toy end-to-end test of inference for generated data * Update runtests.jl * Update AnalysisPipeline.jl * fix style of AnalysisPipeline to be consistent * fix using removed type * move defaults into make_ functions with a dispatch on pipeline type * fix use of default * plot function dispatch on pipeline type * Change generate inference so that name of latent model and model itself are passed as a Pair * refactor unit tests of default functions into unit tests on constructors * Update runtests.jl * Update test_full_inference.jl * remove default methods * abstract the map on inference results to another function * Update plot_functions.jl * catch failure to pass truthdata and reformat --- pipeline/scripts/analysis_pipeline.jl | 84 ---------------- pipeline/scripts/run_pipeline.jl | 26 +++++ pipeline/src/AnalysisPipeline.jl | 49 ++++++---- pipeline/src/constructors/constructors.jl | 10 ++ .../make_Rt.jl} | 5 +- .../make_epiaware_name_model_pairs.jl | 29 ++++++ pipeline/src/constructors/make_gi_params.jl | 16 +++ .../make_inf_generating_processes.jl | 14 +++ .../constructors/make_inference_configs.jl | 21 ++++ .../src/constructors/make_inference_method.jl | 19 ++++ .../constructors/make_latent_model_priors.jl | 26 +++++ .../constructors/make_truth_data_configs.jl | 15 +++ pipeline/src/constructors/make_tspan.jl | 16 +++ pipeline/src/default_epiaware_models.jl | 34 ------- pipeline/src/default_gi_params.jl | 12 --- pipeline/src/default_inference_method.jl | 29 ------ pipeline/src/default_latent_model_priors.jl | 20 ---- pipeline/src/default_latent_models_names.jl | 12 --- pipeline/src/default_tspan.jl | 15 --- pipeline/src/{ => infer}/InferenceConfig.jl | 0 .../{ => infer}/generate_inference_results.jl | 23 +++-- pipeline/src/infer/infer.jl | 3 + pipeline/src/infer/map_inference_results.jl | 23 +++++ pipeline/src/make_inference_configs.jl | 18 ---- pipeline/src/make_truth_data_configs.jl | 14 --- pipeline/src/pipeline/do_inference.jl | 19 ++++ pipeline/src/pipeline/do_pipeline.jl | 14 +++ pipeline/src/pipeline/do_truthdata.jl | 18 ++++ pipeline/src/pipeline/pipeline.jl | 4 + pipeline/src/pipeline/pipelinetypes.jl | 11 +++ pipeline/src/plot_functions.jl | 11 ++- .../{ => simulate}/TruthSimulationConfig.jl | 0 .../src/{ => simulate}/generate_truthdata.jl | 27 ++++-- pipeline/src/simulate/simulate.jl | 2 + .../test/constructors/test_constructors.jl | 97 +++++++++++++++++++ pipeline/test/default_returning_functions.jl | 70 ------------- .../test/{ => infer}/test_InferenceConfig.jl | 0 pipeline/test/pipeline/test_pipelinetypes.jl | 10 ++ pipeline/test/runtests.jl | 10 +- .../{ => simulate}/test_SimulationConfig.jl | 0 .../test_TruthSimulationConfig.jl | 6 +- pipeline/test/test_full_inference.jl | 24 ++--- pipeline/test/test_make_configs.jl | 24 ----- 43 files changed, 480 insertions(+), 400 deletions(-) delete mode 100644 pipeline/scripts/analysis_pipeline.jl create mode 100644 pipeline/scripts/run_pipeline.jl create mode 100644 pipeline/src/constructors/constructors.jl rename pipeline/src/{default_Rt.jl => constructors/make_Rt.jl} (77%) create mode 100644 pipeline/src/constructors/make_epiaware_name_model_pairs.jl create mode 100644 pipeline/src/constructors/make_gi_params.jl create mode 100644 pipeline/src/constructors/make_inf_generating_processes.jl create mode 100644 pipeline/src/constructors/make_inference_configs.jl create mode 100644 pipeline/src/constructors/make_inference_method.jl create mode 100644 pipeline/src/constructors/make_latent_model_priors.jl create mode 100644 pipeline/src/constructors/make_truth_data_configs.jl create mode 100644 pipeline/src/constructors/make_tspan.jl delete mode 100644 pipeline/src/default_epiaware_models.jl delete mode 100644 pipeline/src/default_gi_params.jl delete mode 100644 pipeline/src/default_inference_method.jl delete mode 100644 pipeline/src/default_latent_model_priors.jl delete mode 100644 pipeline/src/default_latent_models_names.jl delete mode 100644 pipeline/src/default_tspan.jl rename pipeline/src/{ => infer}/InferenceConfig.jl (100%) rename pipeline/src/{ => infer}/generate_inference_results.jl (57%) create mode 100644 pipeline/src/infer/infer.jl create mode 100644 pipeline/src/infer/map_inference_results.jl delete mode 100644 pipeline/src/make_inference_configs.jl delete mode 100644 pipeline/src/make_truth_data_configs.jl create mode 100644 pipeline/src/pipeline/do_inference.jl create mode 100644 pipeline/src/pipeline/do_pipeline.jl create mode 100644 pipeline/src/pipeline/do_truthdata.jl create mode 100644 pipeline/src/pipeline/pipeline.jl create mode 100644 pipeline/src/pipeline/pipelinetypes.jl rename pipeline/src/{ => simulate}/TruthSimulationConfig.jl (100%) rename pipeline/src/{ => simulate}/generate_truthdata.jl (53%) create mode 100644 pipeline/src/simulate/simulate.jl create mode 100644 pipeline/test/constructors/test_constructors.jl delete mode 100644 pipeline/test/default_returning_functions.jl rename pipeline/test/{ => infer}/test_InferenceConfig.jl (100%) create mode 100644 pipeline/test/pipeline/test_pipelinetypes.jl rename pipeline/test/{ => simulate}/test_SimulationConfig.jl (100%) rename pipeline/test/{ => simulate}/test_TruthSimulationConfig.jl (85%) delete mode 100644 pipeline/test/test_make_configs.jl diff --git a/pipeline/scripts/analysis_pipeline.jl b/pipeline/scripts/analysis_pipeline.jl deleted file mode 100644 index cfc57b2e3..000000000 --- a/pipeline/scripts/analysis_pipeline.jl +++ /dev/null @@ -1,84 +0,0 @@ -# Analogy: library(targets) in R -using DrWatson - -# Activate the project environment -# Analogy: source(functions.R) in targets -quickactivate(@__DIR__(), "Analysis pipeline") -using Dagger - -@info(""" - Running the analysis pipeline. - --------------------------------------------- - Currently active project is: $(projectname()) - Path of active project: $(projectdir()) - """) - -## Other dependencies -# Analogy: tar_option_set(...) in targets -# Add processes for parallel computing and ensure all have same -# dependencies/environment -using Distributed -addprocs(1) - -@everywhere begin - using DrWatson - quickactivate(@__DIR__(), "Analysis pipeline") - include(srcdir("AnalysisPipeline.jl")) -end - -@everywhere using .AnalysisPipeline - -## Run the pipeline steps -# Analogy: list(tar_targets...) followed by tar_make(...) in targets -# This is an intermediate commit that runs but will be updated with a job -# scheduler for these tasks in the next commit. Probably Dagger.jl - -# Default parameter values and plot true Rt -# @everywhere begin -# default_gi_param_dict = default_gi_params() -# true_Rt = default_Rt() -# plt_Rt = plot_Rt(true_Rt) -# latent_models_dict = default_epiaware_models() -# latent_models_names = Dict(value => key for (key, value) in latent_models_dict) -# tspan = default_tspan() -# inference_method = default_inference_method() -# end - -default_gi_param_dict_thunk = Dagger.@spawn default_gi_params() -true_Rt_thunk = Dagger.@spawn default_Rt() -plt_Rt_thunk = Dagger.@spawn plot_Rt(true_Rt_thunk) -latent_models_dict_thunk = Dagger.@spawn default_epiaware_models() -latent_models_names_thunk = Dagger.@spawn default_latent_models_names() -tspan_thunk = Dagger.@spawn default_tspan() -inference_method_thunk = Dagger.@spawn default_inference_method() - -# fetch the default GI parameters and latent models from their `EagerThunk`s -default_gi_param_dict = fetch(default_gi_param_dict_thunk) -latent_models_dict = fetch(latent_models_dict_thunk) - -# truth data configurations (e.g. different GI means) - -truth_data_configs = make_truth_data_configs( - gi_means = default_gi_param_dict["gi_means"], gi_stds = default_gi_param_dict["gi_stds"]) - -# inference configurations -# (e.g. different infection generation processes and latent models etc.) -inference_configs = make_inference_configs( - latent_models = collect(values(latent_models_dict)), - gi_means = default_gi_param_dict["gi_means"], - gi_stds = default_gi_param_dict["gi_stds"]) - -# Produce and save the truth data -truthdata_from_configs = @sync map(truth_data_configs) do truth_data_config - # generate truth data - truth_thunk = Dagger.@spawn generate_truthdata_from_config( - truth_data_config; plot = true) - # # Run the inference scenarios - for inference_config in inference_configs - inference_thunks = Dagger.@spawn generate_inference_results( - truth_thunk, inference_config; tspan_thunk, inference_method_thunk, - truth_data_config, latent_models_names_thunk) - end - truthdata = fetch(truth_thunk) - return truthdata -end diff --git a/pipeline/scripts/run_pipeline.jl b/pipeline/scripts/run_pipeline.jl new file mode 100644 index 000000000..c9982e24a --- /dev/null +++ b/pipeline/scripts/run_pipeline.jl @@ -0,0 +1,26 @@ +# Local environment script to run the analysis pipeline +using Pkg +Pkg.activate(joinpath(@__DIR__(), "..")) +using Dagger + +@info(""" + Running the analysis pipeline. + -------------------------------------------- + """) + +# Define the backend resources to use for the pipeline +# in this case we are using distributed local workers with loaded modules +using Distributed +pids = addprocs() + +@everywhere include("../src/AnalysisPipeline.jl") +@everywhere using .AnalysisPipeline + +# Create an instance of the pipeline behaviour +pipeline = RtwithoutRenewalPipeline() + +# Run the pipeline +do_pipeline(pipeline) + +# Remove the workers +rmprocs(pids) diff --git a/pipeline/src/AnalysisPipeline.jl b/pipeline/src/AnalysisPipeline.jl index aa06c22b8..87aee2fa7 100644 --- a/pipeline/src/AnalysisPipeline.jl +++ b/pipeline/src/AnalysisPipeline.jl @@ -1,34 +1,43 @@ """ This module contains the analysis pipeline for the `Rt-without-renewal` project. + +# Pipeline Components + +In this module the meaning of a _pipeline component_ is a directed-acylic-graph +(DAG) of tasks defined using `Dagger.jl` via dispatch on an `AbstractEpiAwarePipeline` +sub-type from a function with prefix `do_`. A full pipeline is a sequence of DAGs, +with execution determined by available computational resources. """ module AnalysisPipeline -using Dates: default using CSV, Dagger, DataFramesMeta, Dates, Distributions, DocStringExtensions, DrWatson, EpiAware, Plots, Statistics, ADTypes, AbstractMCMC, Plots, JLD2 # Exported struct types -export TruthSimulationConfig, InferenceConfig +export AbstractEpiAwarePipeline, EpiAwarePipeline, RtwithoutRenewalPipeline, + TruthSimulationConfig, InferenceConfig + +# Exported functions: constructors +export make_gi_params, make_inf_generating_processes, make_latent_model_priors, + make_epiaware_name_model_pairs, make_Rt, make_truth_data_configs, + make_inference_configs, make_tspan, make_inference_method + +# Exported functions: pipeline components +export do_truthdata, do_inference, do_pipeline + +# Exported functions: simulate functions +export simulate, generate_truthdata + +# Exported functions: infer functions +export infer, generate_inference_results, map_inference_results -# Exported functions -export simulate, infer, default_gi_params, default_Rt, default_tspan, - default_latent_model_priors, default_epiaware_models, default_inference_method, - default_latent_models_names, make_truth_data_configs, make_inference_configs, - generate_truthdata_from_config, generate_inference_results, plot_truth_data, plot_Rt +# Exported functions: plot functions +export plot_truth_data, plot_Rt include("docstrings.jl") -include("default_gi_params.jl") -include("default_Rt.jl") -include("default_tspan.jl") -include("default_latent_model_priors.jl") -include("default_epiaware_models.jl") -include("default_inference_method.jl") -include("make_truth_data_configs.jl") -include("make_inference_configs.jl") -include("default_latent_models_names.jl") -include("TruthSimulationConfig.jl") -include("InferenceConfig.jl") -include("generate_truthdata.jl") -include("generate_inference_results.jl") +include("pipeline/pipeline.jl") +include("constructors/constructors.jl") +include("simulate/simulate.jl") +include("infer/infer.jl") include("plot_functions.jl") end diff --git a/pipeline/src/constructors/constructors.jl b/pipeline/src/constructors/constructors.jl new file mode 100644 index 000000000..2c42487db --- /dev/null +++ b/pipeline/src/constructors/constructors.jl @@ -0,0 +1,10 @@ +include("make_gi_params.jl") +include("make_inf_generating_processes.jl") +include("make_latent_model_priors.jl") +include("make_epiaware_name_model_pairs.jl") +include("make_inference_method.jl") +include("make_truth_data_configs.jl") +include("make_inference_configs.jl") +include("make_Rt.jl") +include("make_tspan.jl") +include("make_inference_method.jl") diff --git a/pipeline/src/default_Rt.jl b/pipeline/src/constructors/make_Rt.jl similarity index 77% rename from pipeline/src/default_Rt.jl rename to pipeline/src/constructors/make_Rt.jl index 127f5b874..fd1a7fba1 100644 --- a/pipeline/src/default_Rt.jl +++ b/pipeline/src/constructors/make_Rt.jl @@ -1,5 +1,6 @@ """ -Compute the default Rt values over time. +Compute the default Rt values over time for generating truth data. This is the +default method. ## keyword Arguments - `A`: Amplitude of the sinusoidal variation in Rt. Default is 0.3. @@ -9,7 +10,7 @@ Compute the default Rt values over time. - `true_Rt`: Array of default Rt values over time. """ -function default_Rt(; A = 0.3, P = 30.0) +function make_Rt(pipeline::AbstractEpiAwarePipeline; A = 0.3, P = 30.0) ϕ = asin(-0.1 / 0.3) * P / (2 * π) N = 160 true_Rt = vcat(fill(1.1, 2 * 7), fill(2.0, 2 * 7), fill(0.5, 2 * 7), diff --git a/pipeline/src/constructors/make_epiaware_name_model_pairs.jl b/pipeline/src/constructors/make_epiaware_name_model_pairs.jl new file mode 100644 index 000000000..c2150be01 --- /dev/null +++ b/pipeline/src/constructors/make_epiaware_name_model_pairs.jl @@ -0,0 +1,29 @@ +""" +Constructs a dictionary of name-model pairs for the EpiAware pipeline. This is +the default method. + +# Arguments +- `pipeline::AbstractEpiaAwarePipeline`: The EpiAware pipeline object. + +# Returns +A dictionary containing the name-model pairs. + +""" +function make_epiaware_name_model_pairs(pipeline::AbstractEpiAwarePipeline) + prior_dict = make_latent_model_priors(pipeline) + + ar = AR(damp_priors = [prior_dict["damp_param_prior"]], + std_prior = prior_dict["std_prior"], + init_priors = [prior_dict["transformed_process_init_prior"]]) + + rw = RandomWalk( + std_prior = prior_dict["std_prior"], init_prior = prior_dict["transformed_process_init_prior"]) + + diff_ar = DiffLatentModel(; + model = ar, init_priors = [prior_dict["transformed_process_init_prior"]]) + + wkly_ar, wkly_rw, wkly_diff_ar = [ar, rw, diff_ar] .|> + model -> BroadcastLatentModel(model, 7, RepeatBlock()) + + return ["wkly_ar" => wkly_ar, "wkly_rw" => wkly_rw, "wkly_diff_ar" => wkly_diff_ar] +end diff --git a/pipeline/src/constructors/make_gi_params.jl b/pipeline/src/constructors/make_gi_params.jl new file mode 100644 index 000000000..9083fe2ff --- /dev/null +++ b/pipeline/src/constructors/make_gi_params.jl @@ -0,0 +1,16 @@ +""" +Constructs a dictionary of GI (Generation Interval) parameters. This is the +default method. + +# Arguments +- `pipeline`: An instance of the `AbstractEpiAwarePipeline` type. + +# Returns +A dictionary containing the GI means and GI standard deviations. + +""" +function make_gi_params(pipeline::AbstractEpiAwarePipeline) + gi_means = [2.0, 10.0, 20.0] + gi_stds = [2.0] + return Dict("gi_means" => gi_means, "gi_stds" => gi_stds) +end diff --git a/pipeline/src/constructors/make_inf_generating_processes.jl b/pipeline/src/constructors/make_inf_generating_processes.jl new file mode 100644 index 000000000..6b23811e9 --- /dev/null +++ b/pipeline/src/constructors/make_inf_generating_processes.jl @@ -0,0 +1,14 @@ +""" +Constructs and returns a vector of infection-generating process types for the given +pipeline. This is the default method. + +# Arguments +- `pipeline`: An instance of `AbstractEpiAwarePipeline`. + +# Returns +An array of infection-generating process types. + +""" +function make_inf_generating_processes(pipeline::AbstractEpiAwarePipeline) + return [DirectInfections, ExpGrowthRate, Renewal] +end diff --git a/pipeline/src/constructors/make_inference_configs.jl b/pipeline/src/constructors/make_inference_configs.jl new file mode 100644 index 000000000..4f0c4f08f --- /dev/null +++ b/pipeline/src/constructors/make_inference_configs.jl @@ -0,0 +1,21 @@ +""" +Create inference configurations for the given pipeline. This is the default method. + +# Arguments +- `pipeline`: An instance of `AbstractEpiAwarePipeline`. + +# Returns +- An object representing the inference configurations. + +""" +function make_inference_configs(pipeline::AbstractEpiAwarePipeline) + gi_param_dict = make_gi_params(pipeline) + namemodel_vect = make_epiaware_name_model_pairs(pipeline) + igps = make_inf_generating_processes(pipeline) + + inference_configs = Dict("igp" => igps, "latent_namemodels" => namemodel_vect, + "gi_mean" => gi_param_dict["gi_means"], "gi_std" => gi_param_dict["gi_stds"]) |> + dict_list + + return inference_configs +end diff --git a/pipeline/src/constructors/make_inference_method.jl b/pipeline/src/constructors/make_inference_method.jl new file mode 100644 index 000000000..2d4d4b18d --- /dev/null +++ b/pipeline/src/constructors/make_inference_method.jl @@ -0,0 +1,19 @@ +""" +Constructs an inference method for the given pipeline. This is a default method. + +# Arguments +- `pipeline`: An instance of `AbstractEpiAwarePipeline`. + +# Returns +- An inference method. + +""" +function make_inference_method(pipeline::AbstractEpiAwarePipeline; ndraws::Integer = 2000, + mcmc_ensemble::AbstractMCMC.AbstractMCMCEnsemble = MCMCSerial(), + nruns_pthf::Integer = 4, maxiters_pthf::Integer = 100, nchains::Integer = 4) + return EpiMethod( + pre_sampler_steps = [ManyPathfinder(nruns = nruns_pthf, maxiters = maxiters_pthf)], + sampler = NUTSampler(adtype = AutoForwardDiff(), ndraws = ndraws, + nchains = nchains, mcmc_parallel = mcmc_ensemble) + ) +end diff --git a/pipeline/src/constructors/make_latent_model_priors.jl b/pipeline/src/constructors/make_latent_model_priors.jl new file mode 100644 index 000000000..b23764b93 --- /dev/null +++ b/pipeline/src/constructors/make_latent_model_priors.jl @@ -0,0 +1,26 @@ +""" +Constructs and returns a dictionary of prior distributions for the latent model +parameters. This is the default method. + +# Arguments +- `pipeline`: An instance of the `AbstractEpiAwarePipeline` type. + +# Returns +A dictionary containing the following prior distributions: +- `"transformed_process_init_prior"`: A normal distribution with mean 0.0 and +standard deviation 0.25. +- `"std_prior"`: A half-normal distribution with standard deviation 0.25. +- `"damp_param_prior"`: A beta distribution with shape parameters 0.5 and 0.5. + +""" +function make_latent_model_priors(pipeline::AbstractEpiAwarePipeline) + transformed_process_init_prior = Normal(0.0, 0.25) + std_prior = HalfNormal(0.25) + damp_param_prior = Beta(0.5, 0.5) + + return Dict( + "transformed_process_init_prior" => transformed_process_init_prior, + "std_prior" => std_prior, + "damp_param_prior" => damp_param_prior + ) +end diff --git a/pipeline/src/constructors/make_truth_data_configs.jl b/pipeline/src/constructors/make_truth_data_configs.jl new file mode 100644 index 000000000..f122c1179 --- /dev/null +++ b/pipeline/src/constructors/make_truth_data_configs.jl @@ -0,0 +1,15 @@ +""" +Create a dictionary of truth data configurations for `pipeline <: AbstractEpiAwarePipeline`. + This is the default method. + +# Returns +A vector of dictionaries containing the mean and standard deviation values for + the generation interval. + +""" +function make_truth_data_configs(pipeline::AbstractEpiAwarePipeline) + gi_param_dict = make_gi_params(pipeline) + return Dict( + "gi_mean" => gi_param_dict["gi_means"], "gi_std" => gi_param_dict["gi_stds"]) |> + dict_list +end diff --git a/pipeline/src/constructors/make_tspan.jl b/pipeline/src/constructors/make_tspan.jl new file mode 100644 index 000000000..417d08263 --- /dev/null +++ b/pipeline/src/constructors/make_tspan.jl @@ -0,0 +1,16 @@ +""" +Constructs the time span for the given `pipeline` object. + +# Arguments +- `pipeline::AbstractEpiAwarePipeline`: The pipeline object for which the time + span is constructed. This is the default method. + +# Returns +- `tspan::Tuple{Float64, Float64}`: The time span as a tuple of start and end times. + +""" +function make_tspan(pipeline::AbstractEpiAwarePipeline; backhorizon = 21) + N = size(make_Rt(pipeline), 1) + @assert backhorizon - model -> BroadcastLatentModel(model, 7, RepeatBlock()) - - return Dict("wkly_ar" => wkly_ar, "wkly_rw" => wkly_rw, "wkly_diff_ar" => wkly_diff_ar) -end diff --git a/pipeline/src/default_gi_params.jl b/pipeline/src/default_gi_params.jl deleted file mode 100644 index e3288efcf..000000000 --- a/pipeline/src/default_gi_params.jl +++ /dev/null @@ -1,12 +0,0 @@ -""" -Constructs a dictionary containing default values for the parameters `gi_means` and `gi_stds`. - -# Returns -- `Dict`: A dictionary with keys `"gi_means"` and `"gi_stds"`, and corresponding default values. - -""" -function default_gi_params() - gi_means = [2.0, 10.0, 20.0] - gi_stds = [2.0] - return Dict("gi_means" => gi_means, "gi_stds" => gi_stds) -end diff --git a/pipeline/src/default_inference_method.jl b/pipeline/src/default_inference_method.jl deleted file mode 100644 index 877f0b4ba..000000000 --- a/pipeline/src/default_inference_method.jl +++ /dev/null @@ -1,29 +0,0 @@ -""" -Constructs and returns an `EpiMethod` object with default settings for inference. - -# Arguments -- `max_threads::Integer`: The maximum number of threads to use for parallelization. - Default is 10. -- `ndraws::Integer`: The number of MCMC samples to draw. Default is 2000. -- `mcmc_ensemble::AbstractMCMC.AbstractMCMCEnsemble`: The MCMC ensemble to use - for parallelization. Default is `MCMCSerial()`; that is no parallelization. -- `nruns_pthf::Integer`: The number of runs for the pre-sampler steps. - Default is 4. -- `maxiters_pthf::Integer`: The maximum number of iterations for the pre-sampler - steps. Default is 100. -- `nchains::Integer`: The number of MCMC chains to run. Default is 2. - -# Returns -An `EpiMethod` object with the specified settings. -""" -function default_inference_method(; max_threads::Integer = 10, ndraws::Integer = 2000, - mcmc_ensemble::AbstractMCMC.AbstractMCMCEnsemble = MCMCSerial(), - nruns_pthf::Integer = 4, maxiters_pthf::Integer = 100, nchains::Integer = 2) - return EpiMethod( - pre_sampler_steps = [ManyPathfinder(nruns = nruns_pthf, maxiters = maxiters_pthf)], - sampler = NUTSampler(adtype = AutoForwardDiff(), - ndraws = ndraws, - nchains = nchains, - mcmc_parallel = mcmc_ensemble) - ) -end diff --git a/pipeline/src/default_latent_model_priors.jl b/pipeline/src/default_latent_model_priors.jl deleted file mode 100644 index 30511ce91..000000000 --- a/pipeline/src/default_latent_model_priors.jl +++ /dev/null @@ -1,20 +0,0 @@ -""" -Constructs a dictionary of default prior distributions for the parameters used - in `EpiAware` models. - -# Returns -- `Dict{String, Distribution}`: A dictionary containing the default prior - distributions. - -""" -function default_latent_model_priors() - transformed_process_init_prior = Normal(0.0, 0.25) - std_prior = HalfNormal(0.25) - damp_param_prior = Beta(0.5, 0.5) - - return Dict( - "transformed_process_init_prior" => transformed_process_init_prior, - "std_prior" => std_prior, - "damp_param_prior" => damp_param_prior - ) -end diff --git a/pipeline/src/default_latent_models_names.jl b/pipeline/src/default_latent_models_names.jl deleted file mode 100644 index 0d2169691..000000000 --- a/pipeline/src/default_latent_models_names.jl +++ /dev/null @@ -1,12 +0,0 @@ -""" -Returns a dictionary mapping the default latent models to their corresponding - names. - -# Returns -- `Dict{Any, Any}`: A dictionary mapping the default latent models to their - corresponding names. -""" -function default_latent_models_names() - latent_models_dict = default_epiaware_models() - return Dict(value => key for (key, value) in latent_models_dict) -end diff --git a/pipeline/src/default_tspan.jl b/pipeline/src/default_tspan.jl deleted file mode 100644 index 1a6f93c8c..000000000 --- a/pipeline/src/default_tspan.jl +++ /dev/null @@ -1,15 +0,0 @@ -""" -Compute the default time span for the Rt calculation. - -# Arguments -- `backhorizon::Int`: The number of days to look back for the time span calculation. Default is 21. - -# Returns -- A tuple `(start, stop)` representing the default time span for the Rt calculation. - -""" -function default_tspan(; backhorizon = 21) - N = length(default_Rt()) - @assert backhorizon igps, "latent_model" => latent_models, - "gi_mean" => gi_means, "gi_std" => gi_stds) |> dict_list -end diff --git a/pipeline/src/make_truth_data_configs.jl b/pipeline/src/make_truth_data_configs.jl deleted file mode 100644 index 96c831fcb..000000000 --- a/pipeline/src/make_truth_data_configs.jl +++ /dev/null @@ -1,14 +0,0 @@ -""" -Create a dictionary of truth data configurations. - -# Arguments -- `gi_means`: The mean values for gi. -- `gi_stds`: The standard deviations for gi. - -# Returns -A dictionary containing the mean and standard deviation values for gi. - -""" -function make_truth_data_configs(; gi_means, gi_stds) - Dict("gi_mean" => gi_means, "gi_std" => gi_stds) |> dict_list -end diff --git a/pipeline/src/pipeline/do_inference.jl b/pipeline/src/pipeline/do_inference.jl new file mode 100644 index 000000000..57f47254c --- /dev/null +++ b/pipeline/src/pipeline/do_inference.jl @@ -0,0 +1,19 @@ +""" +Generate inference results using the specified truth data and pipeline. + +# Arguments +- `truthdata`: The truth data used for generating inference results. +- `pipeline`: An instance of the `AbstractAbstractEpiAwarePipeline` sub-type. + +# Returns +An array of inference results. + +""" +function do_inference(truthdata, pipeline::AbstractEpiAwarePipeline) + inference_configs = make_inference_configs(pipeline) + tspan = make_tspan(pipeline) + inference_method = make_inference_method(pipeline) + inference_results = map_inference_results( + truthdata, inference_configs, pipeline; tspan, inference_method) + return inference_results +end diff --git a/pipeline/src/pipeline/do_pipeline.jl b/pipeline/src/pipeline/do_pipeline.jl new file mode 100644 index 000000000..acc7f79ef --- /dev/null +++ b/pipeline/src/pipeline/do_pipeline.jl @@ -0,0 +1,14 @@ +""" +Create a pipeline by generating truth data and making inferences. + +# Arguments +- `pipeline::AbstractEpiAwarePipeline`: The pipeline object which sets pipeline behavior. + +""" +function do_pipeline(pipeline::AbstractEpiAwarePipeline) + truthdatas = do_truthdata(pipeline) + for truthdata in truthdatas + do_inference(truthdata, pipeline) + end + return nothing +end diff --git a/pipeline/src/pipeline/do_truthdata.jl b/pipeline/src/pipeline/do_truthdata.jl new file mode 100644 index 000000000..3781bc3f5 --- /dev/null +++ b/pipeline/src/pipeline/do_truthdata.jl @@ -0,0 +1,18 @@ +""" +Generate truth data for the EpiAwarePipeline. + +# Arguments +- `pipeline::EpiAwarePipeline`: The EpiAwarePipeline object. + +# Returns +An array of truth data generated from the given pipeline. + +""" +function do_truthdata(pipeline::AbstractEpiAwarePipeline) + truth_data_configs = make_truth_data_configs(pipeline) + truthdata_from_configs = map(truth_data_configs) do truth_data_config + return Dagger.@spawn cache=true generate_truthdata( + truth_data_config, pipeline; plot = false) + end + return truthdata_from_configs +end diff --git a/pipeline/src/pipeline/pipeline.jl b/pipeline/src/pipeline/pipeline.jl new file mode 100644 index 000000000..401d04df5 --- /dev/null +++ b/pipeline/src/pipeline/pipeline.jl @@ -0,0 +1,4 @@ +include("pipelinetypes.jl") +include("do_truthdata.jl") +include("do_inference.jl") +include("do_pipeline.jl") diff --git a/pipeline/src/pipeline/pipelinetypes.jl b/pipeline/src/pipeline/pipelinetypes.jl new file mode 100644 index 000000000..d5555ea17 --- /dev/null +++ b/pipeline/src/pipeline/pipelinetypes.jl @@ -0,0 +1,11 @@ +""" +The abstract root type for all pipeline types using `EpiAware`. +""" +abstract type AbstractEpiAwarePipeline end + +""" +The pipeline type for the Rt pipeline with renewal including specific options + for plotting and saving. +""" +struct RtwithoutRenewalPipeline <: AbstractEpiAwarePipeline +end diff --git a/pipeline/src/plot_functions.jl b/pipeline/src/plot_functions.jl index 6af3b347d..989923476 100644 --- a/pipeline/src/plot_functions.jl +++ b/pipeline/src/plot_functions.jl @@ -1,14 +1,17 @@ """ -Plot the true cases and latent infections. +Plot the true cases and latent infections. This is the default method for plotting. # Arguments - `data`: A dictionary containing the data for plotting. - `config`: The configuration for the truth data scenario. +- `pipeline::AbstractEpiAwarePipeline`: The pipeline object which sets pipeline + behavior. # Returns - `plt_cases`: The plot object representing the cases and latent infections. """ -function plot_truth_data(data, config; plotsname = "truth_data") +function plot_truth_data( + data, config, pipeline::AbstractEpiAwarePipeline; plotsname = "truth_data") plt_cases = scatter( data["y_t"], label = "Cases", xlabel = "Time", ylabel = "Daily cases", title = "Cases and latent infections", legend = :bottomright) @@ -27,12 +30,14 @@ Plot and save the plot of the true Rt values over time. # Arguments - `true_Rt`: An array of true Rt values. +- `pipeline::AbstractEpiAwarePipeline`: The pipeline object which sets pipeline + behavior. # Returns - `plt_Rt`: The plot object. """ -function plot_Rt(true_Rt) +function plot_Rt(true_Rt, pipeline::AbstractEpiAwarePipeline) plt_Rt = plot(true_Rt, label = "True Rt", xlabel = "Time", ylabel = "Rt", title = "True Rt", legend = :topright) diff --git a/pipeline/src/TruthSimulationConfig.jl b/pipeline/src/simulate/TruthSimulationConfig.jl similarity index 100% rename from pipeline/src/TruthSimulationConfig.jl rename to pipeline/src/simulate/TruthSimulationConfig.jl diff --git a/pipeline/src/generate_truthdata.jl b/pipeline/src/simulate/generate_truthdata.jl similarity index 53% rename from pipeline/src/generate_truthdata.jl rename to pipeline/src/simulate/generate_truthdata.jl index 768042819..febfea262 100644 --- a/pipeline/src/generate_truthdata.jl +++ b/pipeline/src/simulate/generate_truthdata.jl @@ -1,26 +1,35 @@ """ -Generate truth data from a configuration file. It does this by converting the configuration dictionary into a `TruthSimulationConfig` object and then calling the `simulate` function to generate the truth data. +Generate truth data from a configuration file. It does this by converting the +configuration dictionary into a `TruthSimulationConfig` object and then calling +the `simulate` function to generate the truth data. This is the default method. # Arguments -- `truth_data_config`: A dictionary containing the configuration parameters for generating truth data. -- `plot`: A boolean indicating whether to plot the generated truth data. Default is `true`. -- `datadir_str`: A string specifying the directory to save the truth data. Default is `"truth_data"`. -- `prefix`: A string specifying the prefix for the truth data file name. Default is `"truth_data"`. +- `truth_data_config`: A dictionary containing the configuration parameters for + generating truth data. +- `pipeline::AbstractEpiAwarePipeline`: The pipeline object which sets pipeline + behavior. +- `plot`: A boolean indicating whether to plot the generated truth data. Default + is `true`. +- `datadir_str`: A string specifying the directory to save the truth data. + Default is `"truth_data"`. +- `prefix`: A string specifying the prefix for the truth data file name. + Default is `"truth_data"`. # Returns - `truthdata`: The generated truth data. - `truthfile`: The file path where the truth data is saved. """ -function generate_truthdata_from_config( - truth_data_config; plot = true, datadir_str = "truth_data", prefix = "truth_data") - true_Rt = default_Rt() +function generate_truthdata( + truth_data_config, pipeline::AbstractEpiAwarePipeline; plot = true, + datadir_str = "truth_data", prefix = "truth_data") + true_Rt = make_Rt(pipeline) config = TruthSimulationConfig( truth_process = true_Rt, gi_mean = truth_data_config["gi_mean"], gi_std = truth_data_config["gi_std"]) truthdata, truthfile = produce_or_load( simulate, config, datadir(datadir_str); prefix = prefix) if plot - plot_truth_data(truthdata, config) + plot_truth_data(truthdata, config, pipeline) end return truthdata end diff --git a/pipeline/src/simulate/simulate.jl b/pipeline/src/simulate/simulate.jl new file mode 100644 index 000000000..b3e843f70 --- /dev/null +++ b/pipeline/src/simulate/simulate.jl @@ -0,0 +1,2 @@ +include("TruthSimulationConfig.jl") +include("generate_truthdata.jl") diff --git a/pipeline/test/constructors/test_constructors.jl b/pipeline/test/constructors/test_constructors.jl new file mode 100644 index 000000000..c9fae07cc --- /dev/null +++ b/pipeline/test/constructors/test_constructors.jl @@ -0,0 +1,97 @@ +@testset "make_gi_params: returns a dictionary with correct keys" begin + using .AnalysisPipeline + pipeline = RtwithoutRenewalPipeline() + params = make_gi_params(pipeline) + + @test params isa Dict + @test haskey(params, "gi_means") + @test haskey(params, "gi_stds") +end + +@testset "make_inf_generating_processes" begin + using .AnalysisPipeline, EpiAware + pipeline = RtwithoutRenewalPipeline() + igps = make_inf_generating_processes(pipeline) + @test igps == [DirectInfections, ExpGrowthRate, Renewal] +end + +@testset "make_Rt: returns an array" begin + using .AnalysisPipeline + pipeline = RtwithoutRenewalPipeline() + + Rt = make_Rt(pipeline) + @test Rt isa Array +end + +@testset "default_tspan: returns an Tuple{Integer, Integer}" begin + using .AnalysisPipeline + pipeline = RtwithoutRenewalPipeline() + + tspan = make_tspan(pipeline) + @test tspan isa Tuple{Integer, Integer} +end + +@testset "make_latent_model_priors: generates a dict with correct keys and distributions" begin + using .AnalysisPipeline, Distributions + pipeline = RtwithoutRenewalPipeline() + + priors_dict = make_latent_model_priors(pipeline) + + # Check if the priors dictionary is constructed correctly + @test haskey(priors_dict, "transformed_process_init_prior") + @test haskey(priors_dict, "std_prior") + @test haskey(priors_dict, "damp_param_prior") + + # Check if the values are all distributions + @test valtype(priors_dict) <: Distribution +end + +@testset "make_epiaware_name_model_pairs: generates a vector of Pairs with correct keys and latent models" begin + using .AnalysisPipeline, EpiAware + pipeline = RtwithoutRenewalPipeline() + + namemodel_vect = make_epiaware_name_model_pairs(pipeline) + + @test first.(namemodel_vect) == ["wkly_ar", "wkly_rw", "wkly_diff_ar"] + @test all([model isa BroadcastLatentModel for model in last.(namemodel_vect)]) +end + +@testset "make_inference_method: constructor and defaults" begin + using .AnalysisPipeline, EpiAware, ADTypes, AbstractMCMC + pipeline = RtwithoutRenewalPipeline() + + method = make_inference_method(pipeline) + + @test length(method.pre_sampler_steps) == 1 + @test method.pre_sampler_steps[1] isa ManyPathfinder + @test method.pre_sampler_steps[1].nruns == 4 + @test method.pre_sampler_steps[1].maxiters == 100 + @test method.sampler isa NUTSampler + @test method.sampler.adtype == AutoForwardDiff() + @test method.sampler.ndraws == 2000 + @test method.sampler.nchains == 4 + @test method.sampler.mcmc_parallel == MCMCSerial() +end + +@testset "make_truth_data_configs" begin + using .AnalysisPipeline + pipeline = RtwithoutRenewalPipeline() + @testset "make_truth_data_configs should return a dictionary" begin + config_dicts = make_truth_data_configs(pipeline) + @test eltype(config_dicts) <: Dict + end + + @testset "make_truth_data_configs should contain gi_mean and gi_std keys" begin + config_dicts = make_truth_data_configs(pipeline) + @test all(config_dicts .|> config -> haskey(config, "gi_mean")) + @test all(config_dicts .|> config -> haskey(config, "gi_std")) + end +end + +@testset "default inference configurations" begin + using .AnalysisPipeline + pipeline = RtwithoutRenewalPipeline() + + inference_configs = make_inference_configs(pipeline) + @test eltype(inference_configs) <: Dict +end diff --git a/pipeline/test/default_returning_functions.jl b/pipeline/test/default_returning_functions.jl deleted file mode 100644 index b9f41784f..000000000 --- a/pipeline/test/default_returning_functions.jl +++ /dev/null @@ -1,70 +0,0 @@ -@testset "default_gi_params: returns a dictionary with correct keys" begin - using .AnalysisPipeline - - params = default_gi_params() - @test params isa Dict - @test haskey(params, "gi_means") - @test haskey(params, "gi_stds") -end - -@testset "default_Rt: returns an array" begin - using .AnalysisPipeline - - Rt = default_Rt() - @test Rt isa Array -end - -@testset "default_tspan: returns an Tuple{Integer, Integer}" begin - using .AnalysisPipeline - - tspan = default_tspan() - @test tspan isa Tuple{Integer, Integer} -end - -@testset "default_priors: generates a dict with correct keys and distributions" begin - using .AnalysisPipeline, Distributions - # Call the default_priors function - priors_dict = default_latent_model_priors() - - # Check if the priors dictionary is constructed correctly - @test haskey(priors_dict, "transformed_process_init_prior") - @test haskey(priors_dict, "std_prior") - @test haskey(priors_dict, "damp_param_prior") - - # Check if the values are all distributions - @test valtype(priors_dict) <: Distribution -end - -@testset "default_epiaware_models: generates a dict with correct keys and latent models" begin - using .AnalysisPipeline, EpiAware - - models_dict = default_epiaware_models() - - @test haskey(models_dict, "wkly_ar") - @test haskey(models_dict, "wkly_rw") - @test haskey(models_dict, "wkly_diff_ar") - @test valtype(models_dict) <: BroadcastLatentModel -end - -@testset "default_inference_method: constructor and defaults" begin - using .AnalysisPipeline, EpiAware, ADTypes, AbstractMCMC - - method = default_inference_method() - @test length(method.pre_sampler_steps) == 1 - @test method.pre_sampler_steps[1] isa ManyPathfinder - @test method.pre_sampler_steps[1].nruns == 4 - @test method.pre_sampler_steps[1].maxiters == 100 - @test method.sampler isa NUTSampler - @test method.sampler.adtype == AutoForwardDiff() - @test method.sampler.ndraws == 2000 - @test method.sampler.nchains == 2 - @test method.sampler.mcmc_parallel == MCMCSerial() -end - -@testset "Test default_latent_models_names" begin - using .AnalysisPipeline - - modelnames = default_latent_models_names() - @test length(modelnames) == 3 - @test modelnames isa Dict -end diff --git a/pipeline/test/test_InferenceConfig.jl b/pipeline/test/infer/test_InferenceConfig.jl similarity index 100% rename from pipeline/test/test_InferenceConfig.jl rename to pipeline/test/infer/test_InferenceConfig.jl diff --git a/pipeline/test/pipeline/test_pipelinetypes.jl b/pipeline/test/pipeline/test_pipelinetypes.jl new file mode 100644 index 000000000..d4ebaff1e --- /dev/null +++ b/pipeline/test/pipeline/test_pipelinetypes.jl @@ -0,0 +1,10 @@ +@testset "EpiAwarePipeline Tests" begin + using .AnalysisPipeline + @testset "AbstractEpiAwarePipeline" begin + @test_throws MethodError AbstractEpiAwarePipeline() + end + @testset "RtwithoutRenewalPipeline" begin + @test isa(RtwithoutRenewalPipeline(), RtwithoutRenewalPipeline) + @test RtwithoutRenewalPipeline <: AbstractEpiAwarePipeline + end +end diff --git a/pipeline/test/runtests.jl b/pipeline/test/runtests.jl index 77a19dbd6..b70dea1ff 100644 --- a/pipeline/test/runtests.jl +++ b/pipeline/test/runtests.jl @@ -5,8 +5,8 @@ quickactivate(@__DIR__(), "Analysis pipeline") include(srcdir("AnalysisPipeline.jl")); #run tests -include("default_returning_functions.jl"); -include("test_make_configs.jl"); -include("test_SimulationConfig.jl"); -include("test_TruthSimulationConfig.jl"); -include("test_InferenceConfig.jl"); +include("pipeline/test_pipelinetypes.jl"); +include("constructors/test_constructors.jl"); +include("simulate/test_TruthSimulationConfig.jl"); +include("simulate/test_SimulationConfig.jl"); +include("infer/test_InferenceConfig.jl"); diff --git a/pipeline/test/test_SimulationConfig.jl b/pipeline/test/simulate/test_SimulationConfig.jl similarity index 100% rename from pipeline/test/test_SimulationConfig.jl rename to pipeline/test/simulate/test_SimulationConfig.jl diff --git a/pipeline/test/test_TruthSimulationConfig.jl b/pipeline/test/simulate/test_TruthSimulationConfig.jl similarity index 85% rename from pipeline/test/test_TruthSimulationConfig.jl rename to pipeline/test/simulate/test_TruthSimulationConfig.jl index 1e87afcc0..d82a13ff2 100644 --- a/pipeline/test/test_TruthSimulationConfig.jl +++ b/pipeline/test/simulate/test_TruthSimulationConfig.jl @@ -18,9 +18,11 @@ @test all(growth_up) end -@testset "generate_truthdata_from_config" begin +@testset "generate_truthdata" begin + using .AnalysisPipeline + pipeline = RtwithoutRenewalPipeline() truth_data_config = Dict("gi_mean" => 0.5, "gi_std" => 0.1) - truthdata = generate_truthdata_from_config(truth_data_config) + truthdata = generate_truthdata(truth_data_config, pipeline) @test haskey(truthdata, "I_t") @test haskey(truthdata, "y_t") diff --git a/pipeline/test/test_full_inference.jl b/pipeline/test/test_full_inference.jl index 61bbd6d62..7e151a559 100644 --- a/pipeline/test/test_full_inference.jl +++ b/pipeline/test/test_full_inference.jl @@ -5,26 +5,16 @@ using Test include(srcdir("AnalysisPipeline.jl")) using .AnalysisPipeline + pipeline = RtwithoutRenewalPipeline() - default_gi_param_dict = default_gi_params() - true_Rt = default_Rt() - latent_models_dict = default_epiaware_models() - latent_models_names = Dict(value => key for (key, value) in latent_models_dict) tspan = (1, 28) - inference_method = default_inference_method() - - truth_data_config = make_truth_data_configs( - gi_means = default_gi_param_dict["gi_means"], - gi_stds = default_gi_param_dict["gi_stds"])[1] - inference_configs = make_inference_configs( - latent_models = collect(values(latent_models_dict)), - gi_means = default_gi_param_dict["gi_means"], - gi_stds = default_gi_param_dict["gi_stds"]) + inference_method = make_inference_method(pipeline) + truth_data_config = make_truth_data_configs(pipeline)[1] + inference_configs = make_inference_configs(pipeline) inference_config = rand(inference_configs) - truthdata = Dict("y_t" => fill(100, 28)) + truthdata = Dict("y_t" => fill(100, 28), "truth_gi_mean" => 1.5) - inference_results, inferencefile = generate_inference_results( - truthdata, inference_config; tspan, inference_method, - truth_data_config, latent_models_names) + inference_results = generate_inference_results( + truthdata, inference_config, pipeline; tspan, inference_method) @test inference_results["inference_results"] isa EpiAwareObservables end diff --git a/pipeline/test/test_make_configs.jl b/pipeline/test/test_make_configs.jl deleted file mode 100644 index c87feb8af..000000000 --- a/pipeline/test/test_make_configs.jl +++ /dev/null @@ -1,24 +0,0 @@ -@testset "make_truth_data_configs: I/O" begin - using .AnalysisPipeline, DrWatson - - gi_means = [1.0, 2.0, 3.0] - gi_stds = [0.5, 0.8, 1.2] - expected_output = Dict("gi_mean" => gi_means, "gi_std" => gi_stds) - @test make_truth_data_configs(gi_means = gi_means, gi_stds = gi_stds) == - dict_list(expected_output) -end - -@testset "make_inference_configs: I/O" begin - using .AnalysisPipeline, DrWatson, EpiAware - struct TestLatentModel <: AbstractLatentModel end - - latent_models = [TestLatentModel()] - gi_means = [2.0, 3.0, 4.0] - gi_stds = [1.0, 2.0, 3.0] - igps = [DirectInfections, Renewal] - expected_result = Dict( - "igp" => [DirectInfections, Renewal], "latent_model" => latent_models, - "gi_mean" => gi_means, "gi_std" => gi_stds) |> dict_list - @test make_inference_configs(latent_models = latent_models, gi_means = gi_means, - gi_stds = gi_stds, igps = igps) == expected_result -end