Skip to content

Commit

Permalink
beef up sampling to test
Browse files Browse the repository at this point in the history
Also removed unnecessary call to `fetch`
  • Loading branch information
SamuelBrand1 committed Dec 18, 2024
1 parent 3d624dc commit b8d6f1f
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 6 deletions.
22 changes: 21 additions & 1 deletion pipeline/src/constructors/make_observation_model.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Constructs an observation model for the given pipeline. This is the defualt method.
Constructs an observation model for the given pipeline. This is the default method.
# Arguments
- `pipeline::AbstractEpiAwarePipeline`: The pipeline for which the observation model is constructed.
Expand All @@ -18,3 +18,23 @@ function make_observation_model(pipeline::AbstractEpiAwarePipeline)
obs = LatentDelay(dayofweek_logit_ascert, delay_distribution)
return obs
end

const negC = -1e15
"""
Soft minimum function for a smooth transition from `x -> x` to a maximum value of 1e15.
"""
_softmin(x) = -logaddexp(negC, -x)

function make_observation_model(pipeline::AbstractRtwithoutRenewalPipeline)
default_params = make_default_params(pipeline)
#Model for ascertainment based on day of the week
dayofweek_logit_ascert = ascertainment_dayofweek(
NegativeBinomialError(cluster_factor_prior = HalfNormal(default_params["cluster_factor"]));
transform = (x, y) -> _softmin.(x .* y))

#Default continuous-time model for latent delay in observations
delay_distribution = make_delay_distribution(pipeline)
#Model for latent delay in observations
obs = LatentDelay(dayofweek_logit_ascert, delay_distribution)
return obs
end
8 changes: 3 additions & 5 deletions pipeline/test/pipeline/test_pipelinefunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,14 @@ end

@testset "do_inference tests" begin
function make_inference(pipeline)
truthdata_dg_task = do_truthdata(pipeline)
truthdata = fetch.(truthdata_dg_task)
truthdata = do_truthdata(pipeline)
do_inference(truthdata[1], pipeline)
end

for pipetype in [SmoothOutbreakPipeline, MeasuresOutbreakPipeline,
SmoothEndemicPipeline, RoughEndemicPipeline]
pipeline = pipetype(; ndraws = 20, nchains = 1, testmode = true)
inference_results_tsk = make_inference(pipeline)
inference_results = fetch.(inference_results_tsk)
pipeline = pipetype(; ndraws = 1000, nchains = 1, testmode = true)
inference_results = make_inference(pipeline)
@test length(inference_results) == 1
@test all([result["inference_results"] isa EpiAwareObservables
for result in inference_results])
Expand Down

0 comments on commit b8d6f1f

Please sign in to comment.