From b8d6f1f441aa4cc96e557c0c9748f46473790e6b Mon Sep 17 00:00:00 2001 From: Samuel Brand <48288458+SamuelBrand1@users.noreply.github.com> Date: Wed, 18 Dec 2024 23:55:44 +0000 Subject: [PATCH] beef up sampling to test Also removed unnecessary call to `fetch` --- .../constructors/make_observation_model.jl | 22 ++++++++++++++++++- .../test/pipeline/test_pipelinefunctions.jl | 8 +++---- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/pipeline/src/constructors/make_observation_model.jl b/pipeline/src/constructors/make_observation_model.jl index 1fb1255aa..9ad4d6020 100644 --- a/pipeline/src/constructors/make_observation_model.jl +++ b/pipeline/src/constructors/make_observation_model.jl @@ -1,5 +1,5 @@ """ -Constructs an observation model for the given pipeline. This is the defualt method. +Constructs an observation model for the given pipeline. This is the default method. # Arguments - `pipeline::AbstractEpiAwarePipeline`: The pipeline for which the observation model is constructed. @@ -18,3 +18,23 @@ function make_observation_model(pipeline::AbstractEpiAwarePipeline) obs = LatentDelay(dayofweek_logit_ascert, delay_distribution) return obs end + +const negC = -1e15 +""" +Soft minimum function for a smooth transition from `x -> x` to a maximum value of 1e15. +""" +_softmin(x) = -logaddexp(negC, -x) + +function make_observation_model(pipeline::AbstractRtwithoutRenewalPipeline) + default_params = make_default_params(pipeline) + #Model for ascertainment based on day of the week + dayofweek_logit_ascert = ascertainment_dayofweek( + NegativeBinomialError(cluster_factor_prior = HalfNormal(default_params["cluster_factor"])); + transform = (x, y) -> _softmin.(x .* y)) + + #Default continuous-time model for latent delay in observations + delay_distribution = make_delay_distribution(pipeline) + #Model for latent delay in observations + obs = LatentDelay(dayofweek_logit_ascert, delay_distribution) + return obs +end diff --git a/pipeline/test/pipeline/test_pipelinefunctions.jl b/pipeline/test/pipeline/test_pipelinefunctions.jl index af96cb2cb..6a27c88b3 100644 --- a/pipeline/test/pipeline/test_pipelinefunctions.jl +++ b/pipeline/test/pipeline/test_pipelinefunctions.jl @@ -13,16 +13,14 @@ end @testset "do_inference tests" begin function make_inference(pipeline) - truthdata_dg_task = do_truthdata(pipeline) - truthdata = fetch.(truthdata_dg_task) + truthdata = do_truthdata(pipeline) do_inference(truthdata[1], pipeline) end for pipetype in [SmoothOutbreakPipeline, MeasuresOutbreakPipeline, SmoothEndemicPipeline, RoughEndemicPipeline] - pipeline = pipetype(; ndraws = 20, nchains = 1, testmode = true) - inference_results_tsk = make_inference(pipeline) - inference_results = fetch.(inference_results_tsk) + pipeline = pipetype(; ndraws = 1000, nchains = 1, testmode = true) + inference_results = make_inference(pipeline) @test length(inference_results) == 1 @test all([result["inference_results"] isa EpiAwareObservables for result in inference_results])