diff --git a/EpiAware/docs/src/examples/getting_started.jl b/EpiAware/docs/src/examples/getting_started.jl index d9e49bb6c..555145e1e 100644 --- a/EpiAware/docs/src/examples/getting_started.jl +++ b/EpiAware/docs/src/examples/getting_started.jl @@ -4,7 +4,8 @@ using Markdown using InteractiveUtils -# ╔═╡ 4680eed2-dbd8-11ee-17d6-2552317711c1 +# ╔═╡ c593a2a0-d7f5-11ee-0931-d9f65ae84a72 +# hideall let docs_dir = dirname(dirname(@__DIR__)) pkg_dir = dirname(docs_dir) @@ -12,11 +13,10 @@ let using Pkg: Pkg Pkg.activate(docs_dir) Pkg.develop(; path = pkg_dir) - Pkg.resolve() Pkg.instantiate() -end +end; -# ╔═╡ 914b0cd7-86ca-4227-96cb-7cd887956833 +# ╔═╡ da479d8d-1312-4b98-b0af-5be52dffaf3f begin using EpiAware using Turing @@ -27,21 +27,15 @@ begin using Statistics using DataFramesMeta using LinearAlgebra - using Pathfinder using Transducers +end - Random.seed!(1) -end; - -# ╔═╡ 7c22d80b-2d52-4935-b62d-cbc64437195c +# ╔═╡ 3ebc8384-f73d-4597-83a7-07a3744fed61 md" # Getting stated with `EpiAware` This tutorial introduces the basic functionality of `EpiAware`. `EpiAware` is a package for making inferences on epidemiological case/determined infection data using a model-based approach. -" -# ╔═╡ 683743e6-9426-4b95-994b-4a579aa2564d -md" ## `EpiAware` models The models we consider are discrete-time $t = 1,\dots, T$ with a latent random process, $Z_t$ generating stochasticity in the number of new infections $I_t$ at each time step. Observations are treated as downstream random variables determined by the actual infections and a model of infection to observation delay. @@ -74,7 +68,7 @@ An `EpiAware` model in code is created from three modular components: Where $g_s$ is a discrete generation interval. For this reason, even when not using a reproductive number approach directly, we ask for a generation interval. " -# ╔═╡ 0b554dd9-79c7-44bc-9cbf-ea56439cb80d +# ╔═╡ 5a0d5ab8-e985-4126-a1ac-58fe08beee38 md" ## Random walk `LatentModel` @@ -100,11 +94,11 @@ Z_0 &\sim \mathcal{N}(0,1),\\ ``` " -# ╔═╡ 65540fb2-f97f-49dc-91f4-af26b803994e +# ╔═╡ 56ae496b-0094-460b-89cb-526627991717 rwp = EpiAware.RandomWalk(Normal(), - truncated(Normal(0.0, 0.02), 0.0, 0.5)) + truncated(Normal(0.0, 0.02), 0.0, Inf)) -# ╔═╡ 3ed6eb84-d0e1-4f09-9fa0-4021d0f79f88 +# ╔═╡ 767beffd-1ef5-4e6c-9ac6-edb52e60fb44 md" ## Direct infection `EpiModel` @@ -117,71 +111,78 @@ As discussed above, we still ask for a defined generation interval, which can be " -# ╔═╡ 0a532e52-8305-470a-8462-2aa023b724a2 +# ╔═╡ 9e43cbe3-94de-44fc-a788-b9c7adb34218 truth_GI = Gamma(2, 5) -# ╔═╡ ec24d355-0158-45e9-9584-ed89bbb17b31 +# ╔═╡ f067284f-a1a6-44a6-9b79-f8c2de447673 md" The `EpiData` constructor performs double interval censoring to convert our _continuous_ estimate of the generation interval into a discretized version. We also implement right truncation using the keyword `D_gen`. " -# ╔═╡ 8452c589-aceb-41b7-978e-918b83db58d3 +# ╔═╡ c0662d48-4b54-4b6d-8c91-ddf4b0e3aa43 model_data = EpiData(truth_GI, D_gen = 10.0) -# ╔═╡ 6028ccd2-428f-4737-9fa6-ab5bf17631bd +# ╔═╡ fd72094f-1b95-4d07-a8b0-ef47dc560dfc md" We can supply a prior for the initial log_infections. " -# ╔═╡ bd6714a0-6e70-4602-993e-12238e1f37f2 +# ╔═╡ 6639e66f-7725-4976-81b2-6472419d1a62 log_I0_prior = Normal(log(100.0), 1.0) -# ╔═╡ af012cc5-02ea-47f4-8545-cf54f2c6f6cc +# ╔═╡ df5e59f8-3185-4bed-9cca-7c266df17cec md" And construct the `EpiModel`. " -# ╔═╡ 18e51238-4038-4022-9b22-0e51ed51ea0a +# ╔═╡ 6fbdd8e6-2323-4352-9185-1f31a9cf9012 epi_model = DirectInfections(model_data, log_I0_prior) -# ╔═╡ d75e7957-a020-46f6-8d40-e2ac1abb917f +# ╔═╡ 5e62a50a-71f4-4902-b1c9-fdf51fe145fa md" ### Delayed Observations `ObservationModel` -The observation model is a negative binomial distribution with mean `μ` and cluster factor `1 / r`. Delays are implemented -as the action of a sparse kernel on the infections $I(t)$. +The observation model is a negative binomial distribution parameterised with mean $\mu$ and 'successes' parameter $r$. The standard deviation _relative_ to the mean $\sigma_{\text{rel}} = \sigma / \mu$ for negative binomial observations is, ```math -\begin{align} -y_t &\sim \text{NegBinomial}(\mu = \sum_{s\geq 0} K[t, t-s] I(s), r), \\ -1 / r &\sim \text{Gamma}(3, 0.05/3). -\end{align} +\sigma_{\text{rel}} =(1/\sqrt{\mu}) + (1 / \sqrt{r}). +``` +It is standard to use a half-t distribution for standard deviation priors (e.g. as argued in this [paper](http://www.stat.columbia.edu/~gelman/research/published/taumain.pdf)); we specialise this to a Half-Normal prior and use an _a priori_ assumption that a typical observation fluctuation around the mean (when the mean is $\sim\mathcal{O}(10^2)$) would be 10%. This implies a standard deviation prior, +```math +1 / \sqrt{r} \sim \text{HalfNormal}\Big(0.1 ~\sqrt{{\pi \over 2}}\Big). +``` +The $\sqrt{{\pi \over 2}}$ factor ensures the correct prior mean (see [here](https://en.wikipedia.org/wiki/Half-normal_distribution)). + +The expected observed cases are delayed infections. Delays are implemented as the action of a sparse kernel on the infections $I(t)$. + +```math +y_t \sim \text{NegBinomial}\Big(\mu = \sum_{s\geq 0} K[t, t-s] I(s), r\Big). \\ ``` " -# ╔═╡ 74771683-48cb-456f-8089-65c2fe2fdef2 +# ╔═╡ e813d547-6100-4c43-b84c-8cebe306bda8 md" We also set up the inference to occur over 100 days. " -# ╔═╡ cba7072f-aaaa-40fe-8e5f-f22d98fbdb30 +# ╔═╡ c7580ae6-0db5-448e-8b20-4dd6fcdb1ae0 time_horizon = 100 -# ╔═╡ 6c39de90-4791-42f0-b863-228753618c8a +# ╔═╡ 0aa3fcbd-0831-45b8-9a2c-7ffbabf5895f md" We choose a simple observation model where infections are observed 0, 1, 2, 3 days later with equal probability. " -# ╔═╡ b4a420af-71e6-4094-9795-9544dd5f34a2 +# ╔═╡ 448669bc-99f4-4823-b15e-fcc9040ba31b obs_model = EpiAware.DelayObservations( fill(0.25, 4), time_horizon, - truncated(Gamma(5, 0.05 / 5), 1e-3, 0.2) + truncated(Normal(0, 0.1 * sqrt(pi) / sqrt(2)), 0.0, Inf) ) -# ╔═╡ 5b667f47-8c90-42ca-8253-998a3ad3878d +# ╔═╡ e49713e8-4840-4083-8e3f-fc52d791be7b md" ## Generate cases from the `EpiAware` model @@ -190,40 +191,40 @@ Having chosen an `EpiModel`, `LatentModel` and `ObservationModel`, we can implem By giving `missing` to the first argument, we indicate that case data will be _generated_ from the model rather than treated as fixed. " -# ╔═╡ 12632c1c-b233-4990-9e7e-9add1bb9a8ee +# ╔═╡ abeff860-58c3-4644-9325-66ffd4446b6d full_epi_aware_mdl = make_epi_aware(missing, time_horizon; epi_model = epi_model, latent_model = rwp, observation_model = obs_model) -# ╔═╡ 031a55e8-617b-4da4-859c-94b660ec424d +# ╔═╡ 821628fb-8044-48b0-aa4f-0b7b57a2f45a md" We choose some fixed parameters: - Initial incidence is 100. - In the direct infection model, the initial incidence and in the initial value of the random walk form a non-identifiable pair. Therefore, we fix $Z_0 = 0$. " -# ╔═╡ 775c7295-a4ea-484e-9ad9-291df3f6ffe8 +# ╔═╡ 36b34fd2-2891-42ca-b5dc-abb482e516ee fixed_parameters = (rw_init = 0.0, init_incidence = log(100.0)) -# ╔═╡ 5ce1855e-9f9e-4feb-8df3-8ec6139ac943 +# ╔═╡ 0aadd9e3-7f91-4b45-9663-67d11335f0d0 md" We fix these parameters using `fix`, and generate a random epidemic. " -# ╔═╡ 5b86a63b-677c-4125-b0be-1527b73b91bd +# ╔═╡ 7e0e6012-8648-4f84-a25a-8b0138c4b72a cond_generative_model = fix(full_epi_aware_mdl, fixed_parameters) -# ╔═╡ 4a1fc7bb-82a0-4643-a18b-d331a31c1390 +# ╔═╡ b20c28be-7b07-410c-a33b-ea5ad6828c12 random_epidemic = rand(cond_generative_model) -# ╔═╡ e571e7b6-0e26-4855-ae90-05a18be6ff38 +# ╔═╡ d073e63b-62da-4743-ace0-78ef7806bc0b true_infections = generated_quantities(cond_generative_model, random_epidemic).I_t -# ╔═╡ 62092c7f-ebe7-428e-baaa-65c34be52371 +# ╔═╡ a04f3c1b-7e11-4800-9c2a-9fc0021de6e7 generated_obs = generated_quantities(cond_generative_model, random_epidemic).generated_y_t -# ╔═╡ 88e8fb2c-38ce-4c68-88b9-c42f3fa6de13 +# ╔═╡ f68b4e41-ac5c-42cd-a8c2-8761d66f7543 let plot(true_infections, label = "I_t", @@ -233,7 +234,7 @@ let scatter!(generated_obs, lab = "generated cases") end -# ╔═╡ 2f90bee6-067d-4267-beb9-356e4a4d714c +# ╔═╡ b5bc8f05-b538-4abf-aa84-450bf2dff3d9 md" ## Inference Fixing $Z_0 = 0$ for the random walk was based on inference principles; in this model $Z_0$ and $\log I_0$ are non-identifiable. @@ -243,19 +244,34 @@ However, we now treat the generated data as `truth_data` and make inference with We do the inference by MCMC/NUTS using the `Turing` NUTS sampler with default warm-up steps. " -# ╔═╡ 7e48a4c5-cd30-4377-8a98-e0c23f2dc31e +# ╔═╡ 4a4c6e91-8d8f-4bbf-bb7e-a36dc281e312 +md" +The observation model supports partially complete data. To test this we set some of the generated observations to be `missing`. +" + +# ╔═╡ 525aa98c-d0e5-4ffa-b808-d90fc986204c truth_data = generated_obs -# ╔═╡ 272e6798-1151-486f-9667-924dbc63bd69 +# ╔═╡ 259a7042-e74f-43c7-aeb4-97a3beeb7776 +let + truth_data = Union{Int, Missing}[truth_data...] + truth_data[vcat([3, 5], 10:20)] .= missing +end + +# ╔═╡ 32638954-2c99-4d4e-8e03-52154030c657 +md" +We now make the model but fixing the initial condition of the random walk to be 0. +" + +# ╔═╡ b4033728-b321-4100-8194-1fd9fe2d268d + inference_mdl = fix( - make_epi_aware(truth_data, time_horizon; - epi_model = epi_model, - latent_model = rwp, - observation_model = obs_model), + make_epi_aware(truth_data, time_horizon; epi_model = epi_model, + latent_model = rwp, observation_model = obs_model), (rw_init = 0.0,) ) -# ╔═╡ 4298f0ec-f6df-42ee-aa28-f7ed60f1e530 +# ╔═╡ 9222b436-9445-4039-abbf-25c8cddb7f63 md" ### Initialising inference @@ -266,38 +282,35 @@ To make NUTS more robust we provide `manypathfinder`, which is built on pathfind `manypathfinder` differs from `Pathfinder.multipathfinder`; `multipathfinder` is aimed at sampling from a potentially non-Gaussian target distribution which is first approximated as a uniformly weighted collection of normal approximations from pathfinder runs. `manypathfinder` is aimed at moving rapidly to a 'good' part of parameter space, and is robust to runs that fail. " -# ╔═╡ 40ebd47a-4a08-4a46-a727-26347d3fca51 +# ╔═╡ 197a4fbb-b71a-475a-bb78-28ff613e3094 best_pf = manypathfinder(inference_mdl, 10; nruns = 20, executor = Transducers.ThreadedEx()); -# ╔═╡ b7d9a56a-b2d5-4595-a6b9-9cd5fa6b1445 +# ╔═╡ 073a1d40-456a-450e-969f-11b23eb7fd1f md" We can use draws from the best pathfinder run to initialise NUTS. " -# ╔═╡ cdd805e2-b00c-4522-9261-1819c6a195eb +# ╔═╡ 0379b058-4c35-440a-bc01-aafa0178bdbf best_pf.draws_transformed -# ╔═╡ e847b0b6-9d70-46ba-bec6-1e3fa676a33c +# ╔═╡ a7798f71-9bb5-4506-9476-0cc11553b9e2 init_params = collect.(eachrow(best_pf.draws_transformed.value[1:4, :, 1])) -# ╔═╡ 9734a535-e3d8-4481-9897-f537ad095d21 +# ╔═╡ 4deb3a51-781d-48c4-91f6-6adf2b1affcf md" **NB: We are running this inference run for speed rather than accuracy as a demonstration. Use a higher target acceptance and more samples in a typical workflow.** " -# ╔═╡ 2fdb4ca6-47ba-4a16-95fa-14b2b32cef10 -begin - target_acc_rate = 0.8 - chn = sample(inference_mdl, - NUTS(target_acc_rate; adtype = AutoReverseDiff(true)), - MCMCThreads(), - 250, - 4; - init_params, - drop_warmup = true) -end +# ╔═╡ 3eb5ec5e-aae7-478e-84fb-80f2e9f85b4c +chn = sample(inference_mdl, + NUTS(; adtype = AutoReverseDiff(true)), + MCMCThreads(), + 250, + 4; + init_params, + drop_warmup = true) -# ╔═╡ 2e42cb30-b087-4ae1-9b8f-95d103e1c290 +# ╔═╡ 30498cc7-16a5-441a-b8cd-c19b220c60c1 md" ### Predictive plotting @@ -306,7 +319,7 @@ We can spaghetti plot generated case data from the version of the model _which h Because we are using synthetic data we can also plot the model predictions for the _unobserved_ infections and check that (at least in this example) we were able to capture some unobserved/latent variables in the process accurate. " -# ╔═╡ e74fc652-cd5f-4764-a416-caa8bab0bf0c +# ╔═╡ e9df22b8-8e4d-4ab7-91ea-c01f2239b3e5 let post_check_mdl = fix(full_epi_aware_mdl, (rw_init = 0.0,)) post_check_y_t = mapreduce(hcat, generated_quantities(post_check_mdl, chn)) do gen @@ -340,12 +353,12 @@ let size = (700, 400)) end -# ╔═╡ 96df9c68-b2e2-4669-b420-5ef23c77aee7 +# ╔═╡ fd6321b1-4c3a-4123-b0dc-c45b951e0b80 md" As well as checking the posterior predictions for latent infections, we can also check how well inference recovered unknown parameters, such as the random walk variance or the cluster factor of the negative binomial observations. " -# ╔═╡ 04d741d8-a2ff-48eb-90b1-e4da416eb582 +# ╔═╡ 10d8fe24-83a6-47ac-97b7-a374481473d3 let parameters_to_plot = (:σ²_RW, :neg_bin_cluster_factor) @@ -363,7 +376,7 @@ let plot(plts..., layout = (2, 1)) end -# ╔═╡ 42763332-096d-40eb-a152-96e858992ed4 +# ╔═╡ 81efe8ca-b753-4a12-bafc-a887a999377b md" ## Reproductive number back-calculation @@ -372,7 +385,7 @@ As mentioned at the top, we _don't_ directly use the concept of reproductive num Here we spaghetti plot posterior sampled time-varying reproductive numbers against the actual. " -# ╔═╡ 3b5a3fa6-fc57-4b3c-b03d-04641bf0e48b +# ╔═╡ 15b9f37f-8d5f-460d-8c28-d7f2271fd099 let n = epi_model.data.len_gen_int Rt_denom = [dot(reverse(epi_model.data.gen_int), true_infections[(t - n):(t - 1)]) @@ -397,48 +410,50 @@ let end # ╔═╡ Cell order: -# ╟─4680eed2-dbd8-11ee-17d6-2552317711c1 -# ╟─914b0cd7-86ca-4227-96cb-7cd887956833 -# ╟─7c22d80b-2d52-4935-b62d-cbc64437195c -# ╟─683743e6-9426-4b95-994b-4a579aa2564d -# ╟─0b554dd9-79c7-44bc-9cbf-ea56439cb80d -# ╠═65540fb2-f97f-49dc-91f4-af26b803994e -# ╟─3ed6eb84-d0e1-4f09-9fa0-4021d0f79f88 -# ╠═0a532e52-8305-470a-8462-2aa023b724a2 -# ╟─ec24d355-0158-45e9-9584-ed89bbb17b31 -# ╠═8452c589-aceb-41b7-978e-918b83db58d3 -# ╟─6028ccd2-428f-4737-9fa6-ab5bf17631bd -# ╠═bd6714a0-6e70-4602-993e-12238e1f37f2 -# ╟─af012cc5-02ea-47f4-8545-cf54f2c6f6cc -# ╠═18e51238-4038-4022-9b22-0e51ed51ea0a -# ╟─d75e7957-a020-46f6-8d40-e2ac1abb917f -# ╟─74771683-48cb-456f-8089-65c2fe2fdef2 -# ╠═cba7072f-aaaa-40fe-8e5f-f22d98fbdb30 -# ╟─6c39de90-4791-42f0-b863-228753618c8a -# ╠═b4a420af-71e6-4094-9795-9544dd5f34a2 -# ╟─5b667f47-8c90-42ca-8253-998a3ad3878d -# ╠═12632c1c-b233-4990-9e7e-9add1bb9a8ee -# ╟─031a55e8-617b-4da4-859c-94b660ec424d -# ╠═775c7295-a4ea-484e-9ad9-291df3f6ffe8 -# ╟─5ce1855e-9f9e-4feb-8df3-8ec6139ac943 -# ╠═5b86a63b-677c-4125-b0be-1527b73b91bd -# ╠═4a1fc7bb-82a0-4643-a18b-d331a31c1390 -# ╠═e571e7b6-0e26-4855-ae90-05a18be6ff38 -# ╠═62092c7f-ebe7-428e-baaa-65c34be52371 -# ╟─88e8fb2c-38ce-4c68-88b9-c42f3fa6de13 -# ╟─2f90bee6-067d-4267-beb9-356e4a4d714c -# ╠═7e48a4c5-cd30-4377-8a98-e0c23f2dc31e -# ╠═272e6798-1151-486f-9667-924dbc63bd69 -# ╟─4298f0ec-f6df-42ee-aa28-f7ed60f1e530 -# ╠═40ebd47a-4a08-4a46-a727-26347d3fca51 -# ╟─b7d9a56a-b2d5-4595-a6b9-9cd5fa6b1445 -# ╠═cdd805e2-b00c-4522-9261-1819c6a195eb -# ╠═e847b0b6-9d70-46ba-bec6-1e3fa676a33c -# ╟─9734a535-e3d8-4481-9897-f537ad095d21 -# ╠═2fdb4ca6-47ba-4a16-95fa-14b2b32cef10 -# ╟─2e42cb30-b087-4ae1-9b8f-95d103e1c290 -# ╠═e74fc652-cd5f-4764-a416-caa8bab0bf0c -# ╠═96df9c68-b2e2-4669-b420-5ef23c77aee7 -# ╠═04d741d8-a2ff-48eb-90b1-e4da416eb582 -# ╠═42763332-096d-40eb-a152-96e858992ed4 -# ╠═3b5a3fa6-fc57-4b3c-b03d-04641bf0e48b +# ╟─c593a2a0-d7f5-11ee-0931-d9f65ae84a72 +# ╟─3ebc8384-f73d-4597-83a7-07a3744fed61 +# ╠═da479d8d-1312-4b98-b0af-5be52dffaf3f +# ╟─5a0d5ab8-e985-4126-a1ac-58fe08beee38 +# ╠═56ae496b-0094-460b-89cb-526627991717 +# ╟─767beffd-1ef5-4e6c-9ac6-edb52e60fb44 +# ╠═9e43cbe3-94de-44fc-a788-b9c7adb34218 +# ╟─f067284f-a1a6-44a6-9b79-f8c2de447673 +# ╠═c0662d48-4b54-4b6d-8c91-ddf4b0e3aa43 +# ╟─fd72094f-1b95-4d07-a8b0-ef47dc560dfc +# ╠═6639e66f-7725-4976-81b2-6472419d1a62 +# ╟─df5e59f8-3185-4bed-9cca-7c266df17cec +# ╠═6fbdd8e6-2323-4352-9185-1f31a9cf9012 +# ╟─5e62a50a-71f4-4902-b1c9-fdf51fe145fa +# ╟─e813d547-6100-4c43-b84c-8cebe306bda8 +# ╠═c7580ae6-0db5-448e-8b20-4dd6fcdb1ae0 +# ╟─0aa3fcbd-0831-45b8-9a2c-7ffbabf5895f +# ╠═448669bc-99f4-4823-b15e-fcc9040ba31b +# ╟─e49713e8-4840-4083-8e3f-fc52d791be7b +# ╠═abeff860-58c3-4644-9325-66ffd4446b6d +# ╟─821628fb-8044-48b0-aa4f-0b7b57a2f45a +# ╠═36b34fd2-2891-42ca-b5dc-abb482e516ee +# ╟─0aadd9e3-7f91-4b45-9663-67d11335f0d0 +# ╠═7e0e6012-8648-4f84-a25a-8b0138c4b72a +# ╠═b20c28be-7b07-410c-a33b-ea5ad6828c12 +# ╠═d073e63b-62da-4743-ace0-78ef7806bc0b +# ╠═a04f3c1b-7e11-4800-9c2a-9fc0021de6e7 +# ╟─f68b4e41-ac5c-42cd-a8c2-8761d66f7543 +# ╟─b5bc8f05-b538-4abf-aa84-450bf2dff3d9 +# ╟─4a4c6e91-8d8f-4bbf-bb7e-a36dc281e312 +# ╠═525aa98c-d0e5-4ffa-b808-d90fc986204c +# ╠═259a7042-e74f-43c7-aeb4-97a3beeb7776 +# ╟─32638954-2c99-4d4e-8e03-52154030c657 +# ╠═b4033728-b321-4100-8194-1fd9fe2d268d +# ╟─9222b436-9445-4039-abbf-25c8cddb7f63 +# ╠═197a4fbb-b71a-475a-bb78-28ff613e3094 +# ╠═073a1d40-456a-450e-969f-11b23eb7fd1f +# ╠═0379b058-4c35-440a-bc01-aafa0178bdbf +# ╠═a7798f71-9bb5-4506-9476-0cc11553b9e2 +# ╟─4deb3a51-781d-48c4-91f6-6adf2b1affcf +# ╠═3eb5ec5e-aae7-478e-84fb-80f2e9f85b4c +# ╟─30498cc7-16a5-441a-b8cd-c19b220c60c1 +# ╠═e9df22b8-8e4d-4ab7-91ea-c01f2239b3e5 +# ╟─fd6321b1-4c3a-4123-b0dc-c45b951e0b80 +# ╠═10d8fe24-83a6-47ac-97b7-a374481473d3 +# ╟─81efe8ca-b753-4a12-bafc-a887a999377b +# ╠═15b9f37f-8d5f-460d-8c28-d7f2271fd099 diff --git a/EpiAware/src/inference-methods.jl b/EpiAware/src/inference-methods.jl index e105390cf..8218d8de7 100644 --- a/EpiAware/src/inference-methods.jl +++ b/EpiAware/src/inference-methods.jl @@ -95,8 +95,7 @@ largest ELBO estimate. - `best_pfs::PathfinderResult`: Best pathfinder result by estimated ELBO. """ function manypathfinder(mdl::DynamicPPL.Model, ndraws; nruns = 4, - nchains = 4, maxiters = 50, max_tries = 100, kwargs...) - ndraws = max(ndraws, nchains) + maxiters = 50, max_tries = 100, kwargs...) _run_manypathfinder(mdl; nruns, ndraws, maxiters, kwargs...) |> pfs -> _continue_manypathfinder!(pfs, mdl; max_tries, nruns, kwargs...) |> pfs -> _get_best_elbo_pathfinder(pfs) diff --git a/EpiAware/src/observation-models.jl b/EpiAware/src/observation-models.jl index 39b2b322d..d1599ef93 100644 --- a/EpiAware/src/observation-models.jl +++ b/EpiAware/src/observation-models.jl @@ -26,7 +26,8 @@ struct DelayObservations{T <: AbstractFloat, S <: Sampleable} <: AbstractObserva end function default_delay_obs_priors() - return (:neg_bin_cluster_factor_prior => Gamma(3, 0.05 / 3),) |> Dict + return (:neg_bin_cluster_factor_prior => truncated( + Normal(0, 0.1 * sqrt(pi) / sqrt(2)), 0.0, Inf),) |> Dict end function generate_observations(observation_model::AbstractObservationModel, @@ -54,7 +55,7 @@ end for i in eachindex(y_t) y_t[i] ~ NegativeBinomialMeanClust( - expected_obs[i], neg_bin_cluster_factor + expected_obs[i], neg_bin_cluster_factor^2 ) end diff --git a/EpiAware/test/test_inference-methods.jl b/EpiAware/test/test_inference-methods.jl index c7f10e186..4a36b7aba 100644 --- a/EpiAware/test/test_inference-methods.jl +++ b/EpiAware/test/test_inference-methods.jl @@ -111,7 +111,7 @@ end maxiters = 50 max_tries = 100 - best_pf = manypathfinder(mdl, ndraws; nruns = nruns, nchains = nchains, + best_pf = manypathfinder(mdl, ndraws; nruns = nruns, maxiters = maxiters, max_tries = max_tries) @test best_pf isa PathfinderResult @@ -128,7 +128,7 @@ end maxiters = 50 max_tries = 10 - best_pf = manypathfinder(mdl, ndraws; nruns = nruns, nchains = nchains, + best_pf = manypathfinder(mdl, ndraws; nruns = nruns, maxiters = maxiters, max_tries = max_tries) pathfinder_samples = best_pf.draws |> vec diff --git a/EpiAware/test/test_observation-models.jl b/EpiAware/test/test_observation-models.jl index 66b33531b..df6b4f7c9 100644 --- a/EpiAware/test/test_observation-models.jl +++ b/EpiAware/test/test_observation-models.jl @@ -25,7 +25,7 @@ chn -> generated_quantities(fix_mdl, chn) .|> (gen -> gen[1][1]) |> vec - direct_samples = EpiAware.NegativeBinomialMeanClust(I_t[1], neg_bin_cf) |> + direct_samples = EpiAware.NegativeBinomialMeanClust(I_t[1], neg_bin_cf^2) |> dist -> rand(dist, n_samples) #For discrete distributions, checking mean and variance is as expected