diff --git a/EpiAware/Project.toml b/EpiAware/Project.toml index cb35f98bc..33006ea15 100644 --- a/EpiAware/Project.toml +++ b/EpiAware/Project.toml @@ -7,27 +7,33 @@ version = "0.1.0-DEV" DataFramesMeta = "1313f7d8-7da2-5740-9ea0-a2ca25f37964" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" Optim = "429524aa-4258-5aef-a3af-852621145aeb" Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a" +Pathfinder = "b1d3bc72-d0e7-4279-b92f-7fa5d6d2d454" QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" [compat] DataFramesMeta = "0.14" Distributions = "0.25" DocStringExtensions = "0.9" +DynamicPPL = "0.24" LinearAlgebra = "1.9" LogExpFunctions = "0.3" Optim = "1.9" Parameters = "0.12" +Pathfinder = "0.8" QuadGK = "2.9" Random = "1.9" ReverseDiff = "1.15" SparseArrays = "1.10" +Transducers = "0.4" Turing = "0.30" julia = "1.10" diff --git a/EpiAware/docs/Project.toml b/EpiAware/docs/Project.toml index 900ecfe73..34efc468c 100644 --- a/EpiAware/docs/Project.toml +++ b/EpiAware/docs/Project.toml @@ -6,8 +6,10 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" EpiAware = "b2eeebe4-5992-4301-9193-7ebc9f62c855" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Pathfinder = "b1d3bc72-d0e7-4279-b92f-7fa5d6d2d454" Pluto = "c3e4b0f8-55cb-11ea-2926-15256bba5781" PlutoStaticHTML = "359b1769-a58e-495b-9770-312e911026ad" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd" +Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" diff --git a/EpiAware/docs/src/examples/getting_started.jl b/EpiAware/docs/src/examples/getting_started.jl index 49ced9764..d9e49bb6c 100644 --- a/EpiAware/docs/src/examples/getting_started.jl +++ b/EpiAware/docs/src/examples/getting_started.jl @@ -1,11 +1,10 @@ ### A Pluto.jl notebook ### -# v0.19.39 +# v0.19.40 using Markdown using InteractiveUtils -# ╔═╡ c593a2a0-d7f5-11ee-0931-d9f65ae84a72 -# hideall +# ╔═╡ 4680eed2-dbd8-11ee-17d6-2552317711c1 let docs_dir = dirname(dirname(@__DIR__)) pkg_dir = dirname(docs_dir) @@ -13,10 +12,11 @@ let using Pkg: Pkg Pkg.activate(docs_dir) Pkg.develop(; path = pkg_dir) + Pkg.resolve() Pkg.instantiate() -end; +end -# ╔═╡ da479d8d-1312-4b98-b0af-5be52dffaf3f +# ╔═╡ 914b0cd7-86ca-4227-96cb-7cd887956833 begin using EpiAware using Turing @@ -27,14 +27,21 @@ begin using Statistics using DataFramesMeta using LinearAlgebra -end + using Pathfinder + using Transducers -# ╔═╡ 3ebc8384-f73d-4597-83a7-07a3744fed61 + Random.seed!(1) +end; + +# ╔═╡ 7c22d80b-2d52-4935-b62d-cbc64437195c md" # Getting stated with `EpiAware` This tutorial introduces the basic functionality of `EpiAware`. `EpiAware` is a package for making inferences on epidemiological case/determined infection data using a model-based approach. +" +# ╔═╡ 683743e6-9426-4b95-994b-4a579aa2564d +md" ## `EpiAware` models The models we consider are discrete-time $t = 1,\dots, T$ with a latent random process, $Z_t$ generating stochasticity in the number of new infections $I_t$ at each time step. Observations are treated as downstream random variables determined by the actual infections and a model of infection to observation delay. @@ -67,7 +74,7 @@ An `EpiAware` model in code is created from three modular components: Where $g_s$ is a discrete generation interval. For this reason, even when not using a reproductive number approach directly, we ask for a generation interval. " -# ╔═╡ 5a0d5ab8-e985-4126-a1ac-58fe08beee38 +# ╔═╡ 0b554dd9-79c7-44bc-9cbf-ea56439cb80d md" ## Random walk `LatentModel` @@ -93,11 +100,11 @@ Z_0 &\sim \mathcal{N}(0,1),\\ ``` " -# ╔═╡ 56ae496b-0094-460b-89cb-526627991717 +# ╔═╡ 65540fb2-f97f-49dc-91f4-af26b803994e rwp = EpiAware.RandomWalk(Normal(), - truncated(Normal(0.0, 0.02), 0.0, Inf)) + truncated(Normal(0.0, 0.02), 0.0, 0.5)) -# ╔═╡ 767beffd-1ef5-4e6c-9ac6-edb52e60fb44 +# ╔═╡ 3ed6eb84-d0e1-4f09-9fa0-4021d0f79f88 md" ## Direct infection `EpiModel` @@ -110,34 +117,34 @@ As discussed above, we still ask for a defined generation interval, which can be " -# ╔═╡ 9e43cbe3-94de-44fc-a788-b9c7adb34218 +# ╔═╡ 0a532e52-8305-470a-8462-2aa023b724a2 truth_GI = Gamma(2, 5) -# ╔═╡ f067284f-a1a6-44a6-9b79-f8c2de447673 +# ╔═╡ ec24d355-0158-45e9-9584-ed89bbb17b31 md" The `EpiData` constructor performs double interval censoring to convert our _continuous_ estimate of the generation interval into a discretized version. We also implement right truncation using the keyword `D_gen`. " -# ╔═╡ c0662d48-4b54-4b6d-8c91-ddf4b0e3aa43 +# ╔═╡ 8452c589-aceb-41b7-978e-918b83db58d3 model_data = EpiData(truth_GI, D_gen = 10.0) -# ╔═╡ fd72094f-1b95-4d07-a8b0-ef47dc560dfc +# ╔═╡ 6028ccd2-428f-4737-9fa6-ab5bf17631bd md" We can supply a prior for the initial log_infections. " -# ╔═╡ 6639e66f-7725-4976-81b2-6472419d1a62 +# ╔═╡ bd6714a0-6e70-4602-993e-12238e1f37f2 log_I0_prior = Normal(log(100.0), 1.0) -# ╔═╡ df5e59f8-3185-4bed-9cca-7c266df17cec +# ╔═╡ af012cc5-02ea-47f4-8545-cf54f2c6f6cc md" And construct the `EpiModel`. " -# ╔═╡ 6fbdd8e6-2323-4352-9185-1f31a9cf9012 +# ╔═╡ 18e51238-4038-4022-9b22-0e51ed51ea0a epi_model = DirectInfections(model_data, log_I0_prior) -# ╔═╡ 5e62a50a-71f4-4902-b1c9-fdf51fe145fa +# ╔═╡ d75e7957-a020-46f6-8d40-e2ac1abb917f md" @@ -154,27 +161,27 @@ y_t &\sim \text{NegBinomial}(\mu = \sum_{s\geq 0} K[t, t-s] I(s), r), \\ ``` " -# ╔═╡ e813d547-6100-4c43-b84c-8cebe306bda8 +# ╔═╡ 74771683-48cb-456f-8089-65c2fe2fdef2 md" We also set up the inference to occur over 100 days. " -# ╔═╡ c7580ae6-0db5-448e-8b20-4dd6fcdb1ae0 +# ╔═╡ cba7072f-aaaa-40fe-8e5f-f22d98fbdb30 time_horizon = 100 -# ╔═╡ 0aa3fcbd-0831-45b8-9a2c-7ffbabf5895f +# ╔═╡ 6c39de90-4791-42f0-b863-228753618c8a md" We choose a simple observation model where infections are observed 0, 1, 2, 3 days later with equal probability. " -# ╔═╡ 448669bc-99f4-4823-b15e-fcc9040ba31b +# ╔═╡ b4a420af-71e6-4094-9795-9544dd5f34a2 obs_model = EpiAware.DelayObservations( fill(0.25, 4), time_horizon, - truncated(Gamma(5, 0.05 / 5), 1e-3, 1.0) + truncated(Gamma(5, 0.05 / 5), 1e-3, 0.2) ) -# ╔═╡ e49713e8-4840-4083-8e3f-fc52d791be7b +# ╔═╡ 5b667f47-8c90-42ca-8253-998a3ad3878d md" ## Generate cases from the `EpiAware` model @@ -183,40 +190,40 @@ Having chosen an `EpiModel`, `LatentModel` and `ObservationModel`, we can implem By giving `missing` to the first argument, we indicate that case data will be _generated_ from the model rather than treated as fixed. " -# ╔═╡ abeff860-58c3-4644-9325-66ffd4446b6d +# ╔═╡ 12632c1c-b233-4990-9e7e-9add1bb9a8ee full_epi_aware_mdl = make_epi_aware(missing, time_horizon; epi_model = epi_model, latent_model = rwp, observation_model = obs_model) -# ╔═╡ 821628fb-8044-48b0-aa4f-0b7b57a2f45a +# ╔═╡ 031a55e8-617b-4da4-859c-94b660ec424d md" We choose some fixed parameters: - Initial incidence is 100. - In the direct infection model, the initial incidence and in the initial value of the random walk form a non-identifiable pair. Therefore, we fix $Z_0 = 0$. " -# ╔═╡ 36b34fd2-2891-42ca-b5dc-abb482e516ee +# ╔═╡ 775c7295-a4ea-484e-9ad9-291df3f6ffe8 fixed_parameters = (rw_init = 0.0, init_incidence = log(100.0)) -# ╔═╡ 0aadd9e3-7f91-4b45-9663-67d11335f0d0 +# ╔═╡ 5ce1855e-9f9e-4feb-8df3-8ec6139ac943 md" We fix these parameters using `fix`, and generate a random epidemic. " -# ╔═╡ 7e0e6012-8648-4f84-a25a-8b0138c4b72a +# ╔═╡ 5b86a63b-677c-4125-b0be-1527b73b91bd cond_generative_model = fix(full_epi_aware_mdl, fixed_parameters) -# ╔═╡ b20c28be-7b07-410c-a33b-ea5ad6828c12 +# ╔═╡ 4a1fc7bb-82a0-4643-a18b-d331a31c1390 random_epidemic = rand(cond_generative_model) -# ╔═╡ d073e63b-62da-4743-ace0-78ef7806bc0b +# ╔═╡ e571e7b6-0e26-4855-ae90-05a18be6ff38 true_infections = generated_quantities(cond_generative_model, random_epidemic).I_t -# ╔═╡ a04f3c1b-7e11-4800-9c2a-9fc0021de6e7 +# ╔═╡ 62092c7f-ebe7-428e-baaa-65c34be52371 generated_obs = generated_quantities(cond_generative_model, random_epidemic).generated_y_t -# ╔═╡ f68b4e41-ac5c-42cd-a8c2-8761d66f7543 +# ╔═╡ 88e8fb2c-38ce-4c68-88b9-c42f3fa6de13 let plot(true_infections, label = "I_t", @@ -226,7 +233,7 @@ let scatter!(generated_obs, lab = "generated cases") end -# ╔═╡ b5bc8f05-b538-4abf-aa84-450bf2dff3d9 +# ╔═╡ 2f90bee6-067d-4267-beb9-356e4a4d714c md" ## Inference Fixing $Z_0 = 0$ for the random walk was based on inference principles; in this model $Z_0$ and $\log I_0$ are non-identifiable. @@ -236,42 +243,61 @@ However, we now treat the generated data as `truth_data` and make inference with We do the inference by MCMC/NUTS using the `Turing` NUTS sampler with default warm-up steps. " -# ╔═╡ c8ce0d46-a160-4c40-a055-69b3d10d1770 +# ╔═╡ 7e48a4c5-cd30-4377-8a98-e0c23f2dc31e truth_data = generated_obs -# ╔═╡ 4a4c6e91-8d8f-4bbf-bb7e-a36dc281e312 +# ╔═╡ 272e6798-1151-486f-9667-924dbc63bd69 +inference_mdl = fix( + make_epi_aware(truth_data, time_horizon; + epi_model = epi_model, + latent_model = rwp, + observation_model = obs_model), + (rw_init = 0.0,) +) + +# ╔═╡ 4298f0ec-f6df-42ee-aa28-f7ed60f1e530 md" -The observation model supports partially complete data. To test this we set some of the generated observations to be `missing`. +### Initialising inference + +It is possible for the default warm-up process for NUTS to get stuck in low probability or otherwise degenerate regions of parameter space. + +To make NUTS more robust we provide `manypathfinder`, which is built on pathfinder variational inference from [Pathfinder.jl](https://mlcolab.github.io/Pathfinder.jl/stable/). `manypathfinder` runs `nruns` pathfinder processes on the inference problem and returns the pathfinder run with maximum estimated ELBO. + +`manypathfinder` differs from `Pathfinder.multipathfinder`; `multipathfinder` is aimed at sampling from a potentially non-Gaussian target distribution which is first approximated as a uniformly weighted collection of normal approximations from pathfinder runs. `manypathfinder` is aimed at moving rapidly to a 'good' part of parameter space, and is robust to runs that fail. " -# ╔═╡ 259a7042-e74f-43c7-aeb4-97a3beeb7776 -let - truth_data = Union{Int, Missing}[truth_data...] - truth_data[vcat([3, 5], 10:20)] .= missing -end +# ╔═╡ 40ebd47a-4a08-4a46-a727-26347d3fca51 +best_pf = manypathfinder(inference_mdl, 10; nruns = 20, executor = Transducers.ThreadedEx()); -# ╔═╡ 32638954-2c99-4d4e-8e03-52154030c657 +# ╔═╡ b7d9a56a-b2d5-4595-a6b9-9cd5fa6b1445 md" -We now make the model but fixing the initial condition of the random walk to be 0. +We can use draws from the best pathfinder run to initialise NUTS. " -# ╔═╡ b4033728-b321-4100-8194-1fd9fe2d268d +# ╔═╡ cdd805e2-b00c-4522-9261-1819c6a195eb +best_pf.draws_transformed -inference_mdl = fix( - make_epi_aware(truth_data, time_horizon; epi_model = epi_model, - latent_model = rwp, observation_model = obs_model), - (rw_init = 0.0,) -) +# ╔═╡ e847b0b6-9d70-46ba-bec6-1e3fa676a33c +init_params = collect.(eachrow(best_pf.draws_transformed.value[1:4, :, 1])) -# ╔═╡ 3eb5ec5e-aae7-478e-84fb-80f2e9f85b4c -chn = sample(inference_mdl, - NUTS(; adtype = AutoReverseDiff(true)), - MCMCThreads(), - 250, - 4; - drop_warmup = true) +# ╔═╡ 9734a535-e3d8-4481-9897-f537ad095d21 +md" +**NB: We are running this inference run for speed rather than accuracy as a demonstration. Use a higher target acceptance and more samples in a typical workflow.** +" + +# ╔═╡ 2fdb4ca6-47ba-4a16-95fa-14b2b32cef10 +begin + target_acc_rate = 0.8 + chn = sample(inference_mdl, + NUTS(target_acc_rate; adtype = AutoReverseDiff(true)), + MCMCThreads(), + 250, + 4; + init_params, + drop_warmup = true) +end -# ╔═╡ 30498cc7-16a5-441a-b8cd-c19b220c60c1 +# ╔═╡ 2e42cb30-b087-4ae1-9b8f-95d103e1c290 md" ### Predictive plotting @@ -280,7 +306,7 @@ We can spaghetti plot generated case data from the version of the model _which h Because we are using synthetic data we can also plot the model predictions for the _unobserved_ infections and check that (at least in this example) we were able to capture some unobserved/latent variables in the process accurate. " -# ╔═╡ e9df22b8-8e4d-4ab7-91ea-c01f2239b3e5 +# ╔═╡ e74fc652-cd5f-4764-a416-caa8bab0bf0c let post_check_mdl = fix(full_epi_aware_mdl, (rw_init = 0.0,)) post_check_y_t = mapreduce(hcat, generated_quantities(post_check_mdl, chn)) do gen @@ -314,14 +340,12 @@ let size = (700, 400)) end -# ╔═╡ 2293b711-0bd0-44d5-8a30-94e56c5e4c65 - -# ╔═╡ fd6321b1-4c3a-4123-b0dc-c45b951e0b80 +# ╔═╡ 96df9c68-b2e2-4669-b420-5ef23c77aee7 md" As well as checking the posterior predictions for latent infections, we can also check how well inference recovered unknown parameters, such as the random walk variance or the cluster factor of the negative binomial observations. " -# ╔═╡ 10d8fe24-83a6-47ac-97b7-a374481473d3 +# ╔═╡ 04d741d8-a2ff-48eb-90b1-e4da416eb582 let parameters_to_plot = (:σ²_RW, :neg_bin_cluster_factor) @@ -339,7 +363,7 @@ let plot(plts..., layout = (2, 1)) end -# ╔═╡ 81efe8ca-b753-4a12-bafc-a887a999377b +# ╔═╡ 42763332-096d-40eb-a152-96e858992ed4 md" ## Reproductive number back-calculation @@ -348,7 +372,7 @@ As mentioned at the top, we _don't_ directly use the concept of reproductive num Here we spaghetti plot posterior sampled time-varying reproductive numbers against the actual. " -# ╔═╡ 15b9f37f-8d5f-460d-8c28-d7f2271fd099 +# ╔═╡ 3b5a3fa6-fc57-4b3c-b03d-04641bf0e48b let n = epi_model.data.len_gen_int Rt_denom = [dot(reverse(epi_model.data.gen_int), true_infections[(t - n):(t - 1)]) @@ -373,45 +397,48 @@ let end # ╔═╡ Cell order: -# ╟─c593a2a0-d7f5-11ee-0931-d9f65ae84a72 -# ╟─3ebc8384-f73d-4597-83a7-07a3744fed61 -# ╠═da479d8d-1312-4b98-b0af-5be52dffaf3f -# ╟─5a0d5ab8-e985-4126-a1ac-58fe08beee38 -# ╠═56ae496b-0094-460b-89cb-526627991717 -# ╟─767beffd-1ef5-4e6c-9ac6-edb52e60fb44 -# ╠═9e43cbe3-94de-44fc-a788-b9c7adb34218 -# ╟─f067284f-a1a6-44a6-9b79-f8c2de447673 -# ╠═c0662d48-4b54-4b6d-8c91-ddf4b0e3aa43 -# ╟─fd72094f-1b95-4d07-a8b0-ef47dc560dfc -# ╠═6639e66f-7725-4976-81b2-6472419d1a62 -# ╟─df5e59f8-3185-4bed-9cca-7c266df17cec -# ╠═6fbdd8e6-2323-4352-9185-1f31a9cf9012 -# ╟─5e62a50a-71f4-4902-b1c9-fdf51fe145fa -# ╟─e813d547-6100-4c43-b84c-8cebe306bda8 -# ╠═c7580ae6-0db5-448e-8b20-4dd6fcdb1ae0 -# ╟─0aa3fcbd-0831-45b8-9a2c-7ffbabf5895f -# ╠═448669bc-99f4-4823-b15e-fcc9040ba31b -# ╟─e49713e8-4840-4083-8e3f-fc52d791be7b -# ╠═abeff860-58c3-4644-9325-66ffd4446b6d -# ╟─821628fb-8044-48b0-aa4f-0b7b57a2f45a -# ╠═36b34fd2-2891-42ca-b5dc-abb482e516ee -# ╟─0aadd9e3-7f91-4b45-9663-67d11335f0d0 -# ╠═7e0e6012-8648-4f84-a25a-8b0138c4b72a -# ╠═b20c28be-7b07-410c-a33b-ea5ad6828c12 -# ╠═d073e63b-62da-4743-ace0-78ef7806bc0b -# ╠═a04f3c1b-7e11-4800-9c2a-9fc0021de6e7 -# ╟─f68b4e41-ac5c-42cd-a8c2-8761d66f7543 -# ╠═b5bc8f05-b538-4abf-aa84-450bf2dff3d9 -# ╠═c8ce0d46-a160-4c40-a055-69b3d10d1770 -# ╟─4a4c6e91-8d8f-4bbf-bb7e-a36dc281e312 -# ╠═259a7042-e74f-43c7-aeb4-97a3beeb7776 -# ╟─32638954-2c99-4d4e-8e03-52154030c657 -# ╠═b4033728-b321-4100-8194-1fd9fe2d268d -# ╠═3eb5ec5e-aae7-478e-84fb-80f2e9f85b4c -# ╟─30498cc7-16a5-441a-b8cd-c19b220c60c1 -# ╠═e9df22b8-8e4d-4ab7-91ea-c01f2239b3e5 -# ╠═2293b711-0bd0-44d5-8a30-94e56c5e4c65 -# ╟─fd6321b1-4c3a-4123-b0dc-c45b951e0b80 -# ╠═10d8fe24-83a6-47ac-97b7-a374481473d3 -# ╟─81efe8ca-b753-4a12-bafc-a887a999377b -# ╠═15b9f37f-8d5f-460d-8c28-d7f2271fd099 +# ╟─4680eed2-dbd8-11ee-17d6-2552317711c1 +# ╟─914b0cd7-86ca-4227-96cb-7cd887956833 +# ╟─7c22d80b-2d52-4935-b62d-cbc64437195c +# ╟─683743e6-9426-4b95-994b-4a579aa2564d +# ╟─0b554dd9-79c7-44bc-9cbf-ea56439cb80d +# ╠═65540fb2-f97f-49dc-91f4-af26b803994e +# ╟─3ed6eb84-d0e1-4f09-9fa0-4021d0f79f88 +# ╠═0a532e52-8305-470a-8462-2aa023b724a2 +# ╟─ec24d355-0158-45e9-9584-ed89bbb17b31 +# ╠═8452c589-aceb-41b7-978e-918b83db58d3 +# ╟─6028ccd2-428f-4737-9fa6-ab5bf17631bd +# ╠═bd6714a0-6e70-4602-993e-12238e1f37f2 +# ╟─af012cc5-02ea-47f4-8545-cf54f2c6f6cc +# ╠═18e51238-4038-4022-9b22-0e51ed51ea0a +# ╟─d75e7957-a020-46f6-8d40-e2ac1abb917f +# ╟─74771683-48cb-456f-8089-65c2fe2fdef2 +# ╠═cba7072f-aaaa-40fe-8e5f-f22d98fbdb30 +# ╟─6c39de90-4791-42f0-b863-228753618c8a +# ╠═b4a420af-71e6-4094-9795-9544dd5f34a2 +# ╟─5b667f47-8c90-42ca-8253-998a3ad3878d +# ╠═12632c1c-b233-4990-9e7e-9add1bb9a8ee +# ╟─031a55e8-617b-4da4-859c-94b660ec424d +# ╠═775c7295-a4ea-484e-9ad9-291df3f6ffe8 +# ╟─5ce1855e-9f9e-4feb-8df3-8ec6139ac943 +# ╠═5b86a63b-677c-4125-b0be-1527b73b91bd +# ╠═4a1fc7bb-82a0-4643-a18b-d331a31c1390 +# ╠═e571e7b6-0e26-4855-ae90-05a18be6ff38 +# ╠═62092c7f-ebe7-428e-baaa-65c34be52371 +# ╟─88e8fb2c-38ce-4c68-88b9-c42f3fa6de13 +# ╟─2f90bee6-067d-4267-beb9-356e4a4d714c +# ╠═7e48a4c5-cd30-4377-8a98-e0c23f2dc31e +# ╠═272e6798-1151-486f-9667-924dbc63bd69 +# ╟─4298f0ec-f6df-42ee-aa28-f7ed60f1e530 +# ╠═40ebd47a-4a08-4a46-a727-26347d3fca51 +# ╟─b7d9a56a-b2d5-4595-a6b9-9cd5fa6b1445 +# ╠═cdd805e2-b00c-4522-9261-1819c6a195eb +# ╠═e847b0b6-9d70-46ba-bec6-1e3fa676a33c +# ╟─9734a535-e3d8-4481-9897-f537ad095d21 +# ╠═2fdb4ca6-47ba-4a16-95fa-14b2b32cef10 +# ╟─2e42cb30-b087-4ae1-9b8f-95d103e1c290 +# ╠═e74fc652-cd5f-4764-a416-caa8bab0bf0c +# ╠═96df9c68-b2e2-4669-b420-5ef23c77aee7 +# ╠═04d741d8-a2ff-48eb-90b1-e4da416eb582 +# ╠═42763332-096d-40eb-a152-96e858992ed4 +# ╠═3b5a3fa6-fc57-4b3c-b03d-04641bf0e48b diff --git a/EpiAware/src/EpiAware.jl b/EpiAware/src/EpiAware.jl index a33275686..0d971220a 100644 --- a/EpiAware/src/EpiAware.jl +++ b/EpiAware/src/EpiAware.jl @@ -34,7 +34,7 @@ module EpiAware using Distributions, Turing, LogExpFunctions, LinearAlgebra, SparseArrays, Random, ReverseDiff, Optim, Parameters, QuadGK, DataFramesMeta, - DocStringExtensions + DocStringExtensions, Pathfinder, DynamicPPL, Transducers # Exported abstract types export AbstractModel, AbstractEpiModel, AbstractLatentModel, @@ -54,6 +54,9 @@ export generate_latent, generate_latent_infs, generate_observations export create_discrete_pmf, spread_draws, scan, R_to_r, r_to_R, default_rw_priors, default_delay_obs_priors +# Exported inference methods +export manypathfinder + include("docstrings.jl") include("abstract-types.jl") include("epi-models.jl") @@ -61,5 +64,6 @@ include("utilities.jl") include("latent-models.jl") include("observation-models.jl") include("models.jl") +include("inference-methods.jl") end diff --git a/EpiAware/src/inference-methods.jl b/EpiAware/src/inference-methods.jl new file mode 100644 index 000000000..e105390cf --- /dev/null +++ b/EpiAware/src/inference-methods.jl @@ -0,0 +1,103 @@ +""" + +Run pathfinder multiple times and store the results in an array. Fails safely. + +# Arguments +- `mdl::DynamicPPL.Model`: The `Turing` model to be used for inference. +- `nruns`: The number of times to run the `pathfinder` function. +- `kwargs...`: Additional keyword arguments passed to `pathfinder`. + +# Returns +An array of `PathfinderResult` objects or `Symbol` values indicating success or failure. +""" +function _run_manypathfinder(mdl::DynamicPPL.Model; nruns, kwargs...) + @info "Running pathfinder $nruns times" + pfs = Vector{Union{Pathfinder.PathfinderResult, Symbol}}(undef, nruns) + Threads.@threads for i in 1:nruns + try + pfs[i] = pathfinder(mdl; kwargs...) + catch + pfs[i] = :fail + end + end + return pfs +end + +""" +Continue running the pathfinder algorithm until a pathfinder succeeds or the maximum number +of tries is reached. + +# Arguments +- `pfs`: An array of pathfinder objects. +- `mdl::DynamicPPL.Model`: The model to perform inference on. +- `max_tries`: The maximum number of tries to run the pathfinder algorithm. Default is + `Inf`. +- `nruns`: The number of times to run the `pathfinder` function. +- `kwargs...`: Additional keyword arguments passed to `pathfinder`. + +# Returns +- `pfs`: The updated array of pathfinder objects. + +""" +function _continue_manypathfinder!(pfs, mdl::DynamicPPL.Model; max_tries, nruns, kwargs...) + tryiter = 1 + if all(pfs .== :fail) + @warn "All initial pathfinder runs failed, trying again for $max_tries tries." + end + while all(pfs .== :fail) && tryiter <= max_tries + new_pf = try + pathfinder(mdl; kwargs...) + catch + :fail + end + pfs = vcat(pfs, new_pf) + tryiter += 1 + end + if all(pfs .== :fail) + e = ErrorException("All pathfinder runs failed after $max_tries tries.") + throw(e) + end + return pfs +end + +""" +Selects the pathfinder with the highest ELBO estimate from a list of pathfinders. + +# Arguments +- `pfs`: A list of pathfinders results or `Symbol` values indicating failure. + +# Returns +The pathfinder with the highest ELBO estimate. +""" +function _get_best_elbo_pathfinder(pfs) + elbos = map(pfs) do pf_res + pf_res == :fail ? -Inf : pf_res.elbo_estimates[end].value + end + _, choice_of_pf = findmax(elbos) + return pfs[choice_of_pf] +end + +""" +Run multiple instances of the pathfinder algorithm and returns the pathfinder run with the +largest ELBO estimate. + +## Arguments +- `mdl::DynamicPPL.Model`: The model to perform inference on. +- `nruns::Int`: The number of pathfinder runs to perform. +- `ndraws::Int`: The number of draws per pathfinder run, readjusted to be at least as large + as the number of chains. +- `nchains::Int`: The number of chains that will be initialised by pathfinder draws. +- `maxiters::Int`: The maximum number of optimizer iterations per pathfinder run. +- `max_tries::Int`: The maximum number of extra tries to find a valid pathfinder result. +- `kwargs...`: Additional keyword arguments passed to `pathfinder`. + +## Returns +- `best_pfs::PathfinderResult`: Best pathfinder result by estimated ELBO. +""" +function manypathfinder(mdl::DynamicPPL.Model, ndraws; nruns = 4, + nchains = 4, maxiters = 50, max_tries = 100, kwargs...) + ndraws = max(ndraws, nchains) + _run_manypathfinder(mdl; nruns, ndraws, maxiters, kwargs...) |> + pfs -> _continue_manypathfinder!(pfs, mdl; max_tries, nruns, kwargs...) |> + pfs -> _get_best_elbo_pathfinder(pfs) +end diff --git a/EpiAware/test/Project.toml b/EpiAware/test/Project.toml index e927523a8..b882dd0ac 100644 --- a/EpiAware/test/Project.toml +++ b/EpiAware/test/Project.toml @@ -5,10 +5,12 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Pathfinder = "b1d3bc72-d0e7-4279-b92f-7fa5d6d2d454" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a" TestItems = "1c621080-faea-4a02-84b6-bbd5e436b8fe" +Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" diff --git a/EpiAware/test/test_inference-methods.jl b/EpiAware/test/test_inference-methods.jl new file mode 100644 index 000000000..c7f10e186 --- /dev/null +++ b/EpiAware/test/test_inference-methods.jl @@ -0,0 +1,157 @@ +@testitem "Testing _run_manypathfinder function" begin + using Turing, Pathfinder + + @testset "Test case: check runs" begin + @model function test_model() + x ~ Normal(0, 1) + y ~ Normal(x, 1) + end + + mdl = test_model() + + nruns = 10 + ndraws = 100 + maxiters = 50 + + pfs = EpiAware._run_manypathfinder( + mdl; nruns = nruns, ndraws = ndraws, maxiters = maxiters) + + @test length(pfs) == nruns + @test all(p -> p isa Union{PathfinderResult, Symbol}, pfs) + end + @testset "Test case: check fail mode for bad model" begin + @model function bad_model() + x ~ Normal(0, 1) + return sqrt(x) #<-fails + end + badmdl = bad_model() + nruns = 5 + ndraws = 50 + maxiters = 100 + + pfs = EpiAware._run_manypathfinder( + badmdl; nruns = nruns, ndraws = ndraws, maxiters = maxiters) + + @test all(pfs .== :fail) + end +end +@testitem "Testing _continue_manypathfinder! function" begin + using Turing, Pathfinder + + @testset "Check that it only adds one more for easy model" begin + @model function easy_model() + x ~ Normal(0, 1) + end + + easymdl = easy_model() + + pfs = Vector{Union{PathfinderResult, Symbol}}([:fail, :fail, :fail]) + max_tries = 3 + nruns = 10 + ndraws = 100 + maxiters = 50 + + pfs = EpiAware._continue_manypathfinder!( + pfs, easymdl; max_tries, nruns, ndraws, maxiters) + + @test pfs[end] isa PathfinderResult + end + + @testset "Check always fails for bad models and throws correct Exception" begin + @model function bad_model() + x ~ Normal(0, 1) + return sqrt(x) #<-fails + end + badmdl = bad_model() + + pfs = Vector{Union{PathfinderResult, Symbol}}([:fail, :fail, :fail]) + max_tries = 3 + nruns = 10 + ndraws = 100 + maxiters = 50 + + @test_throws "All pathfinder runs failed after $max_tries tries." begin + pfs = EpiAware._continue_manypathfinder!( + pfs, badmdl; max_tries, nruns, ndraws, maxiters) + end + end +end +@testitem "Testing _get_best_elbo_pathfinder function" begin + using Pathfinder, Turing + + @model function test_model() + x ~ Normal(0, 1) + y ~ Normal(x, 1) + end + + mdl = test_model() + nruns = 10 + ndraws = 100 + maxiters = 50 + + pfs = EpiAware._run_manypathfinder( + mdl; nruns = nruns, ndraws = ndraws, maxiters = maxiters) + + best_pf = EpiAware._get_best_elbo_pathfinder(pfs) + @test best_pf isa PathfinderResult +end +@testitem "Testing manypathfinder function" begin + using Turing, Pathfinder, HypothesisTests + @testset "Test model works" begin + @model function test_model() + x ~ Normal(0, 1) + y ~ Normal(x, 1) + end + + mdl = test_model() + + nruns = 4 + ndraws = 10 + nchains = 4 + maxiters = 50 + max_tries = 100 + + best_pf = manypathfinder(mdl, ndraws; nruns = nruns, nchains = nchains, + maxiters = maxiters, max_tries = max_tries) + + @test best_pf isa PathfinderResult + end + + @testset "Does good job finding simple distribution" begin + @model function basic_normal() + x ~ Normal(0, 1) + end + mdl = basic_normal() + nruns = 4 + ndraws = 2000 + nchains = 4 + maxiters = 50 + max_tries = 10 + + best_pf = manypathfinder(mdl, ndraws; nruns = nruns, nchains = nchains, + maxiters = maxiters, max_tries = max_tries) + + pathfinder_samples = best_pf.draws |> vec + ks_test_pval = ExactOneSampleKSTest(pathfinder_samples, Normal(0.0, 1)) |> pvalue + @test ks_test_pval > 1e-6 + end + + @testset "Check always fails for bad models and throws correct Exception" begin + @model function bad_model() + x ~ Normal(0, 1) + return sqrt(x) #<-fails + end + badmdl = bad_model() + + max_tries = 3 + nruns = 10 + ndraws = 100 + maxiters = 50 + nchains = 4 + + @test_throws "All pathfinder runs failed after $max_tries tries." begin + manypathfinder(badmdl, ndraws; nruns = nruns, nchains = nchains, + maxiters = maxiters, max_tries = max_tries) + end + end +end