From d034601aa2f67b32b7f9e1886c0dabf4bc55652d Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Thu, 7 Mar 2024 17:18:12 +0000 Subject: [PATCH 1/6] Change prior for negative binomial observations --- EpiAware/docs/src/examples/getting_started.jl | 13 +++++-------- EpiAware/src/observation-models.jl | 5 +++-- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/EpiAware/docs/src/examples/getting_started.jl b/EpiAware/docs/src/examples/getting_started.jl index 49ced9764..ecd1b23f0 100644 --- a/EpiAware/docs/src/examples/getting_started.jl +++ b/EpiAware/docs/src/examples/getting_started.jl @@ -1,5 +1,5 @@ ### A Pluto.jl notebook ### -# v0.19.39 +# v0.19.40 using Markdown using InteractiveUtils @@ -149,7 +149,7 @@ as the action of a sparse kernel on the infections $I(t)$. ```math \begin{align} y_t &\sim \text{NegBinomial}(\mu = \sum_{s\geq 0} K[t, t-s] I(s), r), \\ -1 / r &\sim \text{Gamma}(3, 0.05/3). +1 / \sqrt{r} &\sim \text{HalfNormal}\Big(0.1 \sqrt{{\pi \over 2}}\Big). \end{align} ``` " @@ -171,7 +171,7 @@ We choose a simple observation model where infections are observed 0, 1, 2, 3 da obs_model = EpiAware.DelayObservations( fill(0.25, 4), time_horizon, - truncated(Gamma(5, 0.05 / 5), 1e-3, 1.0) + truncated(Normal(0, 0.1 * sqrt(pi) / sqrt(2)), 0.0, Inf) ) # ╔═╡ e49713e8-4840-4083-8e3f-fc52d791be7b @@ -314,8 +314,6 @@ let size = (700, 400)) end -# ╔═╡ 2293b711-0bd0-44d5-8a30-94e56c5e4c65 - # ╔═╡ fd6321b1-4c3a-4123-b0dc-c45b951e0b80 md" As well as checking the posterior predictions for latent infections, we can also check how well inference recovered unknown parameters, such as the random walk variance or the cluster factor of the negative binomial observations. @@ -386,7 +384,7 @@ end # ╠═6639e66f-7725-4976-81b2-6472419d1a62 # ╟─df5e59f8-3185-4bed-9cca-7c266df17cec # ╠═6fbdd8e6-2323-4352-9185-1f31a9cf9012 -# ╟─5e62a50a-71f4-4902-b1c9-fdf51fe145fa +# ╠═5e62a50a-71f4-4902-b1c9-fdf51fe145fa # ╟─e813d547-6100-4c43-b84c-8cebe306bda8 # ╠═c7580ae6-0db5-448e-8b20-4dd6fcdb1ae0 # ╟─0aa3fcbd-0831-45b8-9a2c-7ffbabf5895f @@ -401,7 +399,7 @@ end # ╠═d073e63b-62da-4743-ace0-78ef7806bc0b # ╠═a04f3c1b-7e11-4800-9c2a-9fc0021de6e7 # ╟─f68b4e41-ac5c-42cd-a8c2-8761d66f7543 -# ╠═b5bc8f05-b538-4abf-aa84-450bf2dff3d9 +# ╟─b5bc8f05-b538-4abf-aa84-450bf2dff3d9 # ╠═c8ce0d46-a160-4c40-a055-69b3d10d1770 # ╟─4a4c6e91-8d8f-4bbf-bb7e-a36dc281e312 # ╠═259a7042-e74f-43c7-aeb4-97a3beeb7776 @@ -410,7 +408,6 @@ end # ╠═3eb5ec5e-aae7-478e-84fb-80f2e9f85b4c # ╟─30498cc7-16a5-441a-b8cd-c19b220c60c1 # ╠═e9df22b8-8e4d-4ab7-91ea-c01f2239b3e5 -# ╠═2293b711-0bd0-44d5-8a30-94e56c5e4c65 # ╟─fd6321b1-4c3a-4123-b0dc-c45b951e0b80 # ╠═10d8fe24-83a6-47ac-97b7-a374481473d3 # ╟─81efe8ca-b753-4a12-bafc-a887a999377b diff --git a/EpiAware/src/observation-models.jl b/EpiAware/src/observation-models.jl index 39b2b322d..d1599ef93 100644 --- a/EpiAware/src/observation-models.jl +++ b/EpiAware/src/observation-models.jl @@ -26,7 +26,8 @@ struct DelayObservations{T <: AbstractFloat, S <: Sampleable} <: AbstractObserva end function default_delay_obs_priors() - return (:neg_bin_cluster_factor_prior => Gamma(3, 0.05 / 3),) |> Dict + return (:neg_bin_cluster_factor_prior => truncated( + Normal(0, 0.1 * sqrt(pi) / sqrt(2)), 0.0, Inf),) |> Dict end function generate_observations(observation_model::AbstractObservationModel, @@ -54,7 +55,7 @@ end for i in eachindex(y_t) y_t[i] ~ NegativeBinomialMeanClust( - expected_obs[i], neg_bin_cluster_factor + expected_obs[i], neg_bin_cluster_factor^2 ) end From a048da5fa2e0aab079589815667b6f28fed47a86 Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Thu, 7 Mar 2024 17:19:40 +0000 Subject: [PATCH 2/6] Fix test to reflect that the `r` parameter is now the square of the cluster factor --- EpiAware/test/test_observation-models.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/EpiAware/test/test_observation-models.jl b/EpiAware/test/test_observation-models.jl index 66b33531b..df6b4f7c9 100644 --- a/EpiAware/test/test_observation-models.jl +++ b/EpiAware/test/test_observation-models.jl @@ -25,7 +25,7 @@ chn -> generated_quantities(fix_mdl, chn) .|> (gen -> gen[1][1]) |> vec - direct_samples = EpiAware.NegativeBinomialMeanClust(I_t[1], neg_bin_cf) |> + direct_samples = EpiAware.NegativeBinomialMeanClust(I_t[1], neg_bin_cf^2) |> dist -> rand(dist, n_samples) #For discrete distributions, checking mean and variance is as expected From e669ff80b2e98666f6e2c430c601b8bddc29b647 Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Fri, 8 Mar 2024 12:26:56 +0000 Subject: [PATCH 3/6] More information on choice of variance prior --- EpiAware/docs/src/examples/getting_started.jl | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/EpiAware/docs/src/examples/getting_started.jl b/EpiAware/docs/src/examples/getting_started.jl index ecd1b23f0..1ce170fc0 100644 --- a/EpiAware/docs/src/examples/getting_started.jl +++ b/EpiAware/docs/src/examples/getting_started.jl @@ -143,14 +143,21 @@ md" ### Delayed Observations `ObservationModel` -The observation model is a negative binomial distribution with mean `μ` and cluster factor `1 / r`. Delays are implemented -as the action of a sparse kernel on the infections $I(t)$. +The observation model is a negative binomial distribution parameterised with mean $\mu$ and 'successes' parameter $r$. The standard deviation _relative_ to the mean $\sigma_{\text{rel}} = \sigma / \mu$ for negative binomial observations is, ```math -\begin{align} -y_t &\sim \text{NegBinomial}(\mu = \sum_{s\geq 0} K[t, t-s] I(s), r), \\ -1 / \sqrt{r} &\sim \text{HalfNormal}\Big(0.1 \sqrt{{\pi \over 2}}\Big). -\end{align} +\sigma_{\text{rel}} =(1/\sqrt{\mu}) + (1 / \sqrt{r}). +``` +It is standard to use a half-t distribution for standard deviation priors (e.g. as argued in this [paper](http://www.stat.columbia.edu/~gelman/research/published/taumain.pdf)); we specialise this to a Half-Normal prior and use an _a priori_ assumption that a typical observation fluctuation around the mean (when the mean is $\sim\mathcal{O}(10^2)$) would be 10%. This implies a standard deviation prior, +```math +1 / \sqrt{r} \sim \text{HalfNormal}\Big(0.1 ~\sqrt{{\pi \over 2}}\Big). +``` +The $\sqrt{{\pi \over 2}}$ factor ensures the correct prior mean (see [here](https://en.wikipedia.org/wiki/Half-normal_distribution)). + +The expected observed cases are delayed infections. Delays are implemented as the action of a sparse kernel on the infections $I(t)$. + +```math +y_t \sim \text{NegBinomial}\Big(\mu = \sum_{s\geq 0} K[t, t-s] I(s), r\Big). \\ ``` " @@ -384,7 +391,7 @@ end # ╠═6639e66f-7725-4976-81b2-6472419d1a62 # ╟─df5e59f8-3185-4bed-9cca-7c266df17cec # ╠═6fbdd8e6-2323-4352-9185-1f31a9cf9012 -# ╠═5e62a50a-71f4-4902-b1c9-fdf51fe145fa +# ╟─5e62a50a-71f4-4902-b1c9-fdf51fe145fa # ╟─e813d547-6100-4c43-b84c-8cebe306bda8 # ╠═c7580ae6-0db5-448e-8b20-4dd6fcdb1ae0 # ╟─0aa3fcbd-0831-45b8-9a2c-7ffbabf5895f From ee33080e7ddffc1d4ca88b24fee228529383c679 Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Fri, 8 Mar 2024 12:56:39 +0000 Subject: [PATCH 4/6] Removed unnecessary `nchains` arg --- EpiAware/src/inference-methods.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/EpiAware/src/inference-methods.jl b/EpiAware/src/inference-methods.jl index e105390cf..8218d8de7 100644 --- a/EpiAware/src/inference-methods.jl +++ b/EpiAware/src/inference-methods.jl @@ -95,8 +95,7 @@ largest ELBO estimate. - `best_pfs::PathfinderResult`: Best pathfinder result by estimated ELBO. """ function manypathfinder(mdl::DynamicPPL.Model, ndraws; nruns = 4, - nchains = 4, maxiters = 50, max_tries = 100, kwargs...) - ndraws = max(ndraws, nchains) + maxiters = 50, max_tries = 100, kwargs...) _run_manypathfinder(mdl; nruns, ndraws, maxiters, kwargs...) |> pfs -> _continue_manypathfinder!(pfs, mdl; max_tries, nruns, kwargs...) |> pfs -> _get_best_elbo_pathfinder(pfs) From 0b426186a4a046ea1d8a6a14c2de967ea959ac6e Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Fri, 8 Mar 2024 12:56:57 +0000 Subject: [PATCH 5/6] Combined new prior into getting started --- EpiAware/docs/src/examples/getting_started.jl | 46 +++++++++++++++++-- 1 file changed, 42 insertions(+), 4 deletions(-) diff --git a/EpiAware/docs/src/examples/getting_started.jl b/EpiAware/docs/src/examples/getting_started.jl index 1ce170fc0..555145e1e 100644 --- a/EpiAware/docs/src/examples/getting_started.jl +++ b/EpiAware/docs/src/examples/getting_started.jl @@ -27,6 +27,7 @@ begin using Statistics using DataFramesMeta using LinearAlgebra + using Transducers end # ╔═╡ 3ebc8384-f73d-4597-83a7-07a3744fed61 @@ -243,14 +244,14 @@ However, we now treat the generated data as `truth_data` and make inference with We do the inference by MCMC/NUTS using the `Turing` NUTS sampler with default warm-up steps. " -# ╔═╡ c8ce0d46-a160-4c40-a055-69b3d10d1770 -truth_data = generated_obs - # ╔═╡ 4a4c6e91-8d8f-4bbf-bb7e-a36dc281e312 md" The observation model supports partially complete data. To test this we set some of the generated observations to be `missing`. " +# ╔═╡ 525aa98c-d0e5-4ffa-b808-d90fc986204c +truth_data = generated_obs + # ╔═╡ 259a7042-e74f-43c7-aeb4-97a3beeb7776 let truth_data = Union{Int, Missing}[truth_data...] @@ -270,12 +271,43 @@ inference_mdl = fix( (rw_init = 0.0,) ) +# ╔═╡ 9222b436-9445-4039-abbf-25c8cddb7f63 +md" +### Initialising inference + +It is possible for the default warm-up process for NUTS to get stuck in low probability or otherwise degenerate regions of parameter space. + +To make NUTS more robust we provide `manypathfinder`, which is built on pathfinder variational inference from [Pathfinder.jl](https://mlcolab.github.io/Pathfinder.jl/stable/). `manypathfinder` runs `nruns` pathfinder processes on the inference problem and returns the pathfinder run with maximum estimated ELBO. + +`manypathfinder` differs from `Pathfinder.multipathfinder`; `multipathfinder` is aimed at sampling from a potentially non-Gaussian target distribution which is first approximated as a uniformly weighted collection of normal approximations from pathfinder runs. `manypathfinder` is aimed at moving rapidly to a 'good' part of parameter space, and is robust to runs that fail. +" + +# ╔═╡ 197a4fbb-b71a-475a-bb78-28ff613e3094 +best_pf = manypathfinder(inference_mdl, 10; nruns = 20, executor = Transducers.ThreadedEx()); + +# ╔═╡ 073a1d40-456a-450e-969f-11b23eb7fd1f +md" +We can use draws from the best pathfinder run to initialise NUTS. +" + +# ╔═╡ 0379b058-4c35-440a-bc01-aafa0178bdbf +best_pf.draws_transformed + +# ╔═╡ a7798f71-9bb5-4506-9476-0cc11553b9e2 +init_params = collect.(eachrow(best_pf.draws_transformed.value[1:4, :, 1])) + +# ╔═╡ 4deb3a51-781d-48c4-91f6-6adf2b1affcf +md" +**NB: We are running this inference run for speed rather than accuracy as a demonstration. Use a higher target acceptance and more samples in a typical workflow.** +" + # ╔═╡ 3eb5ec5e-aae7-478e-84fb-80f2e9f85b4c chn = sample(inference_mdl, NUTS(; adtype = AutoReverseDiff(true)), MCMCThreads(), 250, 4; + init_params, drop_warmup = true) # ╔═╡ 30498cc7-16a5-441a-b8cd-c19b220c60c1 @@ -407,11 +439,17 @@ end # ╠═a04f3c1b-7e11-4800-9c2a-9fc0021de6e7 # ╟─f68b4e41-ac5c-42cd-a8c2-8761d66f7543 # ╟─b5bc8f05-b538-4abf-aa84-450bf2dff3d9 -# ╠═c8ce0d46-a160-4c40-a055-69b3d10d1770 # ╟─4a4c6e91-8d8f-4bbf-bb7e-a36dc281e312 +# ╠═525aa98c-d0e5-4ffa-b808-d90fc986204c # ╠═259a7042-e74f-43c7-aeb4-97a3beeb7776 # ╟─32638954-2c99-4d4e-8e03-52154030c657 # ╠═b4033728-b321-4100-8194-1fd9fe2d268d +# ╟─9222b436-9445-4039-abbf-25c8cddb7f63 +# ╠═197a4fbb-b71a-475a-bb78-28ff613e3094 +# ╠═073a1d40-456a-450e-969f-11b23eb7fd1f +# ╠═0379b058-4c35-440a-bc01-aafa0178bdbf +# ╠═a7798f71-9bb5-4506-9476-0cc11553b9e2 +# ╟─4deb3a51-781d-48c4-91f6-6adf2b1affcf # ╠═3eb5ec5e-aae7-478e-84fb-80f2e9f85b4c # ╟─30498cc7-16a5-441a-b8cd-c19b220c60c1 # ╠═e9df22b8-8e4d-4ab7-91ea-c01f2239b3e5 From cd91ae27d594dab3007620d6125edeb8aab9a373 Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Fri, 8 Mar 2024 13:15:14 +0000 Subject: [PATCH 6/6] fix broken manypathfinder test --- EpiAware/test/test_inference-methods.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/EpiAware/test/test_inference-methods.jl b/EpiAware/test/test_inference-methods.jl index c7f10e186..4a36b7aba 100644 --- a/EpiAware/test/test_inference-methods.jl +++ b/EpiAware/test/test_inference-methods.jl @@ -111,7 +111,7 @@ end maxiters = 50 max_tries = 100 - best_pf = manypathfinder(mdl, ndraws; nruns = nruns, nchains = nchains, + best_pf = manypathfinder(mdl, ndraws; nruns = nruns, maxiters = maxiters, max_tries = max_tries) @test best_pf isa PathfinderResult @@ -128,7 +128,7 @@ end maxiters = 50 max_tries = 10 - best_pf = manypathfinder(mdl, ndraws; nruns = nruns, nchains = nchains, + best_pf = manypathfinder(mdl, ndraws; nruns = nruns, maxiters = maxiters, max_tries = max_tries) pathfinder_samples = best_pf.draws |> vec