Skip to content

Commit

Permalink
Example mode and unit tests for pipeline functions. (#295)
Browse files Browse the repository at this point in the history
* unit tests for pipeline functions.

* reformat

* Move dispatch on example mode to the config list constructors with additional unit testing

* reformat

* Add _selector methods to dispatch different pipeline number of scenarios
  • Loading branch information
SamuelBrand1 authored Jun 19, 2024
1 parent 2b6ef61 commit 18f8aec
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 9 deletions.
1 change: 1 addition & 0 deletions pipeline/src/constructors/constructors.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
include("selector.jl")
include("make_gi_params.jl")
include("make_inf_generating_processes.jl")
include("make_model_priors.jl")
Expand Down
3 changes: 2 additions & 1 deletion pipeline/src/constructors/make_inference_configs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@ function make_inference_configs(pipeline::AbstractEpiAwarePipeline)
"gi_std" => gi_param_dict["gi_stds"], "log_I0_prior" => priors["log_I0_prior"]) |>
dict_list

return inference_configs
selected_inference_configs = _selector(inference_configs, pipeline)
return selected_inference_configs
end
9 changes: 5 additions & 4 deletions pipeline/src/constructors/make_truth_data_configs.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""
Create a dictionary of truth data configurations for `pipeline <: AbstractEpiAwarePipeline`.
This is the default method.
Create a vector of truth data configurations for `pipeline <: AbstractEpiAwarePipeline`.
# Returns
A vector of dictionaries containing the mean and standard deviation values for
Expand All @@ -9,7 +8,9 @@ A vector of dictionaries containing the mean and standard deviation values for
"""
function make_truth_data_configs(pipeline::AbstractEpiAwarePipeline)
gi_param_dict = make_gi_params(pipeline)
return Dict(
gi_param_dict_list = Dict(
"gi_mean" => gi_param_dict["gi_means"], "gi_std" => gi_param_dict["gi_stds"]) |>
dict_list
dict_list
selected_truth_data_configs = _selector(gi_param_dict_list, pipeline)
return selected_truth_data_configs
end
15 changes: 15 additions & 0 deletions pipeline/src/constructors/selector.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""
Internal method for selecting from a list of items based on the pipeline type.
Default is to return the list as is.
"""
function _selector(list, pipeline::AbstractEpiAwarePipeline)
return list
end

"""
Internal method for selecting from a list of items based on the pipeline type.
Example/test mode is to return a randomly selected item from the list.
"""
function _selector(list, pipeline::EpiAwareExamplePipeline)
return [rand(list)]
end
33 changes: 29 additions & 4 deletions pipeline/test/constructors/test_constructors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ end

@testset "make_truth_data_configs" begin
using EpiAwarePipeline
pipeline = EpiAwareExamplePipeline()
pipeline = RtwithoutRenewalPipeline()
example_pipeline = EpiAwareExamplePipeline()
@testset "make_truth_data_configs should return a dictionary" begin
config_dicts = make_truth_data_configs(pipeline)
@test eltype(config_dicts) <: Dict
Expand All @@ -96,14 +97,38 @@ end
@test all(config_dicts .|> config -> haskey(config, "gi_mean"))
@test all(config_dicts .|> config -> haskey(config, "gi_std"))
end

@testset "make_truth_data_configs should return a vector of length 1 for EpiAwareExamplePipeline" begin
config_dicts = make_truth_data_configs(example_pipeline)
@test length(config_dicts) == 1
end
end

@testset "default inference configurations" begin
using EpiAwarePipeline
pipeline = EpiAwareExamplePipeline()

inference_configs = make_inference_configs(pipeline)
@test eltype(inference_configs) <: Dict
pipeline = RtwithoutRenewalPipeline()
example_pipeline = EpiAwareExamplePipeline()

@testset "make_inference_configs should return a vector of dictionaries" begin
inference_configs = make_inference_configs(pipeline)
@test eltype(inference_configs) <: Dict
end

@testset "make_inference_configs should contain igp, latent_namemodels, observation_model, gi_mean, gi_std, and log_I0_prior keys" begin
inference_configs = make_inference_configs(pipeline)
@test inference_configs .|> (config -> haskey(config, "igp")) |> all
@test inference_configs .|> (config -> haskey(config, "latent_namemodels")) |> all
@test inference_configs .|> (config -> haskey(config, "observation_model")) |> all
@test inference_configs .|> (config -> haskey(config, "gi_mean")) |> all
@test inference_configs .|> (config -> haskey(config, "gi_std")) |> all
@test inference_configs .|> (config -> haskey(config, "log_I0_prior")) |> all
end

@testset "make_inference_configs should return a vector of length 1 for EpiAwareExamplePipeline" begin
inference_configs = make_inference_configs(example_pipeline)
@test length(inference_configs) == 1
end
end

@testset "make_default_params" begin
Expand Down
32 changes: 32 additions & 0 deletions pipeline/test/pipeline/test_pipelinefunctions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
@testset "do_truthdata tests" begin
using EpiAwarePipeline, Dagger
pipeline = EpiAwareExamplePipeline()
truthdata_dg_task = do_truthdata(pipeline)
truthdata = fetch.(truthdata_dg_task)

@test length(truthdata) == 1
@test all([data["y_t"] isa Vector{Union{Missing, Real}} for data in truthdata])
end

@testset "do_inference tests" begin
using EpiAwarePipeline
pipeline = EpiAwareExamplePipeline()

function make_inference()
truthdata = do_truthdata(pipeline)
do_inference(truthdata[1:1], pipeline)
end

inference_results_tsk = make_inference()
inference_results = fetch.(inference_results_tsk)
@test length(inference_results) == 1
@test all([result["inference_results"] isa EpiAwareObservables
for result in inference_results])
end

@testset "do_pipeline test: just run" begin
using EpiAwarePipeline
pipeline = EpiAwareExamplePipeline()
res = do_pipeline(pipeline)
@test isnothing(res)
end
1 change: 1 addition & 0 deletions pipeline/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ quickactivate(@__DIR__(), "EpiAwarePipeline")

# Run tests
include("pipeline/test_pipelinetypes.jl");
include("pipeline/test_pipelinefunctions.jl");
include("constructors/test_constructors.jl");
include("simulate/test_TruthSimulationConfig.jl");
include("simulate/test_SimulationConfig.jl");
Expand Down

0 comments on commit 18f8aec

Please sign in to comment.