diff --git a/pipeline/src/constructors/constructors.jl b/pipeline/src/constructors/constructors.jl index 68e02819d..630502c4f 100644 --- a/pipeline/src/constructors/constructors.jl +++ b/pipeline/src/constructors/constructors.jl @@ -1,3 +1,4 @@ +include("selector.jl") include("make_gi_params.jl") include("make_inf_generating_processes.jl") include("make_model_priors.jl") diff --git a/pipeline/src/constructors/make_inference_configs.jl b/pipeline/src/constructors/make_inference_configs.jl index 53f38be85..ab0d3815e 100644 --- a/pipeline/src/constructors/make_inference_configs.jl +++ b/pipeline/src/constructors/make_inference_configs.jl @@ -20,5 +20,6 @@ function make_inference_configs(pipeline::AbstractEpiAwarePipeline) "gi_std" => gi_param_dict["gi_stds"], "log_I0_prior" => priors["log_I0_prior"]) |> dict_list - return inference_configs + selected_inference_configs = _selector(inference_configs, pipeline) + return selected_inference_configs end diff --git a/pipeline/src/constructors/make_truth_data_configs.jl b/pipeline/src/constructors/make_truth_data_configs.jl index f122c1179..59db1f000 100644 --- a/pipeline/src/constructors/make_truth_data_configs.jl +++ b/pipeline/src/constructors/make_truth_data_configs.jl @@ -1,6 +1,5 @@ """ -Create a dictionary of truth data configurations for `pipeline <: AbstractEpiAwarePipeline`. - This is the default method. +Create a vector of truth data configurations for `pipeline <: AbstractEpiAwarePipeline`. # Returns A vector of dictionaries containing the mean and standard deviation values for @@ -9,7 +8,9 @@ A vector of dictionaries containing the mean and standard deviation values for """ function make_truth_data_configs(pipeline::AbstractEpiAwarePipeline) gi_param_dict = make_gi_params(pipeline) - return Dict( + gi_param_dict_list = Dict( "gi_mean" => gi_param_dict["gi_means"], "gi_std" => gi_param_dict["gi_stds"]) |> - dict_list + dict_list + selected_truth_data_configs = _selector(gi_param_dict_list, pipeline) + return selected_truth_data_configs end diff --git a/pipeline/src/constructors/selector.jl b/pipeline/src/constructors/selector.jl new file mode 100644 index 000000000..65cf4d7b8 --- /dev/null +++ b/pipeline/src/constructors/selector.jl @@ -0,0 +1,15 @@ +""" +Internal method for selecting from a list of items based on the pipeline type. +Default is to return the list as is. +""" +function _selector(list, pipeline::AbstractEpiAwarePipeline) + return list +end + +""" +Internal method for selecting from a list of items based on the pipeline type. +Example/test mode is to return a randomly selected item from the list. +""" +function _selector(list, pipeline::EpiAwareExamplePipeline) + return [rand(list)] +end diff --git a/pipeline/test/constructors/test_constructors.jl b/pipeline/test/constructors/test_constructors.jl index eaf7ab948..5f2e07aba 100644 --- a/pipeline/test/constructors/test_constructors.jl +++ b/pipeline/test/constructors/test_constructors.jl @@ -85,7 +85,8 @@ end @testset "make_truth_data_configs" begin using EpiAwarePipeline - pipeline = EpiAwareExamplePipeline() + pipeline = RtwithoutRenewalPipeline() + example_pipeline = EpiAwareExamplePipeline() @testset "make_truth_data_configs should return a dictionary" begin config_dicts = make_truth_data_configs(pipeline) @test eltype(config_dicts) <: Dict @@ -96,14 +97,38 @@ end @test all(config_dicts .|> config -> haskey(config, "gi_mean")) @test all(config_dicts .|> config -> haskey(config, "gi_std")) end + + @testset "make_truth_data_configs should return a vector of length 1 for EpiAwareExamplePipeline" begin + config_dicts = make_truth_data_configs(example_pipeline) + @test length(config_dicts) == 1 + end end @testset "default inference configurations" begin using EpiAwarePipeline - pipeline = EpiAwareExamplePipeline() - inference_configs = make_inference_configs(pipeline) - @test eltype(inference_configs) <: Dict + pipeline = RtwithoutRenewalPipeline() + example_pipeline = EpiAwareExamplePipeline() + + @testset "make_inference_configs should return a vector of dictionaries" begin + inference_configs = make_inference_configs(pipeline) + @test eltype(inference_configs) <: Dict + end + + @testset "make_inference_configs should contain igp, latent_namemodels, observation_model, gi_mean, gi_std, and log_I0_prior keys" begin + inference_configs = make_inference_configs(pipeline) + @test inference_configs .|> (config -> haskey(config, "igp")) |> all + @test inference_configs .|> (config -> haskey(config, "latent_namemodels")) |> all + @test inference_configs .|> (config -> haskey(config, "observation_model")) |> all + @test inference_configs .|> (config -> haskey(config, "gi_mean")) |> all + @test inference_configs .|> (config -> haskey(config, "gi_std")) |> all + @test inference_configs .|> (config -> haskey(config, "log_I0_prior")) |> all + end + + @testset "make_inference_configs should return a vector of length 1 for EpiAwareExamplePipeline" begin + inference_configs = make_inference_configs(example_pipeline) + @test length(inference_configs) == 1 + end end @testset "make_default_params" begin diff --git a/pipeline/test/pipeline/test_pipelinefunctions.jl b/pipeline/test/pipeline/test_pipelinefunctions.jl new file mode 100644 index 000000000..45e8a6123 --- /dev/null +++ b/pipeline/test/pipeline/test_pipelinefunctions.jl @@ -0,0 +1,32 @@ +@testset "do_truthdata tests" begin + using EpiAwarePipeline, Dagger + pipeline = EpiAwareExamplePipeline() + truthdata_dg_task = do_truthdata(pipeline) + truthdata = fetch.(truthdata_dg_task) + + @test length(truthdata) == 1 + @test all([data["y_t"] isa Vector{Union{Missing, Real}} for data in truthdata]) +end + +@testset "do_inference tests" begin + using EpiAwarePipeline + pipeline = EpiAwareExamplePipeline() + + function make_inference() + truthdata = do_truthdata(pipeline) + do_inference(truthdata[1:1], pipeline) + end + + inference_results_tsk = make_inference() + inference_results = fetch.(inference_results_tsk) + @test length(inference_results) == 1 + @test all([result["inference_results"] isa EpiAwareObservables + for result in inference_results]) +end + +@testset "do_pipeline test: just run" begin + using EpiAwarePipeline + pipeline = EpiAwareExamplePipeline() + res = do_pipeline(pipeline) + @test isnothing(res) +end diff --git a/pipeline/test/runtests.jl b/pipeline/test/runtests.jl index e349f6441..ce97f3d9d 100644 --- a/pipeline/test/runtests.jl +++ b/pipeline/test/runtests.jl @@ -3,6 +3,7 @@ quickactivate(@__DIR__(), "EpiAwarePipeline") # Run tests include("pipeline/test_pipelinetypes.jl"); +include("pipeline/test_pipelinefunctions.jl"); include("constructors/test_constructors.jl"); include("simulate/test_TruthSimulationConfig.jl"); include("simulate/test_SimulationConfig.jl");