From 576e0ee08eef0c41184e5d02a20c56e883b47aa4 Mon Sep 17 00:00:00 2001 From: Samuel Brand <48288458+SamuelBrand1@users.noreply.github.com> Date: Tue, 14 May 2024 16:31:25 +0100 Subject: [PATCH] splits simulate_or_infer into simulate and infer plus updated unit tests (#222) --- pipeline/src/AnalysisPipeline.jl | 2 +- pipeline/src/InferenceConfig.jl | 2 +- pipeline/src/TruthSimulationConfig.jl | 4 +--- pipeline/src/generate_inference_results.jl | 2 +- pipeline/src/generate_truthdata.jl | 4 ++-- pipeline/test/test_TruthSimulationConfig.jl | 6 +++--- 6 files changed, 9 insertions(+), 11 deletions(-) diff --git a/pipeline/src/AnalysisPipeline.jl b/pipeline/src/AnalysisPipeline.jl index e6f192a42..aa06c22b8 100644 --- a/pipeline/src/AnalysisPipeline.jl +++ b/pipeline/src/AnalysisPipeline.jl @@ -11,7 +11,7 @@ using CSV, Dagger, DataFramesMeta, Dates, Distributions, DocStringExtensions, Dr export TruthSimulationConfig, InferenceConfig # Exported functions -export simulate_or_infer, default_gi_params, default_Rt, default_tspan, +export simulate, infer, default_gi_params, default_Rt, default_tspan, default_latent_model_priors, default_epiaware_models, default_inference_method, default_latent_models_names, make_truth_data_configs, make_inference_configs, generate_truthdata_from_config, generate_inference_results, plot_truth_data, plot_Rt diff --git a/pipeline/src/InferenceConfig.jl b/pipeline/src/InferenceConfig.jl index db818db4d..b52109b51 100644 --- a/pipeline/src/InferenceConfig.jl +++ b/pipeline/src/InferenceConfig.jl @@ -64,7 +64,7 @@ to make inference on and model configuration. - `inference_results`: The results of the simulation or inference. """ -function simulate_or_infer(config::InferenceConfig) +function infer(config::InferenceConfig) #Define infection-generating model shape = (config.gi_mean / config.gi_std)^2 scale = config.gi_std^2 / config.gi_mean diff --git a/pipeline/src/TruthSimulationConfig.jl b/pipeline/src/TruthSimulationConfig.jl index ee78bf92d..4289764b8 100644 --- a/pipeline/src/TruthSimulationConfig.jl +++ b/pipeline/src/TruthSimulationConfig.jl @@ -29,8 +29,6 @@ mean and standard deviation. end """ - simulate_or_infer(config::TruthSimulationConfig) - Simulates or infers the truth process and observations based on the given configuration. # Arguments @@ -40,7 +38,7 @@ Simulates or infers the truth process and observations based on the given config A dictionary containing the sampled infections and observations, along with other relevant information. """ -function simulate_or_infer(config::TruthSimulationConfig) +function simulate(config::TruthSimulationConfig) #Define infection-generating model shape = (config.gi_mean / config.gi_std)^2 scale = config.gi_std^2 / config.gi_mean diff --git a/pipeline/src/generate_inference_results.jl b/pipeline/src/generate_inference_results.jl index 0d6dd1fc1..2de5697cf 100644 --- a/pipeline/src/generate_inference_results.jl +++ b/pipeline/src/generate_inference_results.jl @@ -29,6 +29,6 @@ function generate_inference_results(truthdata, inference_config; tspan, inferenc string(truth_data_config["gi_mean"]) inference_results, inferencefile = produce_or_load( - simulate_or_infer, config, datadir(datadir_name); prefix = prfx) + infer, config, datadir(datadir_name); prefix = prfx) return inference_results, inferencefile end diff --git a/pipeline/src/generate_truthdata.jl b/pipeline/src/generate_truthdata.jl index dedda0041..768042819 100644 --- a/pipeline/src/generate_truthdata.jl +++ b/pipeline/src/generate_truthdata.jl @@ -1,5 +1,5 @@ """ -Generate truth data from a configuration file. It does this by converting the configuration dictionary into a `TruthSimulationConfig` object and then calling the `simulate_or_infer` function to generate the truth data. +Generate truth data from a configuration file. It does this by converting the configuration dictionary into a `TruthSimulationConfig` object and then calling the `simulate` function to generate the truth data. # Arguments - `truth_data_config`: A dictionary containing the configuration parameters for generating truth data. @@ -18,7 +18,7 @@ function generate_truthdata_from_config( truth_process = true_Rt, gi_mean = truth_data_config["gi_mean"], gi_std = truth_data_config["gi_std"]) truthdata, truthfile = produce_or_load( - simulate_or_infer, config, datadir(datadir_str); prefix = prefix) + simulate, config, datadir(datadir_str); prefix = prefix) if plot plot_truth_data(truthdata, config) end diff --git a/pipeline/test/test_TruthSimulationConfig.jl b/pipeline/test/test_TruthSimulationConfig.jl index 9da52eb21..1e87afcc0 100644 --- a/pipeline/test/test_TruthSimulationConfig.jl +++ b/pipeline/test/test_TruthSimulationConfig.jl @@ -1,10 +1,10 @@ -@testset "simulate_or_infer: simulate runs" begin +@testset "simulate runs" begin using Distributions, .AnalysisPipeline, EpiAware # Define a mock TruthSimulationConfig object for testing config = TruthSimulationConfig( truth_process = fill(1.5, 10), gi_mean = 2.0, gi_std = 2.0) - # Test the simulate_or_infer function - result = simulate_or_infer(config) + # Test the simulate function + result = simulate(config) @test haskey(result, "I_t") @test haskey(result, "y_t")