From 0b7c8e18d8a793d1f4ac734956c93a67884d6e7b Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Fri, 1 Mar 2024 18:21:09 +0000 Subject: [PATCH 01/40] update --- EpiAware/docs/src/examples/getting_started.jl | 72 +++++++ EpiAware/docs/src/man/getting-started.md | 193 ++++++++++++++++++ .../toy_model_log_infs_RW.jl | 2 +- 3 files changed, 266 insertions(+), 1 deletion(-) create mode 100644 EpiAware/docs/src/examples/getting_started.jl diff --git a/EpiAware/docs/src/examples/getting_started.jl b/EpiAware/docs/src/examples/getting_started.jl new file mode 100644 index 000000000..938e2cfb2 --- /dev/null +++ b/EpiAware/docs/src/examples/getting_started.jl @@ -0,0 +1,72 @@ +### A Pluto.jl notebook ### +# v0.19.40 + +using Markdown +using InteractiveUtils + +# ╔═╡ c593a2a0-d7f5-11ee-0931-d9f65ae84a72 +# hideall +let + docs_dir = dirname(dirname(@__DIR__)) + pkg_dir = dirname(docs_dir) + + using Pkg: Pkg + Pkg.activate(docs_dir) + Pkg.develop(; path = pkg_dir) + Pkg.instantiate() +end; + +# ╔═╡ 3ebc8384-f73d-4597-83a7-07a3744fed61 +md" +# Getting stated with `EpiAware` + +This is a toy model for demonstrating current functionality of EpiAware package. + +## Generative Model without data + +### Latent Process + +The latent process is a random walk defined by a Turing model `random_walk` of specified length `n`. + +_Unfixed parameters_: +- `σ²_RW`: The variance of the random walk process. Current defauly prior is +- `init_rw_value`: The initial value of the random walk process. +- `ϵ_t`: The random noise vector. + +```math +\begin{align} +X(t) &= X(0) + \sigma_{RW} \sum_{t= 1}^n \epsilon_t \\ +X(0) &\sim \mathcal{N}(0, 1) \\ +\epsilon_t &\sim \mathcal{N}(0, 1) \\ +\sigma_{RW} &\sim \text{HalfNormal}(0.05). +\end{align} +``` + +### Log-Infections Model + +The log-infections model is defined by a Turing model `log_infections` that takes the observed data `y_t` (or `missing` value), +an `EpiModel` object `epi_model`, and a `latent_model` model. In this case the latent process is a random walk model. + +It also accepts optional arguments for the `process_priors`, `transform_function`, `pos_shift`, `neg_bin_cluster_factor`, and `neg_bin_cluster_factor_prior`. + +```math +\log I_t = \exp(X(t)). +``` + +### Observation model + +The observation model is a negative binomial distribution with mean `μ` and cluster factor `r`. Delays are implemented +as the action of a sparse kernel on the infections $I(t)$. The delay kernel is contained in an `EpiModel` struct. + +```math +\begin{align} +y_t &\sim \text{NegBinomial}(\mu = \sum_s\geq 0 K[t, t-s] I(s), r), +r &\sim \text{Gamma}(3, 0.05/3). +\end{align} +``` + +" + +# ╔═╡ Cell order: +# ╟─c593a2a0-d7f5-11ee-0931-d9f65ae84a72 +# ╠═3ebc8384-f73d-4597-83a7-07a3744fed61 diff --git a/EpiAware/docs/src/man/getting-started.md b/EpiAware/docs/src/man/getting-started.md index e69de29bb..65b6b7a99 100644 --- a/EpiAware/docs/src/man/getting-started.md +++ b/EpiAware/docs/src/man/getting-started.md @@ -0,0 +1,193 @@ +```@meta +EditURL = "../../../test/predictive_checking/toy_model_log_infs_RW.jl" +``` + +# Getting started + +This is a toy model for demonstrating current functionality of EpiAware package. + +## Generative Model without data + +### Latent Process + +The latent process is a random walk defined by a Turing model `random_walk` of specified length `n`. + +_Unfixed parameters_: +- `σ²_RW`: The variance of the random walk process. Current defauly prior is +- `init_rw_value`: The initial value of the random walk process. +- `ϵ_t`: The random noise vector. + +```math +\begin{align} +X(t) &= X(0) + \sigma_{RW} \sum_{t= 1}^n \epsilon_t \\ +X(0) &\sim \mathcal{N}(0, 1) \\ +\epsilon_t &\sim \mathcal{N}(0, 1) \\ +\sigma_{RW} &\sim \text{HalfNormal}(0.05). +\end{align} +``` + +### Log-Infections Model + +The log-infections model is defined by a Turing model `log_infections` that takes the observed data `y_t` (or `missing` value), +an `EpiModel` object `epi_model`, and a `latent_model` model. In this case the latent process is a random walk model. + +It also accepts optional arguments for the `process_priors`, `transform_function`, `pos_shift`, `neg_bin_cluster_factor`, and `neg_bin_cluster_factor_prior`. + +```math +\log I_t = \exp(X(t)). +``` + +### Observation model + +The observation model is a negative binomial distribution with mean `μ` and cluster factor `r`. Delays are implemented +as the action of a sparse kernel on the infections $I(t)$. The delay kernel is contained in an `EpiModel` struct. + +```math +\begin{align} +y_t &\sim \text{NegBinomial}(\mu = \sum_s\geq 0 K[t, t-s] I(s), r), +r &\sim \text{Gamma}(3, 0.05/3). +\end{align} +``` + +## Load dependencies + +This script should be run from Test environment mode. If not, run the following command: + + +````@example toy_model_log_infs_RW +using EpiAware +using Turing +using Distributions +using StatsPlots +using Random +using DynamicPPL +using Statistics +using DataFramesMeta +using CSV # For outputting the MCMC chain + +Random.seed!(0) +```` + +## Create an `EpiModel` struct + +- Medium length generation interval distribution. +- Median 2 day, std 4.3 day delay distribution. + +````@example toy_model_log_infs_RW +truth_GI = Gamma(2, 5) +model_data = EpiData(truth_GI, + D_gen = 10.0) + +log_I0_prior = Normal(0.0, 1.0) +epi_model = DirectInfections(model_data, log_I0_prior) +```` + +## Define the data generating process + +In this case we use the `DirectInfections` model. + +````@example toy_model_log_infs_RW +rwp = EpiAware.RandomWalk(Normal(), + truncated(Normal(0.0, 0.01), 0.0, 0.5)) + +#Define the observation model - no delay model +time_horizon = 100 +obs_model = EpiAware.DelayObservations([1.0], + time_horizon, + truncated(Gamma(5, 0.05 / 5), 1e-3, 1.0)) +```` + +## Generate a `Turing` `Model` +We don't have observed data, so we use `missing` value for `y_t`. + +````@example toy_model_log_infs_RW +log_infs_model = make_epi_aware(missing, time_horizon, ; epi_model = epi_model, + latent_model_model = rwp, observation_model = obs_model, + pos_shift = 1e-6) +```` + +## Sample from the model +I define a fixed version of the model with initial infections set to 1 and variance of the random walk process set to 0.1. +We can sample from the model using the `rand` function, and plot the generated infections against generated cases. + +We can get the generated infections using `generated_quantities` function. Because the observed +cases are "defined" with a `~` operator they can be accessed directly from the randomly sampled +process. + +````@example toy_model_log_infs_RW +cond_toy = fix(log_infs_model, (init = log(1.0), σ²_RW = 0.1)) +random_epidemic = rand(cond_toy) +gen = generated_quantities(cond_toy, random_epidemic) + +plot(gen.I_t, + label = "I_t", + xlabel = "Time", + ylabel = "Infections", + title = "Generated Infections") +scatter!(random_epidemic.y_t, lab = "generated cases") +```` + +## Inference + +We treat the generated data as observed data and attempt to infer underlying infections. + +````@example toy_model_log_infs_RW +truth_data = random_epidemic.y_t + +model = make_epi_aware(truth_data, time_horizon, ; epi_model = epi_model, + latent_model_model = rwp, observation_model = obs_model, + pos_shift = 1e-6) +@time chn = sample(model, + NUTS(; adtype = AutoReverseDiff(true)), + MCMCThreads(), + 250, + 4; + drop_warmup = true) +```` + +## Postior predictive checking + +We check the posterior predictive checking by plotting the predicted cases against the observed cases. + +````@example toy_model_log_infs_RW +predicted_y_t = mapreduce(hcat, generated_quantities(log_infs_model, chn)) do gen + gen.generated_y_t +end + +plot(predicted_y_t, c = :grey, alpha = 0.05, lab = "") +scatter!(truth_data, + lab = "Observed cases", + xlabel = "Time", + ylabel = "Cases", + title = "Posterior Predictive Checking", + ylims = (-0.5, maximum(truth_data) * 2.5)) +```` + +## Underlying inferred infections + +````@example toy_model_log_infs_RW +predicted_I_t = mapreduce(hcat, generated_quantities(log_infs_model, chn)) do gen + gen.I_t +end + +plot(predicted_I_t, c = :grey, alpha = 0.05, lab = "") +scatter!(gen.I_t, + lab = "Actual infections", + xlabel = "Time", + ylabel = "Unobserved Infections", + title = "Posterior Predictive Checking", + ylims = (-0.5, maximum(gen.I_t) * 1.5)) +```` + +## Outputing the MCMC chain +We can use `spread_draws` to convert the MCMC chain into a tidybayes format. + +````@example toy_model_log_infs_RW +df_chn = spread_draws(chn) +save_path = joinpath(@__DIR__, "assets/toy_model_log_infs_RW_draws.csv") +CSV.write(save_path, df_chn) +```` + +--- + +*This page was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).* diff --git a/EpiAware/test/predictive_checking/toy_model_log_infs_RW.jl b/EpiAware/test/predictive_checking/toy_model_log_infs_RW.jl index e686bf41a..cce09e19a 100644 --- a/EpiAware/test/predictive_checking/toy_model_log_infs_RW.jl +++ b/EpiAware/test/predictive_checking/toy_model_log_infs_RW.jl @@ -83,7 +83,7 @@ truth_GI = Gamma(2, 5) model_data = EpiData(truth_GI, D_gen = 10.0) -log_I0_prior = Normal(0.0, 1.0) +log_I0_prior = Normal(0.0, 10.0) epi_model = DirectInfections(model_data, log_I0_prior) #= From d86b12d69e33d822974ae7af30e76ea98de82b23 Mon Sep 17 00:00:00 2001 From: Samuel Brand <48288458+SamuelBrand1@users.noreply.github.com> Date: Mon, 4 Mar 2024 07:12:21 +0000 Subject: [PATCH 02/40] added doc deps --- EpiAware/docs/Project.toml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/EpiAware/docs/Project.toml b/EpiAware/docs/Project.toml index 3d7eeda93..4ec707df0 100644 --- a/EpiAware/docs/Project.toml +++ b/EpiAware/docs/Project.toml @@ -1,4 +1,11 @@ [deps] Changelog = "5217a498-cd5d-4ec6-b8c2-9b85a09b6e3e" +DataFramesMeta = "1313f7d8-7da2-5740-9ea0-a2ca25f37964" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" EpiAware = "b2eeebe4-5992-4301-9193-7ebc9f62c855" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd" +Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" From dfeaf541f631163196cd9f514fb4804f01115019 Mon Sep 17 00:00:00 2001 From: Samuel Brand <48288458+SamuelBrand1@users.noreply.github.com> Date: Mon, 4 Mar 2024 07:12:53 +0000 Subject: [PATCH 03/40] started `getting started` notebook --- EpiAware/docs/src/examples/getting_started.jl | 196 ++++++++++++++++-- 1 file changed, 177 insertions(+), 19 deletions(-) diff --git a/EpiAware/docs/src/examples/getting_started.jl b/EpiAware/docs/src/examples/getting_started.jl index 938e2cfb2..113990f3c 100644 --- a/EpiAware/docs/src/examples/getting_started.jl +++ b/EpiAware/docs/src/examples/getting_started.jl @@ -16,43 +16,131 @@ let Pkg.instantiate() end; +# ╔═╡ da479d8d-1312-4b98-b0af-5be52dffaf3f +begin + using EpiAware + using Turing + using Distributions + using StatsPlots + using Random + using DynamicPPL + using Statistics + using DataFramesMeta +end + # ╔═╡ 3ebc8384-f73d-4597-83a7-07a3744fed61 md" # Getting stated with `EpiAware` -This is a toy model for demonstrating current functionality of EpiAware package. +This tutorial introduces the basic functionality of `EpiAware`. `EpiAware` is a package for making inferences on epidemiological case/determined infection data using a model-based approach. + +## `EpiAware` models +The models we consider are discrete-time $t = 1,\dots, T$ with a latent random process, $Z_t$ generating stochasticity in the number of new infections $I_t$ at each time step. Observations are treated as downstream random variables determined by the actual infections and a model of infection to observation delay. + +#### Mathematical definition +```math +\begin{align} +Z_\cdot &\sim \mathcal{P}(\mathbb{R}^T) | \theta_Z, \\ +I_0 &\sim f_0(\theta_I), \\ +I_t &\sim g_I(\{I_s, Z_s\}_{s < t}, \theta_{I}), \\ +y_t &\sim f_O(\{I_s\}_{s \leq t}, \theta_{O}). +\end{align} +``` +Where, $\mathcal{P}(\mathbb{R}^T) | \theta_Z$ is a parametric process on $\mathbb{R}^T$. $f_0$ and $f_O$ are parametric distributions on, respectively, the initial number of infections and the observed case data conditional on underlying infections. $g_I$ is distribution of new infections conditional on infections and latent process in the past. Note that we assume that new infections are conditional on the strict past, whereas new observations can depend on infections on the same time step. -## Generative Model without data +#### Code structure outline -### Latent Process +An `EpiAware` model in code is created from three modular components: -The latent process is a random walk defined by a Turing model `random_walk` of specified length `n`. +- A `LatentModel`: This defines the model for $Z_\cdot$. This chooses $\mathcal{P}(\mathbb{R}^T) | \theta_Z$. +- An `EpiModel`: This defines a generative process for infections conditional on the latent process. This chooses $f_0(\theta_I)$, and $g_I(\{I_s, Z_s\}_{s < t}, \theta_{I})$. +- An `ObservationModel`: This defines the observation model. This chooses $f_O({I_s}_{s \leq t}, \theta_{O})$ -_Unfixed parameters_: -- `σ²_RW`: The variance of the random walk process. Current defauly prior is -- `init_rw_value`: The initial value of the random walk process. -- `ϵ_t`: The random noise vector. +#### Reproductive number +`EpiAware` models do not need to specify a time-varying reproductive number $\mathcal{R}_t$ to generate $I_\cdot$, however, this is often a quantity of interest. When not directly used we will typically back-calculate $\mathcal{R}_t$ from the generated infections: +```math +\mathcal{R}_t = {I_t \over \sum_{s \geq 1} g_s I_{t-s} }. +``` + +Where $g_s$ is a discrete generation interval. For this reason, even when not using a reproductive number approach directly, we ask for a generation interval. +" + +# ╔═╡ 5a0d5ab8-e985-4126-a1ac-58fe08beee38 +md" +## Random walk `LatentModel` + +As an example, we choose the latent process as a random walk with parameters $\theta_Z$: + +- ``Z_0``: Initial position. +- ``\sigma^2_{Z}``: The step-size variance. + +Conditional on the parameters the random walk is then generated by white noise: ```math \begin{align} -X(t) &= X(0) + \sigma_{RW} \sum_{t= 1}^n \epsilon_t \\ -X(0) &\sim \mathcal{N}(0, 1) \\ -\epsilon_t &\sim \mathcal{N}(0, 1) \\ -\sigma_{RW} &\sim \text{HalfNormal}(0.05). +Z_t &= Z_0 + \sigma_{RW} \sum_{t= 1}^T \epsilon_t, \\ +\epsilon_t &\sim \mathcal{N}(0,1). \end{align} ``` -### Log-Infections Model +In `EpiAware` we provide a constructor for random walk latent models with priors for $\theta_Z$. We choose priors, +```math +\begin{align} +Z_0 &\sim \mathcal{N}(0,1),\\ +\sigma^2_Z &\sim \text{HalfNormal}(0.01). +\end{align} +``` +" -The log-infections model is defined by a Turing model `log_infections` that takes the observed data `y_t` (or `missing` value), -an `EpiModel` object `epi_model`, and a `latent_model` model. In this case the latent process is a random walk model. +# ╔═╡ 56ae496b-0094-460b-89cb-526627991717 +rwp = EpiAware.RandomWalk(Normal(), + truncated(Normal(0.0, 0.01), 0.0, 0.5)) -It also accepts optional arguments for the `process_priors`, `transform_function`, `pos_shift`, `neg_bin_cluster_factor`, and `neg_bin_cluster_factor_prior`. +# ╔═╡ 767beffd-1ef5-4e6c-9ac6-edb52e60fb44 +md" +## Direct infection `EpiModel` +This is a simple model where the unobserved log-infections are directly generated by the latent process $Z$. ```math -\log I_t = \exp(X(t)). +\log I_t = \log I_0 + Z_t. ``` +As discussed above, we still ask for a defined generation interval, which can be used to calculate $\mathcal{R}_t$. + +" + +# ╔═╡ 9e43cbe3-94de-44fc-a788-b9c7adb34218 +truth_GI = Gamma(2, 5) + +# ╔═╡ f067284f-a1a6-44a6-9b79-f8c2de447673 +md" +The `EpiData` constructor performs double interval censoring to convert our _continuous_ estimate of the generation interval into a discretized version. We also implement right truncation using the keyword `D_gen`. +" + +# ╔═╡ c0662d48-4b54-4b6d-8c91-ddf4b0e3aa43 +model_data = EpiData(truth_GI, + D_gen = 10.0) + +# ╔═╡ fd72094f-1b95-4d07-a8b0-ef47dc560dfc +md" +We can supply a prior for the initial log_infections. +" + +# ╔═╡ 6639e66f-7725-4976-81b2-6472419d1a62 +log_I0_prior = Normal(0.0, 10.0) + +# ╔═╡ df5e59f8-3185-4bed-9cca-7c266df17cec +md" +And construct the `EpiModel`. +" + +# ╔═╡ 6fbdd8e6-2323-4352-9185-1f31a9cf9012 +epi_model = DirectInfections(model_data, log_I0_prior) + +# ╔═╡ 5e62a50a-71f4-4902-b1c9-fdf51fe145fa +md" + + ### Observation model The observation model is a negative binomial distribution with mean `μ` and cluster factor `r`. Delays are implemented @@ -64,9 +152,79 @@ y_t &\sim \text{NegBinomial}(\mu = \sum_s\geq 0 K[t, t-s] I(s), r), r &\sim \text{Gamma}(3, 0.05/3). \end{align} ``` - " +# ╔═╡ c7580ae6-0db5-448e-8b20-4dd6fcdb1ae0 +time_horizon = 100 + +# ╔═╡ 448669bc-99f4-4823-b15e-fcc9040ba31b +obs_model = EpiAware.DelayObservations([1.0], + time_horizon, + truncated(Gamma(5, 0.05 / 5), 1e-3, 1.0)) + +# ╔═╡ abeff860-58c3-4644-9325-66ffd4446b6d +log_infs_model = make_epi_aware(missing, time_horizon, ; epi_model = epi_model, + latent_model_model = rwp, observation_model = obs_model, + pos_shift = 1e-6) + +# ╔═╡ 7e0e6012-8648-4f84-a25a-8b0138c4b72a +cond_toy = fix(log_infs_model, (init_incidence = log(10.0), σ²_RW = 0.1, init_rw = 0.0)) + +# ╔═╡ b20c28be-7b07-410c-a33b-ea5ad6828c12 +random_epidemic = rand(cond_toy) + +# ╔═╡ d073e63b-62da-4743-ace0-78ef7806bc0b +gen = generated_quantities(cond_toy, random_epidemic) + +# ╔═╡ f68b4e41-ac5c-42cd-a8c2-8761d66f7543 +plot(gen.I_t, + label = "I_t", + xlabel = "Time", + ylabel = "Infections", + title = "Generated Infections") + +# ╔═╡ 31764672-3073-4280-8ab2-d42544be1629 +scatter!(random_epidemic.y_t, lab = "generated cases") + +# ╔═╡ c8ce0d46-a160-4c40-a055-69b3d10d1770 +truth_data = random_epidemic.y_t + +# ╔═╡ b4033728-b321-4100-8194-1fd9fe2d268d +model = make_epi_aware(truth_data, time_horizon, ; epi_model = epi_model, + latent_model_model = rwp, observation_model = obs_model, + pos_shift = 1e-6) | (init_rw = 0.0,) + +# ╔═╡ 3eb5ec5e-aae7-478e-84fb-80f2e9f85b4c +chn = sample(mdl, + NUTS(; adtype = AutoReverseDiff(true)), + MCMCThreads(), + 250, + 4; + drop_warmup = true) + # ╔═╡ Cell order: # ╟─c593a2a0-d7f5-11ee-0931-d9f65ae84a72 -# ╠═3ebc8384-f73d-4597-83a7-07a3744fed61 +# ╟─3ebc8384-f73d-4597-83a7-07a3744fed61 +# ╠═da479d8d-1312-4b98-b0af-5be52dffaf3f +# ╟─5a0d5ab8-e985-4126-a1ac-58fe08beee38 +# ╠═56ae496b-0094-460b-89cb-526627991717 +# ╟─767beffd-1ef5-4e6c-9ac6-edb52e60fb44 +# ╠═9e43cbe3-94de-44fc-a788-b9c7adb34218 +# ╟─f067284f-a1a6-44a6-9b79-f8c2de447673 +# ╠═c0662d48-4b54-4b6d-8c91-ddf4b0e3aa43 +# ╟─fd72094f-1b95-4d07-a8b0-ef47dc560dfc +# ╠═6639e66f-7725-4976-81b2-6472419d1a62 +# ╟─df5e59f8-3185-4bed-9cca-7c266df17cec +# ╠═6fbdd8e6-2323-4352-9185-1f31a9cf9012 +# ╠═5e62a50a-71f4-4902-b1c9-fdf51fe145fa +# ╠═c7580ae6-0db5-448e-8b20-4dd6fcdb1ae0 +# ╠═448669bc-99f4-4823-b15e-fcc9040ba31b +# ╠═abeff860-58c3-4644-9325-66ffd4446b6d +# ╠═7e0e6012-8648-4f84-a25a-8b0138c4b72a +# ╠═b20c28be-7b07-410c-a33b-ea5ad6828c12 +# ╠═d073e63b-62da-4743-ace0-78ef7806bc0b +# ╠═f68b4e41-ac5c-42cd-a8c2-8761d66f7543 +# ╠═31764672-3073-4280-8ab2-d42544be1629 +# ╠═c8ce0d46-a160-4c40-a055-69b3d10d1770 +# ╠═b4033728-b321-4100-8194-1fd9fe2d268d +# ╠═3eb5ec5e-aae7-478e-84fb-80f2e9f85b4c From 2ce57b42d39b2071031155e655f29d683ab45c69 Mon Sep 17 00:00:00 2001 From: Samuel Brand <48288458+SamuelBrand1@users.noreply.github.com> Date: Mon, 4 Mar 2024 13:39:22 +0000 Subject: [PATCH 04/40] Update getting_started.jl --- EpiAware/docs/src/examples/getting_started.jl | 96 +++++++++++++++---- 1 file changed, 78 insertions(+), 18 deletions(-) diff --git a/EpiAware/docs/src/examples/getting_started.jl b/EpiAware/docs/src/examples/getting_started.jl index 113990f3c..4706f36f1 100644 --- a/EpiAware/docs/src/examples/getting_started.jl +++ b/EpiAware/docs/src/examples/getting_started.jl @@ -127,7 +127,7 @@ We can supply a prior for the initial log_infections. " # ╔═╡ 6639e66f-7725-4976-81b2-6472419d1a62 -log_I0_prior = Normal(0.0, 10.0) +log_I0_prior = Normal(log(100.0), 1.0) # ╔═╡ df5e59f8-3185-4bed-9cca-7c266df17cec md" @@ -162,13 +162,21 @@ obs_model = EpiAware.DelayObservations([1.0], time_horizon, truncated(Gamma(5, 0.05 / 5), 1e-3, 1.0)) +# ╔═╡ e49713e8-4840-4083-8e3f-fc52d791be7b +md" +## Generation +" + # ╔═╡ abeff860-58c3-4644-9325-66ffd4446b6d -log_infs_model = make_epi_aware(missing, time_horizon, ; epi_model = epi_model, +full_epi_aware_mdl = make_epi_aware(missing, time_horizon; epi_model = epi_model, latent_model_model = rwp, observation_model = obs_model, pos_shift = 1e-6) +# ╔═╡ 36b34fd2-2891-42ca-b5dc-abb482e516ee +fixed_parameters = (init_incidence = log(100.0), σ²_RW = 0.1^2, init_rw = 0.0) + # ╔═╡ 7e0e6012-8648-4f84-a25a-8b0138c4b72a -cond_toy = fix(log_infs_model, (init_incidence = log(10.0), σ²_RW = 0.1, init_rw = 0.0)) +cond_toy = fix(full_epi_aware_mdl, fixed_parameters) # ╔═╡ b20c28be-7b07-410c-a33b-ea5ad6828c12 random_epidemic = rand(cond_toy) @@ -177,31 +185,79 @@ random_epidemic = rand(cond_toy) gen = generated_quantities(cond_toy, random_epidemic) # ╔═╡ f68b4e41-ac5c-42cd-a8c2-8761d66f7543 -plot(gen.I_t, - label = "I_t", - xlabel = "Time", - ylabel = "Infections", - title = "Generated Infections") +let + plot(gen.I_t, + label = "I_t", + xlabel = "Time", + ylabel = "Infections", + title = "Generated Infections") + scatter!(random_epidemic.y_t, lab = "generated cases") +end -# ╔═╡ 31764672-3073-4280-8ab2-d42544be1629 -scatter!(random_epidemic.y_t, lab = "generated cases") +# ╔═╡ b5bc8f05-b538-4abf-aa84-450bf2dff3d9 +md" +## Inference +" # ╔═╡ c8ce0d46-a160-4c40-a055-69b3d10d1770 truth_data = random_epidemic.y_t # ╔═╡ b4033728-b321-4100-8194-1fd9fe2d268d -model = make_epi_aware(truth_data, time_horizon, ; epi_model = epi_model, - latent_model_model = rwp, observation_model = obs_model, - pos_shift = 1e-6) | (init_rw = 0.0,) +inference_mdl = fix( + make_epi_aware(truth_data, time_horizon; epi_model = epi_model, + latent_model_model = rwp, observation_model = obs_model, + pos_shift = 1e-6), + (init_rw = 0.0,)) # ╔═╡ 3eb5ec5e-aae7-478e-84fb-80f2e9f85b4c -chn = sample(mdl, +chn = sample(inference_mdl, NUTS(; adtype = AutoReverseDiff(true)), MCMCThreads(), - 250, - 4; + 500, + 2; + n_warmup = 100, drop_warmup = true) +# ╔═╡ e9df22b8-8e4d-4ab7-91ea-c01f2239b3e5 +let + post_check_mdl = fix(full_epi_aware_mdl, (init_rw = 0.0,)) + post_check_y_t = mapreduce(hcat, generated_quantities(full_epi_aware_mdl, chn)) do gen + gen.generated_y_t + end + + predicted_I_t = mapreduce(hcat, generated_quantities(inference_mdl, chn)) do gen + gen.I_t + end + + p1 = plot(post_check_y_t, c = :grey, alpha = 0.05, lab = "") + scatter!(p1, truth_data, + lab = "Observed cases", + xlabel = "Time", + ylabel = "Cases", + title = "Post. predictive checking: cases", + ylims = (-0.5, maximum(truth_data) * 2.5), + c = :green) + + p2 = plot(predicted_I_t, c = :grey, alpha = 0.05, lab = "") + scatter!(p2, gen.I_t, + lab = "Actual infections", + xlabel = "Time", + ylabel = "Unobserved Infections", + title = "Post. predictions: infections", + ylims = (-0.5, maximum(gen.I_t) * 1.5), + c = :red) + + plot(p1, p2, layout = (2, 1)) +end + +# ╔═╡ 10d8fe24-83a6-47ac-97b7-a374481473d3 +let + var_samples = chn[:σ²_RW] |> vec + histogram(var_samples, bins = 50, norm = :pdf) + vline!([fixed_parameters.:σ²_RW]) + plot!(rwp.var_prior) +end + # ╔═╡ Cell order: # ╟─c593a2a0-d7f5-11ee-0931-d9f65ae84a72 # ╟─3ebc8384-f73d-4597-83a7-07a3744fed61 @@ -219,12 +275,16 @@ chn = sample(mdl, # ╠═5e62a50a-71f4-4902-b1c9-fdf51fe145fa # ╠═c7580ae6-0db5-448e-8b20-4dd6fcdb1ae0 # ╠═448669bc-99f4-4823-b15e-fcc9040ba31b +# ╟─e49713e8-4840-4083-8e3f-fc52d791be7b # ╠═abeff860-58c3-4644-9325-66ffd4446b6d +# ╠═36b34fd2-2891-42ca-b5dc-abb482e516ee # ╠═7e0e6012-8648-4f84-a25a-8b0138c4b72a # ╠═b20c28be-7b07-410c-a33b-ea5ad6828c12 # ╠═d073e63b-62da-4743-ace0-78ef7806bc0b -# ╠═f68b4e41-ac5c-42cd-a8c2-8761d66f7543 -# ╠═31764672-3073-4280-8ab2-d42544be1629 +# ╟─f68b4e41-ac5c-42cd-a8c2-8761d66f7543 +# ╟─b5bc8f05-b538-4abf-aa84-450bf2dff3d9 # ╠═c8ce0d46-a160-4c40-a055-69b3d10d1770 # ╠═b4033728-b321-4100-8194-1fd9fe2d268d # ╠═3eb5ec5e-aae7-478e-84fb-80f2e9f85b4c +# ╠═e9df22b8-8e4d-4ab7-91ea-c01f2239b3e5 +# ╠═10d8fe24-83a6-47ac-97b7-a374481473d3 From 9089270155d5c814e88f6f1d81c3e266496ca5a2 Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Tue, 5 Mar 2024 12:43:58 +0000 Subject: [PATCH 05/40] Getting started example --- EpiAware/docs/src/examples/getting_started.jl | 172 ++++++++++++++---- 1 file changed, 136 insertions(+), 36 deletions(-) diff --git a/EpiAware/docs/src/examples/getting_started.jl b/EpiAware/docs/src/examples/getting_started.jl index 4706f36f1..200c20c27 100644 --- a/EpiAware/docs/src/examples/getting_started.jl +++ b/EpiAware/docs/src/examples/getting_started.jl @@ -26,6 +26,7 @@ begin using DynamicPPL using Statistics using DataFramesMeta + using LinearAlgebra end # ╔═╡ 3ebc8384-f73d-4597-83a7-07a3744fed61 @@ -94,7 +95,7 @@ Z_0 &\sim \mathcal{N}(0,1),\\ # ╔═╡ 56ae496b-0094-460b-89cb-526627991717 rwp = EpiAware.RandomWalk(Normal(), - truncated(Normal(0.0, 0.01), 0.0, 0.5)) + truncated(Normal(0.0, 0.02), 0.0, Inf)) # ╔═╡ 767beffd-1ef5-4e6c-9ac6-edb52e60fb44 md" @@ -118,8 +119,7 @@ The `EpiData` constructor performs double interval censoring to convert our _con " # ╔═╡ c0662d48-4b54-4b6d-8c91-ddf4b0e3aa43 -model_data = EpiData(truth_GI, - D_gen = 10.0) +model_data = EpiData(truth_GI, D_gen = 10.0) # ╔═╡ fd72094f-1b95-4d07-a8b0-ef47dc560dfc md" @@ -141,52 +141,81 @@ epi_model = DirectInfections(model_data, log_I0_prior) md" -### Observation model +### Delayed Observations `ObservationModel` -The observation model is a negative binomial distribution with mean `μ` and cluster factor `r`. Delays are implemented -as the action of a sparse kernel on the infections $I(t)$. The delay kernel is contained in an `EpiModel` struct. +The observation model is a negative binomial distribution with mean `μ` and cluster factor `1 / r`. Delays are implemented +as the action of a sparse kernel on the infections $I(t)$. ```math \begin{align} -y_t &\sim \text{NegBinomial}(\mu = \sum_s\geq 0 K[t, t-s] I(s), r), -r &\sim \text{Gamma}(3, 0.05/3). +y_t &\sim \text{NegBinomial}(\mu = \sum_{s\geq 0} K[t, t-s] I(s), r), \\ +1 / r &\sim \text{Gamma}(3, 0.05/3). \end{align} ``` " +# ╔═╡ e813d547-6100-4c43-b84c-8cebe306bda8 +md" +We also set up the inference to occur over 100 days. +" + # ╔═╡ c7580ae6-0db5-448e-8b20-4dd6fcdb1ae0 time_horizon = 100 +# ╔═╡ 0aa3fcbd-0831-45b8-9a2c-7ffbabf5895f +md" +We choose a simple observation model where infections are observed 0, 1, 2, 3 days later with equal probability. +" + # ╔═╡ 448669bc-99f4-4823-b15e-fcc9040ba31b -obs_model = EpiAware.DelayObservations([1.0], +obs_model = EpiAware.DelayObservations( + fill(0.25, 4), time_horizon, - truncated(Gamma(5, 0.05 / 5), 1e-3, 1.0)) + truncated(Gamma(5, 0.05 / 5), 1e-3, 1.0) +) # ╔═╡ e49713e8-4840-4083-8e3f-fc52d791be7b md" -## Generation +## Generate cases from the `EpiAware` model + +Having chosen an `EpiModel`, `LatentModel` and `ObservationModel`, we can implement the model as a [`Turing`](https://turinglang.org/dev/) model using `make_epi_aware`. + +By giving `missing` to the first argument, we indicate that case data will be _generated_ from the model rather than treated as fixed. " # ╔═╡ abeff860-58c3-4644-9325-66ffd4446b6d -full_epi_aware_mdl = make_epi_aware(missing, time_horizon; epi_model = epi_model, - latent_model_model = rwp, observation_model = obs_model, - pos_shift = 1e-6) +full_epi_aware_mdl = make_epi_aware(missing, time_horizon; + epi_model = epi_model, + latent_model_model = rwp, + observation_model = obs_model) + +# ╔═╡ 821628fb-8044-48b0-aa4f-0b7b57a2f45a +md" +We choose some fixed parameters: +- Initial incidence is 100. +- In the direct infection model, the initial incidence and in the initial value of the random walk form a non-identifiable pair. Therefore, we fix $Z_0 = 0$. +" # ╔═╡ 36b34fd2-2891-42ca-b5dc-abb482e516ee -fixed_parameters = (init_incidence = log(100.0), σ²_RW = 0.1^2, init_rw = 0.0) +fixed_parameters = (rw_init = 0.0, init_incidence = log(100.0)) + +# ╔═╡ 0aadd9e3-7f91-4b45-9663-67d11335f0d0 +md" +We fix these parameters using `fix`, and generate a random epidemic. +" # ╔═╡ 7e0e6012-8648-4f84-a25a-8b0138c4b72a -cond_toy = fix(full_epi_aware_mdl, fixed_parameters) +cond_generative_model = fix(full_epi_aware_mdl, fixed_parameters) # ╔═╡ b20c28be-7b07-410c-a33b-ea5ad6828c12 -random_epidemic = rand(cond_toy) +random_epidemic = rand(cond_generative_model) # ╔═╡ d073e63b-62da-4743-ace0-78ef7806bc0b -gen = generated_quantities(cond_toy, random_epidemic) +true_infections = generated_quantities(cond_generative_model, random_epidemic).I_t # ╔═╡ f68b4e41-ac5c-42cd-a8c2-8761d66f7543 let - plot(gen.I_t, + plot(true_infections, label = "I_t", xlabel = "Time", ylabel = "Infections", @@ -197,6 +226,11 @@ end # ╔═╡ b5bc8f05-b538-4abf-aa84-450bf2dff3d9 md" ## Inference +Fixing $Z_0 = 0$ for the random walk was based on inference principles; in this model $Z_0$ and $\log I_0$ are non-identifiable. + +However, we now treat the generated data as `truth_data` and make inference without fixing any other parameters. + +We do the inference by MCMC/NUTS using the `Turing` NUTS sampler with default warm-up steps. " # ╔═╡ c8ce0d46-a160-4c40-a055-69b3d10d1770 @@ -205,23 +239,31 @@ truth_data = random_epidemic.y_t # ╔═╡ b4033728-b321-4100-8194-1fd9fe2d268d inference_mdl = fix( make_epi_aware(truth_data, time_horizon; epi_model = epi_model, - latent_model_model = rwp, observation_model = obs_model, - pos_shift = 1e-6), - (init_rw = 0.0,)) + latent_model_model = rwp, observation_model = obs_model), + (rw_init = 0.0,) +) # ╔═╡ 3eb5ec5e-aae7-478e-84fb-80f2e9f85b4c chn = sample(inference_mdl, NUTS(; adtype = AutoReverseDiff(true)), MCMCThreads(), - 500, - 2; - n_warmup = 100, + 250, + 4; drop_warmup = true) +# ╔═╡ 30498cc7-16a5-441a-b8cd-c19b220c60c1 +md" +### Predictive plotting + +We can spaghetti plot generated case data from the version of the model _which hasn't conditioned on case data_ using posterior parameters inferred from the version conditioned on observed data. This is known as _posterior predictive checking_, and is a useful diagnostic tool for Bayesian inference (see [here](http://www.stat.columbia.edu/~gelman/book/BDA3.pdf)). + +Because we are using synthetic data we can also plot the model predictions for the _unobserved_ infections and check that (at least in this example) we were able to capture some unobserved/latent variables in the process accurate. +" + # ╔═╡ e9df22b8-8e4d-4ab7-91ea-c01f2239b3e5 let - post_check_mdl = fix(full_epi_aware_mdl, (init_rw = 0.0,)) - post_check_y_t = mapreduce(hcat, generated_quantities(full_epi_aware_mdl, chn)) do gen + post_check_mdl = fix(full_epi_aware_mdl, (rw_init = 0.0,)) + post_check_y_t = mapreduce(hcat, generated_quantities(post_check_mdl, chn)) do gen gen.generated_y_t end @@ -235,27 +277,77 @@ let xlabel = "Time", ylabel = "Cases", title = "Post. predictive checking: cases", - ylims = (-0.5, maximum(truth_data) * 2.5), + ylims = (-0.5, maximum(truth_data) * 1.5), c = :green) p2 = plot(predicted_I_t, c = :grey, alpha = 0.05, lab = "") - scatter!(p2, gen.I_t, + scatter!(p2, true_infections, lab = "Actual infections", xlabel = "Time", ylabel = "Unobserved Infections", title = "Post. predictions: infections", - ylims = (-0.5, maximum(gen.I_t) * 1.5), + ylims = (-0.5, maximum(true_infections) * 1.5), c = :red) - plot(p1, p2, layout = (2, 1)) + plot(p1, p2, + layout = (1, 2), + size = (700, 400)) end +# ╔═╡ fd6321b1-4c3a-4123-b0dc-c45b951e0b80 +md" +As well as checking the posterior predictions for latent infections, we can also check how well inference recovered unknown parameters, such as the random walk variance or the cluster factor of the negative binomial observations. +" + # ╔═╡ 10d8fe24-83a6-47ac-97b7-a374481473d3 let - var_samples = chn[:σ²_RW] |> vec - histogram(var_samples, bins = 50, norm = :pdf) - vline!([fixed_parameters.:σ²_RW]) - plot!(rwp.var_prior) + parameters_to_plot = (:σ²_RW, :neg_bin_cluster_factor) + + plts = map(parameters_to_plot) do name + var_samples = chn[name] |> vec + histogram(var_samples, + bins = 50, + norm = :pdf, + lw = 0, + fillalpha = 0.5, + lab = "MCMC") + vline!([getfield(random_epidemic, name)], lab = "True value") + title!(string(name)) + end + plot(plts..., layout = (2, 1)) +end + +# ╔═╡ 81efe8ca-b753-4a12-bafc-a887a999377b +md" +## Reproductive number back-calculation + +As mentioned at the top, we _don't_ directly use the concept of reproductive numbers in this note. However, we can back-calculate the implied $\mathcal{R}(t)$ values, conditional on the specified generation interval being correct. + +Here we spaghetti plot posterior sampled time-varying reproductive numbers against the actual. +" + +# ╔═╡ 15b9f37f-8d5f-460d-8c28-d7f2271fd099 +let + n = epi_model.data.len_gen_int + Rt_denom = [dot(reverse(epi_model.data.gen_int), true_infections[(t - n):(t - 1)]) + for t in (n + 1):length(true_infections)] + true_Rt = true_infections[(n + 1):end] ./ Rt_denom + + predicted_Rt = mapreduce(hcat, generated_quantities(inference_mdl, chn)) do gen + _It = gen.I_t + _Rt_denom = [dot(reverse(epi_model.data.gen_int), _It[(t - n):(t - 1)]) + for t in (n + 1):length(_It)] + Rt = _It[(n + 1):end] ./ _Rt_denom + end + + plt = plot((n + 1):time_horizon, predicted_Rt, c = :grey, alpha = 0.05, lab = "") + plot!(plt, (n + 1):time_horizon, true_Rt, + lab = "true Rt", + xlabel = "Time", + ylabel = "Rt", + title = "Post. predictions: reproductive number", + c = :red, + lw = 2) end # ╔═╡ Cell order: @@ -272,12 +364,16 @@ end # ╠═6639e66f-7725-4976-81b2-6472419d1a62 # ╟─df5e59f8-3185-4bed-9cca-7c266df17cec # ╠═6fbdd8e6-2323-4352-9185-1f31a9cf9012 -# ╠═5e62a50a-71f4-4902-b1c9-fdf51fe145fa +# ╟─5e62a50a-71f4-4902-b1c9-fdf51fe145fa +# ╟─e813d547-6100-4c43-b84c-8cebe306bda8 # ╠═c7580ae6-0db5-448e-8b20-4dd6fcdb1ae0 +# ╟─0aa3fcbd-0831-45b8-9a2c-7ffbabf5895f # ╠═448669bc-99f4-4823-b15e-fcc9040ba31b # ╟─e49713e8-4840-4083-8e3f-fc52d791be7b # ╠═abeff860-58c3-4644-9325-66ffd4446b6d +# ╟─821628fb-8044-48b0-aa4f-0b7b57a2f45a # ╠═36b34fd2-2891-42ca-b5dc-abb482e516ee +# ╟─0aadd9e3-7f91-4b45-9663-67d11335f0d0 # ╠═7e0e6012-8648-4f84-a25a-8b0138c4b72a # ╠═b20c28be-7b07-410c-a33b-ea5ad6828c12 # ╠═d073e63b-62da-4743-ace0-78ef7806bc0b @@ -286,5 +382,9 @@ end # ╠═c8ce0d46-a160-4c40-a055-69b3d10d1770 # ╠═b4033728-b321-4100-8194-1fd9fe2d268d # ╠═3eb5ec5e-aae7-478e-84fb-80f2e9f85b4c +# ╟─30498cc7-16a5-441a-b8cd-c19b220c60c1 # ╠═e9df22b8-8e4d-4ab7-91ea-c01f2239b3e5 +# ╟─fd6321b1-4c3a-4123-b0dc-c45b951e0b80 # ╠═10d8fe24-83a6-47ac-97b7-a374481473d3 +# ╟─81efe8ca-b753-4a12-bafc-a887a999377b +# ╠═15b9f37f-8d5f-460d-8c28-d7f2271fd099 From e163b153abfa0215ab999537f601366a823a780f Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Tue, 5 Mar 2024 13:31:17 +0000 Subject: [PATCH 06/40] Add generated Pluto notebooks to gitignore --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index ffa29724b..95595a137 100644 --- a/.gitignore +++ b/.gitignore @@ -381,3 +381,6 @@ docs/site/ /Manifest.toml .DS_Store .vscode/settings.json + +#Ignore generated Pluto notebooks +EpiAware/docs/src/examples/*.md From 5774a217e44d01b78f35ebcfbe733d2a41bb1337 Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Tue, 5 Mar 2024 13:33:33 +0000 Subject: [PATCH 07/40] remove old version of example --- EpiAware/docs/src/man/getting-started.md | 193 ----------------------- 1 file changed, 193 deletions(-) delete mode 100644 EpiAware/docs/src/man/getting-started.md diff --git a/EpiAware/docs/src/man/getting-started.md b/EpiAware/docs/src/man/getting-started.md deleted file mode 100644 index 65b6b7a99..000000000 --- a/EpiAware/docs/src/man/getting-started.md +++ /dev/null @@ -1,193 +0,0 @@ -```@meta -EditURL = "../../../test/predictive_checking/toy_model_log_infs_RW.jl" -``` - -# Getting started - -This is a toy model for demonstrating current functionality of EpiAware package. - -## Generative Model without data - -### Latent Process - -The latent process is a random walk defined by a Turing model `random_walk` of specified length `n`. - -_Unfixed parameters_: -- `σ²_RW`: The variance of the random walk process. Current defauly prior is -- `init_rw_value`: The initial value of the random walk process. -- `ϵ_t`: The random noise vector. - -```math -\begin{align} -X(t) &= X(0) + \sigma_{RW} \sum_{t= 1}^n \epsilon_t \\ -X(0) &\sim \mathcal{N}(0, 1) \\ -\epsilon_t &\sim \mathcal{N}(0, 1) \\ -\sigma_{RW} &\sim \text{HalfNormal}(0.05). -\end{align} -``` - -### Log-Infections Model - -The log-infections model is defined by a Turing model `log_infections` that takes the observed data `y_t` (or `missing` value), -an `EpiModel` object `epi_model`, and a `latent_model` model. In this case the latent process is a random walk model. - -It also accepts optional arguments for the `process_priors`, `transform_function`, `pos_shift`, `neg_bin_cluster_factor`, and `neg_bin_cluster_factor_prior`. - -```math -\log I_t = \exp(X(t)). -``` - -### Observation model - -The observation model is a negative binomial distribution with mean `μ` and cluster factor `r`. Delays are implemented -as the action of a sparse kernel on the infections $I(t)$. The delay kernel is contained in an `EpiModel` struct. - -```math -\begin{align} -y_t &\sim \text{NegBinomial}(\mu = \sum_s\geq 0 K[t, t-s] I(s), r), -r &\sim \text{Gamma}(3, 0.05/3). -\end{align} -``` - -## Load dependencies - -This script should be run from Test environment mode. If not, run the following command: - - -````@example toy_model_log_infs_RW -using EpiAware -using Turing -using Distributions -using StatsPlots -using Random -using DynamicPPL -using Statistics -using DataFramesMeta -using CSV # For outputting the MCMC chain - -Random.seed!(0) -```` - -## Create an `EpiModel` struct - -- Medium length generation interval distribution. -- Median 2 day, std 4.3 day delay distribution. - -````@example toy_model_log_infs_RW -truth_GI = Gamma(2, 5) -model_data = EpiData(truth_GI, - D_gen = 10.0) - -log_I0_prior = Normal(0.0, 1.0) -epi_model = DirectInfections(model_data, log_I0_prior) -```` - -## Define the data generating process - -In this case we use the `DirectInfections` model. - -````@example toy_model_log_infs_RW -rwp = EpiAware.RandomWalk(Normal(), - truncated(Normal(0.0, 0.01), 0.0, 0.5)) - -#Define the observation model - no delay model -time_horizon = 100 -obs_model = EpiAware.DelayObservations([1.0], - time_horizon, - truncated(Gamma(5, 0.05 / 5), 1e-3, 1.0)) -```` - -## Generate a `Turing` `Model` -We don't have observed data, so we use `missing` value for `y_t`. - -````@example toy_model_log_infs_RW -log_infs_model = make_epi_aware(missing, time_horizon, ; epi_model = epi_model, - latent_model_model = rwp, observation_model = obs_model, - pos_shift = 1e-6) -```` - -## Sample from the model -I define a fixed version of the model with initial infections set to 1 and variance of the random walk process set to 0.1. -We can sample from the model using the `rand` function, and plot the generated infections against generated cases. - -We can get the generated infections using `generated_quantities` function. Because the observed -cases are "defined" with a `~` operator they can be accessed directly from the randomly sampled -process. - -````@example toy_model_log_infs_RW -cond_toy = fix(log_infs_model, (init = log(1.0), σ²_RW = 0.1)) -random_epidemic = rand(cond_toy) -gen = generated_quantities(cond_toy, random_epidemic) - -plot(gen.I_t, - label = "I_t", - xlabel = "Time", - ylabel = "Infections", - title = "Generated Infections") -scatter!(random_epidemic.y_t, lab = "generated cases") -```` - -## Inference - -We treat the generated data as observed data and attempt to infer underlying infections. - -````@example toy_model_log_infs_RW -truth_data = random_epidemic.y_t - -model = make_epi_aware(truth_data, time_horizon, ; epi_model = epi_model, - latent_model_model = rwp, observation_model = obs_model, - pos_shift = 1e-6) -@time chn = sample(model, - NUTS(; adtype = AutoReverseDiff(true)), - MCMCThreads(), - 250, - 4; - drop_warmup = true) -```` - -## Postior predictive checking - -We check the posterior predictive checking by plotting the predicted cases against the observed cases. - -````@example toy_model_log_infs_RW -predicted_y_t = mapreduce(hcat, generated_quantities(log_infs_model, chn)) do gen - gen.generated_y_t -end - -plot(predicted_y_t, c = :grey, alpha = 0.05, lab = "") -scatter!(truth_data, - lab = "Observed cases", - xlabel = "Time", - ylabel = "Cases", - title = "Posterior Predictive Checking", - ylims = (-0.5, maximum(truth_data) * 2.5)) -```` - -## Underlying inferred infections - -````@example toy_model_log_infs_RW -predicted_I_t = mapreduce(hcat, generated_quantities(log_infs_model, chn)) do gen - gen.I_t -end - -plot(predicted_I_t, c = :grey, alpha = 0.05, lab = "") -scatter!(gen.I_t, - lab = "Actual infections", - xlabel = "Time", - ylabel = "Unobserved Infections", - title = "Posterior Predictive Checking", - ylims = (-0.5, maximum(gen.I_t) * 1.5)) -```` - -## Outputing the MCMC chain -We can use `spread_draws` to convert the MCMC chain into a tidybayes format. - -````@example toy_model_log_infs_RW -df_chn = spread_draws(chn) -save_path = joinpath(@__DIR__, "assets/toy_model_log_infs_RW_draws.csv") -CSV.write(save_path, df_chn) -```` - ---- - -*This page was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).* From e7c60ce22632cee1fd8982f7a1a972702dbb6757 Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Tue, 5 Mar 2024 13:33:51 +0000 Subject: [PATCH 08/40] add Pluto and PlutoStaticHTML as deps --- EpiAware/docs/Project.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/EpiAware/docs/Project.toml b/EpiAware/docs/Project.toml index 4ec707df0..900ecfe73 100644 --- a/EpiAware/docs/Project.toml +++ b/EpiAware/docs/Project.toml @@ -6,6 +6,8 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" EpiAware = "b2eeebe4-5992-4301-9193-7ebc9f62c855" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Pluto = "c3e4b0f8-55cb-11ea-2926-15256bba5781" +PlutoStaticHTML = "359b1769-a58e-495b-9770-312e911026ad" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" From ee9e1ca06c3ea796c0fab558fb2e5e9e8ca95038 Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Tue, 5 Mar 2024 13:35:03 +0000 Subject: [PATCH 09/40] Adapt make and build files to generate Pluto notebooks statically and then render into docs --- EpiAware/docs/build.jl | 24 ++++++++++++++++++++++++ EpiAware/docs/make.jl | 10 +++++++++- EpiAware/docs/pages.jl | 2 +- 3 files changed, 34 insertions(+), 2 deletions(-) create mode 100644 EpiAware/docs/build.jl diff --git a/EpiAware/docs/build.jl b/EpiAware/docs/build.jl new file mode 100644 index 000000000..21d1ddcb4 --- /dev/null +++ b/EpiAware/docs/build.jl @@ -0,0 +1,24 @@ + +"""Run all Pluto notebooks (".jl" files) in `tutorials_dir` and write outputs to HTML files.""" +function build(target_subdir; _module = EpiAware) + target_dir = joinpath(pkgdir(_module), "docs", "src", target_subdir) + + @info "Building notebooks in $target_subdir" + # Evaluate notebooks in the same process to avoid having to recompile from scratch each time. + # This is similar to how Documenter and Franklin evaluate code. + # Note that things like method overrides and other global changes may leak between notebooks! + use_distributed = false + output_format = documenter_output + bopts = BuildOptions(target_dir; use_distributed, output_format) + build_notebooks(bopts) + return nothing +end + +"Return Markdown file links which can be passed to Documenter.jl." +function markdown_files(notebook_titles, target_subdir) + md_files = map(notebook_titles) do title + file = lowercase(replace(title, " " => '_')) + return joinpath(target_subdir, "$file.md") + end + return md_files +end diff --git a/EpiAware/docs/make.jl b/EpiAware/docs/make.jl index edc0449fc..012a779d6 100644 --- a/EpiAware/docs/make.jl +++ b/EpiAware/docs/make.jl @@ -1,8 +1,13 @@ using Documenter using EpiAware +using Pluto: Configuration.CompilerOptions +using PlutoStaticHTML include("changelog.jl") include("pages.jl") +include("build.jl") + +# build("examples") makedocs(; sitename = "EpiAware.jl", authors = "Samuel Brand, Zachary Susswein, Sam Abbott, and contributors", @@ -11,7 +16,10 @@ makedocs(; sitename = "EpiAware.jl", modules = [EpiAware], pages = pages, format = Documenter.HTML( - prettyurls = get(ENV, "CI", nothing) == "true" + prettyurls = get(ENV, "CI", nothing) == "true", + mathengine = Documenter.MathJax3(), + size_threshold = 600 * 2^10, + size_threshold_warn = 200 * 2^10 ) ) diff --git a/EpiAware/docs/pages.jl b/EpiAware/docs/pages.jl index db333fa73..9d379ceb6 100644 --- a/EpiAware/docs/pages.jl +++ b/EpiAware/docs/pages.jl @@ -3,7 +3,7 @@ pages = [ "Manual" => Any[ "Guide" => "man/guide.md", "Examples" => [ - "Getting started" => "man/getting-started.md" + "Getting started" => "examples/getting_started.md" ] ], "Reference" => Any[ From 359a3c29c33c4ef1e7c27c3c471aa791f465e0d2 Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Tue, 5 Mar 2024 13:51:56 +0000 Subject: [PATCH 10/40] include build step for rendering Pluto notebooks --- EpiAware/docs/make.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/EpiAware/docs/make.jl b/EpiAware/docs/make.jl index 012a779d6..fe3c1f58c 100644 --- a/EpiAware/docs/make.jl +++ b/EpiAware/docs/make.jl @@ -7,7 +7,7 @@ include("changelog.jl") include("pages.jl") include("build.jl") -# build("examples") +build("examples") makedocs(; sitename = "EpiAware.jl", authors = "Samuel Brand, Zachary Susswein, Sam Abbott, and contributors", From ca2707e0a4a4f3e3da9969f9ae5e4cf41572fc83 Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Tue, 5 Mar 2024 13:52:11 +0000 Subject: [PATCH 11/40] remove old version of the getting started example --- .../toy_model_log_infs_RW.jl | 193 ------------------ 1 file changed, 193 deletions(-) delete mode 100644 EpiAware/test/predictive_checking/toy_model_log_infs_RW.jl diff --git a/EpiAware/test/predictive_checking/toy_model_log_infs_RW.jl b/EpiAware/test/predictive_checking/toy_model_log_infs_RW.jl deleted file mode 100644 index cce09e19a..000000000 --- a/EpiAware/test/predictive_checking/toy_model_log_infs_RW.jl +++ /dev/null @@ -1,193 +0,0 @@ -#= -# Toy model for running analysis: - -This is a toy model for demonstrating current functionality of EpiAware package. - -## Generative Model without data - -### Latent Process - -The latent process is a random walk defined by a Turing model `random_walk` of specified length `n`. - -_Unfixed parameters_: -- `σ²_RW`: The variance of the random walk process. Current defauly prior is -- `init_rw_value`: The initial value of the random walk process. -- `ϵ_t`: The random noise vector. - -```math -\begin{align} -X(t) &= X(0) + \sigma_{RW} \sum_{t= 1}^n \epsilon_t \\ -X(0) &\sim \mathcal{N}(0, 1) \\ -\epsilon_t &\sim \mathcal{N}(0, 1) \\ -\sigma_{RW} &\sim \text{HalfNormal}(0.05). -\end{align} -``` - -### Log-Infections Model - -The log-infections model is defined by a Turing model `log_infections` that takes the observed data `y_t` (or `missing` value), -an `EpiModel` object `epi_model`, and a `latent_model` model. In this case the latent process is a random walk model. - -It also accepts optional arguments for the `process_priors`, `transform_function`, `pos_shift`, `neg_bin_cluster_factor`, and `neg_bin_cluster_factor_prior`. - -```math -\log I_t = \exp(X(t)). -``` - -### Observation model - -The observation model is a negative binomial distribution with mean `μ` and cluster factor `r`. Delays are implemented -as the action of a sparse kernel on the infections $I(t)$. The delay kernel is contained in an `EpiModel` struct. - -```math -\begin{align} -y_t &\sim \text{NegBinomial}(\mu = \sum_s\geq 0 K[t, t-s] I(s), r), -r &\sim \text{Gamma}(3, 0.05/3). -\end{align} -``` - -## Load dependencies - -This script should be run from Test environment mode. If not, run the following command: - -```julia -using TestEnv # Run in Test environment mode -TestEnv.activate() -``` - -=# - -# using TestEnv # Run in Test environment mode -# TestEnv.activate() - -using EpiAware -using Turing -using Distributions -using StatsPlots -using Random -using DynamicPPL -using Statistics -using DataFramesMeta -using CSV # For outputting the MCMC chain - -Random.seed!(0) - -#= -## Create an `EpiModel` struct - -- Medium length generation interval distribution. -- Median 2 day, std 4.3 day delay distribution. -=# - -truth_GI = Gamma(2, 5) -model_data = EpiData(truth_GI, - D_gen = 10.0) - -log_I0_prior = Normal(0.0, 10.0) -epi_model = DirectInfections(model_data, log_I0_prior) - -#= -## Define the data generating process - -In this case we use the `DirectInfections` model. -=# - -rwp = EpiAware.RandomWalk(Normal(), - truncated(Normal(0.0, 0.01), 0.0, 0.5)) - -#Define the observation model - no delay model -time_horizon = 100 -obs_model = EpiAware.DelayObservations([1.0], - time_horizon, - truncated(Gamma(5, 0.05 / 5), 1e-3, 1.0)) - -#= -## Generate a `Turing` `Model` -We don't have observed data, so we use `missing` value for `y_t`. -=# - -log_infs_model = make_epi_aware(missing, time_horizon, ; epi_model = epi_model, - latent_model_model = rwp, observation_model = obs_model, - pos_shift = 1e-6) - -#= -## Sample from the model -I define a fixed version of the model with initial infections set to 1 and variance of the random walk process set to 0.1. -We can sample from the model using the `rand` function, and plot the generated infections against generated cases. - -We can get the generated infections using `generated_quantities` function. Because the observed -cases are "defined" with a `~` operator they can be accessed directly from the randomly sampled -process. -=# - -cond_toy = fix(log_infs_model, (init = log(1.0), σ²_RW = 0.1)) -random_epidemic = rand(cond_toy) -gen = generated_quantities(cond_toy, random_epidemic) - -plot(gen.I_t, - label = "I_t", - xlabel = "Time", - ylabel = "Infections", - title = "Generated Infections") -scatter!(random_epidemic.y_t, lab = "generated cases") - -#= -## Inference - -We treat the generated data as observed data and attempt to infer underlying infections. -=# - -truth_data = random_epidemic.y_t - -model = make_epi_aware(truth_data, time_horizon, ; epi_model = epi_model, - latent_model_model = rwp, observation_model = obs_model, - pos_shift = 1e-6) -@time chn = sample(model, - NUTS(; adtype = AutoReverseDiff(true)), - MCMCThreads(), - 250, - 4; - drop_warmup = true) - -#= -## Postior predictive checking - -We check the posterior predictive checking by plotting the predicted cases against the observed cases. -=# - -predicted_y_t = mapreduce(hcat, generated_quantities(log_infs_model, chn)) do gen - gen.generated_y_t -end - -plot(predicted_y_t, c = :grey, alpha = 0.05, lab = "") -scatter!(truth_data, - lab = "Observed cases", - xlabel = "Time", - ylabel = "Cases", - title = "Posterior Predictive Checking", - ylims = (-0.5, maximum(truth_data) * 2.5)) - -#= -## Underlying inferred infections -=# - -predicted_I_t = mapreduce(hcat, generated_quantities(log_infs_model, chn)) do gen - gen.I_t -end - -plot(predicted_I_t, c = :grey, alpha = 0.05, lab = "") -scatter!(gen.I_t, - lab = "Actual infections", - xlabel = "Time", - ylabel = "Unobserved Infections", - title = "Posterior Predictive Checking", - ylims = (-0.5, maximum(gen.I_t) * 1.5)) - -#= -## Outputing the MCMC chain -We can use `spread_draws` to convert the MCMC chain into a tidybayes format. -=# - -df_chn = spread_draws(chn) -save_path = joinpath(@__DIR__, "assets/toy_model_log_infs_RW_draws.csv") -CSV.write(save_path, df_chn) From f7816de4a227c027d550e9b658576abeb52056fb Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Fri, 1 Mar 2024 18:21:09 +0000 Subject: [PATCH 12/40] update --- EpiAware/docs/src/examples/getting_started.jl | 72 +++++++ EpiAware/docs/src/man/getting-started.md | 193 ++++++++++++++++++ .../toy_model_log_infs_RW.jl | 2 +- 3 files changed, 266 insertions(+), 1 deletion(-) create mode 100644 EpiAware/docs/src/examples/getting_started.jl diff --git a/EpiAware/docs/src/examples/getting_started.jl b/EpiAware/docs/src/examples/getting_started.jl new file mode 100644 index 000000000..938e2cfb2 --- /dev/null +++ b/EpiAware/docs/src/examples/getting_started.jl @@ -0,0 +1,72 @@ +### A Pluto.jl notebook ### +# v0.19.40 + +using Markdown +using InteractiveUtils + +# ╔═╡ c593a2a0-d7f5-11ee-0931-d9f65ae84a72 +# hideall +let + docs_dir = dirname(dirname(@__DIR__)) + pkg_dir = dirname(docs_dir) + + using Pkg: Pkg + Pkg.activate(docs_dir) + Pkg.develop(; path = pkg_dir) + Pkg.instantiate() +end; + +# ╔═╡ 3ebc8384-f73d-4597-83a7-07a3744fed61 +md" +# Getting stated with `EpiAware` + +This is a toy model for demonstrating current functionality of EpiAware package. + +## Generative Model without data + +### Latent Process + +The latent process is a random walk defined by a Turing model `random_walk` of specified length `n`. + +_Unfixed parameters_: +- `σ²_RW`: The variance of the random walk process. Current defauly prior is +- `init_rw_value`: The initial value of the random walk process. +- `ϵ_t`: The random noise vector. + +```math +\begin{align} +X(t) &= X(0) + \sigma_{RW} \sum_{t= 1}^n \epsilon_t \\ +X(0) &\sim \mathcal{N}(0, 1) \\ +\epsilon_t &\sim \mathcal{N}(0, 1) \\ +\sigma_{RW} &\sim \text{HalfNormal}(0.05). +\end{align} +``` + +### Log-Infections Model + +The log-infections model is defined by a Turing model `log_infections` that takes the observed data `y_t` (or `missing` value), +an `EpiModel` object `epi_model`, and a `latent_model` model. In this case the latent process is a random walk model. + +It also accepts optional arguments for the `process_priors`, `transform_function`, `pos_shift`, `neg_bin_cluster_factor`, and `neg_bin_cluster_factor_prior`. + +```math +\log I_t = \exp(X(t)). +``` + +### Observation model + +The observation model is a negative binomial distribution with mean `μ` and cluster factor `r`. Delays are implemented +as the action of a sparse kernel on the infections $I(t)$. The delay kernel is contained in an `EpiModel` struct. + +```math +\begin{align} +y_t &\sim \text{NegBinomial}(\mu = \sum_s\geq 0 K[t, t-s] I(s), r), +r &\sim \text{Gamma}(3, 0.05/3). +\end{align} +``` + +" + +# ╔═╡ Cell order: +# ╟─c593a2a0-d7f5-11ee-0931-d9f65ae84a72 +# ╠═3ebc8384-f73d-4597-83a7-07a3744fed61 diff --git a/EpiAware/docs/src/man/getting-started.md b/EpiAware/docs/src/man/getting-started.md index e69de29bb..65b6b7a99 100644 --- a/EpiAware/docs/src/man/getting-started.md +++ b/EpiAware/docs/src/man/getting-started.md @@ -0,0 +1,193 @@ +```@meta +EditURL = "../../../test/predictive_checking/toy_model_log_infs_RW.jl" +``` + +# Getting started + +This is a toy model for demonstrating current functionality of EpiAware package. + +## Generative Model without data + +### Latent Process + +The latent process is a random walk defined by a Turing model `random_walk` of specified length `n`. + +_Unfixed parameters_: +- `σ²_RW`: The variance of the random walk process. Current defauly prior is +- `init_rw_value`: The initial value of the random walk process. +- `ϵ_t`: The random noise vector. + +```math +\begin{align} +X(t) &= X(0) + \sigma_{RW} \sum_{t= 1}^n \epsilon_t \\ +X(0) &\sim \mathcal{N}(0, 1) \\ +\epsilon_t &\sim \mathcal{N}(0, 1) \\ +\sigma_{RW} &\sim \text{HalfNormal}(0.05). +\end{align} +``` + +### Log-Infections Model + +The log-infections model is defined by a Turing model `log_infections` that takes the observed data `y_t` (or `missing` value), +an `EpiModel` object `epi_model`, and a `latent_model` model. In this case the latent process is a random walk model. + +It also accepts optional arguments for the `process_priors`, `transform_function`, `pos_shift`, `neg_bin_cluster_factor`, and `neg_bin_cluster_factor_prior`. + +```math +\log I_t = \exp(X(t)). +``` + +### Observation model + +The observation model is a negative binomial distribution with mean `μ` and cluster factor `r`. Delays are implemented +as the action of a sparse kernel on the infections $I(t)$. The delay kernel is contained in an `EpiModel` struct. + +```math +\begin{align} +y_t &\sim \text{NegBinomial}(\mu = \sum_s\geq 0 K[t, t-s] I(s), r), +r &\sim \text{Gamma}(3, 0.05/3). +\end{align} +``` + +## Load dependencies + +This script should be run from Test environment mode. If not, run the following command: + + +````@example toy_model_log_infs_RW +using EpiAware +using Turing +using Distributions +using StatsPlots +using Random +using DynamicPPL +using Statistics +using DataFramesMeta +using CSV # For outputting the MCMC chain + +Random.seed!(0) +```` + +## Create an `EpiModel` struct + +- Medium length generation interval distribution. +- Median 2 day, std 4.3 day delay distribution. + +````@example toy_model_log_infs_RW +truth_GI = Gamma(2, 5) +model_data = EpiData(truth_GI, + D_gen = 10.0) + +log_I0_prior = Normal(0.0, 1.0) +epi_model = DirectInfections(model_data, log_I0_prior) +```` + +## Define the data generating process + +In this case we use the `DirectInfections` model. + +````@example toy_model_log_infs_RW +rwp = EpiAware.RandomWalk(Normal(), + truncated(Normal(0.0, 0.01), 0.0, 0.5)) + +#Define the observation model - no delay model +time_horizon = 100 +obs_model = EpiAware.DelayObservations([1.0], + time_horizon, + truncated(Gamma(5, 0.05 / 5), 1e-3, 1.0)) +```` + +## Generate a `Turing` `Model` +We don't have observed data, so we use `missing` value for `y_t`. + +````@example toy_model_log_infs_RW +log_infs_model = make_epi_aware(missing, time_horizon, ; epi_model = epi_model, + latent_model_model = rwp, observation_model = obs_model, + pos_shift = 1e-6) +```` + +## Sample from the model +I define a fixed version of the model with initial infections set to 1 and variance of the random walk process set to 0.1. +We can sample from the model using the `rand` function, and plot the generated infections against generated cases. + +We can get the generated infections using `generated_quantities` function. Because the observed +cases are "defined" with a `~` operator they can be accessed directly from the randomly sampled +process. + +````@example toy_model_log_infs_RW +cond_toy = fix(log_infs_model, (init = log(1.0), σ²_RW = 0.1)) +random_epidemic = rand(cond_toy) +gen = generated_quantities(cond_toy, random_epidemic) + +plot(gen.I_t, + label = "I_t", + xlabel = "Time", + ylabel = "Infections", + title = "Generated Infections") +scatter!(random_epidemic.y_t, lab = "generated cases") +```` + +## Inference + +We treat the generated data as observed data and attempt to infer underlying infections. + +````@example toy_model_log_infs_RW +truth_data = random_epidemic.y_t + +model = make_epi_aware(truth_data, time_horizon, ; epi_model = epi_model, + latent_model_model = rwp, observation_model = obs_model, + pos_shift = 1e-6) +@time chn = sample(model, + NUTS(; adtype = AutoReverseDiff(true)), + MCMCThreads(), + 250, + 4; + drop_warmup = true) +```` + +## Postior predictive checking + +We check the posterior predictive checking by plotting the predicted cases against the observed cases. + +````@example toy_model_log_infs_RW +predicted_y_t = mapreduce(hcat, generated_quantities(log_infs_model, chn)) do gen + gen.generated_y_t +end + +plot(predicted_y_t, c = :grey, alpha = 0.05, lab = "") +scatter!(truth_data, + lab = "Observed cases", + xlabel = "Time", + ylabel = "Cases", + title = "Posterior Predictive Checking", + ylims = (-0.5, maximum(truth_data) * 2.5)) +```` + +## Underlying inferred infections + +````@example toy_model_log_infs_RW +predicted_I_t = mapreduce(hcat, generated_quantities(log_infs_model, chn)) do gen + gen.I_t +end + +plot(predicted_I_t, c = :grey, alpha = 0.05, lab = "") +scatter!(gen.I_t, + lab = "Actual infections", + xlabel = "Time", + ylabel = "Unobserved Infections", + title = "Posterior Predictive Checking", + ylims = (-0.5, maximum(gen.I_t) * 1.5)) +```` + +## Outputing the MCMC chain +We can use `spread_draws` to convert the MCMC chain into a tidybayes format. + +````@example toy_model_log_infs_RW +df_chn = spread_draws(chn) +save_path = joinpath(@__DIR__, "assets/toy_model_log_infs_RW_draws.csv") +CSV.write(save_path, df_chn) +```` + +--- + +*This page was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).* diff --git a/EpiAware/test/predictive_checking/toy_model_log_infs_RW.jl b/EpiAware/test/predictive_checking/toy_model_log_infs_RW.jl index e686bf41a..cce09e19a 100644 --- a/EpiAware/test/predictive_checking/toy_model_log_infs_RW.jl +++ b/EpiAware/test/predictive_checking/toy_model_log_infs_RW.jl @@ -83,7 +83,7 @@ truth_GI = Gamma(2, 5) model_data = EpiData(truth_GI, D_gen = 10.0) -log_I0_prior = Normal(0.0, 1.0) +log_I0_prior = Normal(0.0, 10.0) epi_model = DirectInfections(model_data, log_I0_prior) #= From 0b4bfe248b2376b9cc4ead9cfc31ecf5794e7e66 Mon Sep 17 00:00:00 2001 From: Samuel Brand <48288458+SamuelBrand1@users.noreply.github.com> Date: Mon, 4 Mar 2024 07:12:21 +0000 Subject: [PATCH 13/40] added doc deps --- EpiAware/docs/Project.toml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/EpiAware/docs/Project.toml b/EpiAware/docs/Project.toml index 3d7eeda93..4ec707df0 100644 --- a/EpiAware/docs/Project.toml +++ b/EpiAware/docs/Project.toml @@ -1,4 +1,11 @@ [deps] Changelog = "5217a498-cd5d-4ec6-b8c2-9b85a09b6e3e" +DataFramesMeta = "1313f7d8-7da2-5740-9ea0-a2ca25f37964" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" EpiAware = "b2eeebe4-5992-4301-9193-7ebc9f62c855" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd" +Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" From 36fd0ea7cde20670f005529ae7954eb89b202664 Mon Sep 17 00:00:00 2001 From: Samuel Brand <48288458+SamuelBrand1@users.noreply.github.com> Date: Mon, 4 Mar 2024 07:12:53 +0000 Subject: [PATCH 14/40] started `getting started` notebook --- EpiAware/docs/src/examples/getting_started.jl | 196 ++++++++++++++++-- 1 file changed, 177 insertions(+), 19 deletions(-) diff --git a/EpiAware/docs/src/examples/getting_started.jl b/EpiAware/docs/src/examples/getting_started.jl index 938e2cfb2..113990f3c 100644 --- a/EpiAware/docs/src/examples/getting_started.jl +++ b/EpiAware/docs/src/examples/getting_started.jl @@ -16,43 +16,131 @@ let Pkg.instantiate() end; +# ╔═╡ da479d8d-1312-4b98-b0af-5be52dffaf3f +begin + using EpiAware + using Turing + using Distributions + using StatsPlots + using Random + using DynamicPPL + using Statistics + using DataFramesMeta +end + # ╔═╡ 3ebc8384-f73d-4597-83a7-07a3744fed61 md" # Getting stated with `EpiAware` -This is a toy model for demonstrating current functionality of EpiAware package. +This tutorial introduces the basic functionality of `EpiAware`. `EpiAware` is a package for making inferences on epidemiological case/determined infection data using a model-based approach. + +## `EpiAware` models +The models we consider are discrete-time $t = 1,\dots, T$ with a latent random process, $Z_t$ generating stochasticity in the number of new infections $I_t$ at each time step. Observations are treated as downstream random variables determined by the actual infections and a model of infection to observation delay. + +#### Mathematical definition +```math +\begin{align} +Z_\cdot &\sim \mathcal{P}(\mathbb{R}^T) | \theta_Z, \\ +I_0 &\sim f_0(\theta_I), \\ +I_t &\sim g_I(\{I_s, Z_s\}_{s < t}, \theta_{I}), \\ +y_t &\sim f_O(\{I_s\}_{s \leq t}, \theta_{O}). +\end{align} +``` +Where, $\mathcal{P}(\mathbb{R}^T) | \theta_Z$ is a parametric process on $\mathbb{R}^T$. $f_0$ and $f_O$ are parametric distributions on, respectively, the initial number of infections and the observed case data conditional on underlying infections. $g_I$ is distribution of new infections conditional on infections and latent process in the past. Note that we assume that new infections are conditional on the strict past, whereas new observations can depend on infections on the same time step. -## Generative Model without data +#### Code structure outline -### Latent Process +An `EpiAware` model in code is created from three modular components: -The latent process is a random walk defined by a Turing model `random_walk` of specified length `n`. +- A `LatentModel`: This defines the model for $Z_\cdot$. This chooses $\mathcal{P}(\mathbb{R}^T) | \theta_Z$. +- An `EpiModel`: This defines a generative process for infections conditional on the latent process. This chooses $f_0(\theta_I)$, and $g_I(\{I_s, Z_s\}_{s < t}, \theta_{I})$. +- An `ObservationModel`: This defines the observation model. This chooses $f_O({I_s}_{s \leq t}, \theta_{O})$ -_Unfixed parameters_: -- `σ²_RW`: The variance of the random walk process. Current defauly prior is -- `init_rw_value`: The initial value of the random walk process. -- `ϵ_t`: The random noise vector. +#### Reproductive number +`EpiAware` models do not need to specify a time-varying reproductive number $\mathcal{R}_t$ to generate $I_\cdot$, however, this is often a quantity of interest. When not directly used we will typically back-calculate $\mathcal{R}_t$ from the generated infections: +```math +\mathcal{R}_t = {I_t \over \sum_{s \geq 1} g_s I_{t-s} }. +``` + +Where $g_s$ is a discrete generation interval. For this reason, even when not using a reproductive number approach directly, we ask for a generation interval. +" + +# ╔═╡ 5a0d5ab8-e985-4126-a1ac-58fe08beee38 +md" +## Random walk `LatentModel` + +As an example, we choose the latent process as a random walk with parameters $\theta_Z$: + +- ``Z_0``: Initial position. +- ``\sigma^2_{Z}``: The step-size variance. + +Conditional on the parameters the random walk is then generated by white noise: ```math \begin{align} -X(t) &= X(0) + \sigma_{RW} \sum_{t= 1}^n \epsilon_t \\ -X(0) &\sim \mathcal{N}(0, 1) \\ -\epsilon_t &\sim \mathcal{N}(0, 1) \\ -\sigma_{RW} &\sim \text{HalfNormal}(0.05). +Z_t &= Z_0 + \sigma_{RW} \sum_{t= 1}^T \epsilon_t, \\ +\epsilon_t &\sim \mathcal{N}(0,1). \end{align} ``` -### Log-Infections Model +In `EpiAware` we provide a constructor for random walk latent models with priors for $\theta_Z$. We choose priors, +```math +\begin{align} +Z_0 &\sim \mathcal{N}(0,1),\\ +\sigma^2_Z &\sim \text{HalfNormal}(0.01). +\end{align} +``` +" -The log-infections model is defined by a Turing model `log_infections` that takes the observed data `y_t` (or `missing` value), -an `EpiModel` object `epi_model`, and a `latent_model` model. In this case the latent process is a random walk model. +# ╔═╡ 56ae496b-0094-460b-89cb-526627991717 +rwp = EpiAware.RandomWalk(Normal(), + truncated(Normal(0.0, 0.01), 0.0, 0.5)) -It also accepts optional arguments for the `process_priors`, `transform_function`, `pos_shift`, `neg_bin_cluster_factor`, and `neg_bin_cluster_factor_prior`. +# ╔═╡ 767beffd-1ef5-4e6c-9ac6-edb52e60fb44 +md" +## Direct infection `EpiModel` +This is a simple model where the unobserved log-infections are directly generated by the latent process $Z$. ```math -\log I_t = \exp(X(t)). +\log I_t = \log I_0 + Z_t. ``` +As discussed above, we still ask for a defined generation interval, which can be used to calculate $\mathcal{R}_t$. + +" + +# ╔═╡ 9e43cbe3-94de-44fc-a788-b9c7adb34218 +truth_GI = Gamma(2, 5) + +# ╔═╡ f067284f-a1a6-44a6-9b79-f8c2de447673 +md" +The `EpiData` constructor performs double interval censoring to convert our _continuous_ estimate of the generation interval into a discretized version. We also implement right truncation using the keyword `D_gen`. +" + +# ╔═╡ c0662d48-4b54-4b6d-8c91-ddf4b0e3aa43 +model_data = EpiData(truth_GI, + D_gen = 10.0) + +# ╔═╡ fd72094f-1b95-4d07-a8b0-ef47dc560dfc +md" +We can supply a prior for the initial log_infections. +" + +# ╔═╡ 6639e66f-7725-4976-81b2-6472419d1a62 +log_I0_prior = Normal(0.0, 10.0) + +# ╔═╡ df5e59f8-3185-4bed-9cca-7c266df17cec +md" +And construct the `EpiModel`. +" + +# ╔═╡ 6fbdd8e6-2323-4352-9185-1f31a9cf9012 +epi_model = DirectInfections(model_data, log_I0_prior) + +# ╔═╡ 5e62a50a-71f4-4902-b1c9-fdf51fe145fa +md" + + ### Observation model The observation model is a negative binomial distribution with mean `μ` and cluster factor `r`. Delays are implemented @@ -64,9 +152,79 @@ y_t &\sim \text{NegBinomial}(\mu = \sum_s\geq 0 K[t, t-s] I(s), r), r &\sim \text{Gamma}(3, 0.05/3). \end{align} ``` - " +# ╔═╡ c7580ae6-0db5-448e-8b20-4dd6fcdb1ae0 +time_horizon = 100 + +# ╔═╡ 448669bc-99f4-4823-b15e-fcc9040ba31b +obs_model = EpiAware.DelayObservations([1.0], + time_horizon, + truncated(Gamma(5, 0.05 / 5), 1e-3, 1.0)) + +# ╔═╡ abeff860-58c3-4644-9325-66ffd4446b6d +log_infs_model = make_epi_aware(missing, time_horizon, ; epi_model = epi_model, + latent_model_model = rwp, observation_model = obs_model, + pos_shift = 1e-6) + +# ╔═╡ 7e0e6012-8648-4f84-a25a-8b0138c4b72a +cond_toy = fix(log_infs_model, (init_incidence = log(10.0), σ²_RW = 0.1, init_rw = 0.0)) + +# ╔═╡ b20c28be-7b07-410c-a33b-ea5ad6828c12 +random_epidemic = rand(cond_toy) + +# ╔═╡ d073e63b-62da-4743-ace0-78ef7806bc0b +gen = generated_quantities(cond_toy, random_epidemic) + +# ╔═╡ f68b4e41-ac5c-42cd-a8c2-8761d66f7543 +plot(gen.I_t, + label = "I_t", + xlabel = "Time", + ylabel = "Infections", + title = "Generated Infections") + +# ╔═╡ 31764672-3073-4280-8ab2-d42544be1629 +scatter!(random_epidemic.y_t, lab = "generated cases") + +# ╔═╡ c8ce0d46-a160-4c40-a055-69b3d10d1770 +truth_data = random_epidemic.y_t + +# ╔═╡ b4033728-b321-4100-8194-1fd9fe2d268d +model = make_epi_aware(truth_data, time_horizon, ; epi_model = epi_model, + latent_model_model = rwp, observation_model = obs_model, + pos_shift = 1e-6) | (init_rw = 0.0,) + +# ╔═╡ 3eb5ec5e-aae7-478e-84fb-80f2e9f85b4c +chn = sample(mdl, + NUTS(; adtype = AutoReverseDiff(true)), + MCMCThreads(), + 250, + 4; + drop_warmup = true) + # ╔═╡ Cell order: # ╟─c593a2a0-d7f5-11ee-0931-d9f65ae84a72 -# ╠═3ebc8384-f73d-4597-83a7-07a3744fed61 +# ╟─3ebc8384-f73d-4597-83a7-07a3744fed61 +# ╠═da479d8d-1312-4b98-b0af-5be52dffaf3f +# ╟─5a0d5ab8-e985-4126-a1ac-58fe08beee38 +# ╠═56ae496b-0094-460b-89cb-526627991717 +# ╟─767beffd-1ef5-4e6c-9ac6-edb52e60fb44 +# ╠═9e43cbe3-94de-44fc-a788-b9c7adb34218 +# ╟─f067284f-a1a6-44a6-9b79-f8c2de447673 +# ╠═c0662d48-4b54-4b6d-8c91-ddf4b0e3aa43 +# ╟─fd72094f-1b95-4d07-a8b0-ef47dc560dfc +# ╠═6639e66f-7725-4976-81b2-6472419d1a62 +# ╟─df5e59f8-3185-4bed-9cca-7c266df17cec +# ╠═6fbdd8e6-2323-4352-9185-1f31a9cf9012 +# ╠═5e62a50a-71f4-4902-b1c9-fdf51fe145fa +# ╠═c7580ae6-0db5-448e-8b20-4dd6fcdb1ae0 +# ╠═448669bc-99f4-4823-b15e-fcc9040ba31b +# ╠═abeff860-58c3-4644-9325-66ffd4446b6d +# ╠═7e0e6012-8648-4f84-a25a-8b0138c4b72a +# ╠═b20c28be-7b07-410c-a33b-ea5ad6828c12 +# ╠═d073e63b-62da-4743-ace0-78ef7806bc0b +# ╠═f68b4e41-ac5c-42cd-a8c2-8761d66f7543 +# ╠═31764672-3073-4280-8ab2-d42544be1629 +# ╠═c8ce0d46-a160-4c40-a055-69b3d10d1770 +# ╠═b4033728-b321-4100-8194-1fd9fe2d268d +# ╠═3eb5ec5e-aae7-478e-84fb-80f2e9f85b4c From 58e9744b1016e693575eb85792e2fadab03d29ad Mon Sep 17 00:00:00 2001 From: Samuel Brand <48288458+SamuelBrand1@users.noreply.github.com> Date: Mon, 4 Mar 2024 13:39:22 +0000 Subject: [PATCH 15/40] Update getting_started.jl --- EpiAware/docs/src/examples/getting_started.jl | 96 +++++++++++++++---- 1 file changed, 78 insertions(+), 18 deletions(-) diff --git a/EpiAware/docs/src/examples/getting_started.jl b/EpiAware/docs/src/examples/getting_started.jl index 113990f3c..4706f36f1 100644 --- a/EpiAware/docs/src/examples/getting_started.jl +++ b/EpiAware/docs/src/examples/getting_started.jl @@ -127,7 +127,7 @@ We can supply a prior for the initial log_infections. " # ╔═╡ 6639e66f-7725-4976-81b2-6472419d1a62 -log_I0_prior = Normal(0.0, 10.0) +log_I0_prior = Normal(log(100.0), 1.0) # ╔═╡ df5e59f8-3185-4bed-9cca-7c266df17cec md" @@ -162,13 +162,21 @@ obs_model = EpiAware.DelayObservations([1.0], time_horizon, truncated(Gamma(5, 0.05 / 5), 1e-3, 1.0)) +# ╔═╡ e49713e8-4840-4083-8e3f-fc52d791be7b +md" +## Generation +" + # ╔═╡ abeff860-58c3-4644-9325-66ffd4446b6d -log_infs_model = make_epi_aware(missing, time_horizon, ; epi_model = epi_model, +full_epi_aware_mdl = make_epi_aware(missing, time_horizon; epi_model = epi_model, latent_model_model = rwp, observation_model = obs_model, pos_shift = 1e-6) +# ╔═╡ 36b34fd2-2891-42ca-b5dc-abb482e516ee +fixed_parameters = (init_incidence = log(100.0), σ²_RW = 0.1^2, init_rw = 0.0) + # ╔═╡ 7e0e6012-8648-4f84-a25a-8b0138c4b72a -cond_toy = fix(log_infs_model, (init_incidence = log(10.0), σ²_RW = 0.1, init_rw = 0.0)) +cond_toy = fix(full_epi_aware_mdl, fixed_parameters) # ╔═╡ b20c28be-7b07-410c-a33b-ea5ad6828c12 random_epidemic = rand(cond_toy) @@ -177,31 +185,79 @@ random_epidemic = rand(cond_toy) gen = generated_quantities(cond_toy, random_epidemic) # ╔═╡ f68b4e41-ac5c-42cd-a8c2-8761d66f7543 -plot(gen.I_t, - label = "I_t", - xlabel = "Time", - ylabel = "Infections", - title = "Generated Infections") +let + plot(gen.I_t, + label = "I_t", + xlabel = "Time", + ylabel = "Infections", + title = "Generated Infections") + scatter!(random_epidemic.y_t, lab = "generated cases") +end -# ╔═╡ 31764672-3073-4280-8ab2-d42544be1629 -scatter!(random_epidemic.y_t, lab = "generated cases") +# ╔═╡ b5bc8f05-b538-4abf-aa84-450bf2dff3d9 +md" +## Inference +" # ╔═╡ c8ce0d46-a160-4c40-a055-69b3d10d1770 truth_data = random_epidemic.y_t # ╔═╡ b4033728-b321-4100-8194-1fd9fe2d268d -model = make_epi_aware(truth_data, time_horizon, ; epi_model = epi_model, - latent_model_model = rwp, observation_model = obs_model, - pos_shift = 1e-6) | (init_rw = 0.0,) +inference_mdl = fix( + make_epi_aware(truth_data, time_horizon; epi_model = epi_model, + latent_model_model = rwp, observation_model = obs_model, + pos_shift = 1e-6), + (init_rw = 0.0,)) # ╔═╡ 3eb5ec5e-aae7-478e-84fb-80f2e9f85b4c -chn = sample(mdl, +chn = sample(inference_mdl, NUTS(; adtype = AutoReverseDiff(true)), MCMCThreads(), - 250, - 4; + 500, + 2; + n_warmup = 100, drop_warmup = true) +# ╔═╡ e9df22b8-8e4d-4ab7-91ea-c01f2239b3e5 +let + post_check_mdl = fix(full_epi_aware_mdl, (init_rw = 0.0,)) + post_check_y_t = mapreduce(hcat, generated_quantities(full_epi_aware_mdl, chn)) do gen + gen.generated_y_t + end + + predicted_I_t = mapreduce(hcat, generated_quantities(inference_mdl, chn)) do gen + gen.I_t + end + + p1 = plot(post_check_y_t, c = :grey, alpha = 0.05, lab = "") + scatter!(p1, truth_data, + lab = "Observed cases", + xlabel = "Time", + ylabel = "Cases", + title = "Post. predictive checking: cases", + ylims = (-0.5, maximum(truth_data) * 2.5), + c = :green) + + p2 = plot(predicted_I_t, c = :grey, alpha = 0.05, lab = "") + scatter!(p2, gen.I_t, + lab = "Actual infections", + xlabel = "Time", + ylabel = "Unobserved Infections", + title = "Post. predictions: infections", + ylims = (-0.5, maximum(gen.I_t) * 1.5), + c = :red) + + plot(p1, p2, layout = (2, 1)) +end + +# ╔═╡ 10d8fe24-83a6-47ac-97b7-a374481473d3 +let + var_samples = chn[:σ²_RW] |> vec + histogram(var_samples, bins = 50, norm = :pdf) + vline!([fixed_parameters.:σ²_RW]) + plot!(rwp.var_prior) +end + # ╔═╡ Cell order: # ╟─c593a2a0-d7f5-11ee-0931-d9f65ae84a72 # ╟─3ebc8384-f73d-4597-83a7-07a3744fed61 @@ -219,12 +275,16 @@ chn = sample(mdl, # ╠═5e62a50a-71f4-4902-b1c9-fdf51fe145fa # ╠═c7580ae6-0db5-448e-8b20-4dd6fcdb1ae0 # ╠═448669bc-99f4-4823-b15e-fcc9040ba31b +# ╟─e49713e8-4840-4083-8e3f-fc52d791be7b # ╠═abeff860-58c3-4644-9325-66ffd4446b6d +# ╠═36b34fd2-2891-42ca-b5dc-abb482e516ee # ╠═7e0e6012-8648-4f84-a25a-8b0138c4b72a # ╠═b20c28be-7b07-410c-a33b-ea5ad6828c12 # ╠═d073e63b-62da-4743-ace0-78ef7806bc0b -# ╠═f68b4e41-ac5c-42cd-a8c2-8761d66f7543 -# ╠═31764672-3073-4280-8ab2-d42544be1629 +# ╟─f68b4e41-ac5c-42cd-a8c2-8761d66f7543 +# ╟─b5bc8f05-b538-4abf-aa84-450bf2dff3d9 # ╠═c8ce0d46-a160-4c40-a055-69b3d10d1770 # ╠═b4033728-b321-4100-8194-1fd9fe2d268d # ╠═3eb5ec5e-aae7-478e-84fb-80f2e9f85b4c +# ╠═e9df22b8-8e4d-4ab7-91ea-c01f2239b3e5 +# ╠═10d8fe24-83a6-47ac-97b7-a374481473d3 From 6513caf1e835ba52e3cf37fe42b20d9b8ef01bfc Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Tue, 5 Mar 2024 12:43:58 +0000 Subject: [PATCH 16/40] Getting started example --- EpiAware/docs/src/examples/getting_started.jl | 172 ++++++++++++++---- 1 file changed, 136 insertions(+), 36 deletions(-) diff --git a/EpiAware/docs/src/examples/getting_started.jl b/EpiAware/docs/src/examples/getting_started.jl index 4706f36f1..200c20c27 100644 --- a/EpiAware/docs/src/examples/getting_started.jl +++ b/EpiAware/docs/src/examples/getting_started.jl @@ -26,6 +26,7 @@ begin using DynamicPPL using Statistics using DataFramesMeta + using LinearAlgebra end # ╔═╡ 3ebc8384-f73d-4597-83a7-07a3744fed61 @@ -94,7 +95,7 @@ Z_0 &\sim \mathcal{N}(0,1),\\ # ╔═╡ 56ae496b-0094-460b-89cb-526627991717 rwp = EpiAware.RandomWalk(Normal(), - truncated(Normal(0.0, 0.01), 0.0, 0.5)) + truncated(Normal(0.0, 0.02), 0.0, Inf)) # ╔═╡ 767beffd-1ef5-4e6c-9ac6-edb52e60fb44 md" @@ -118,8 +119,7 @@ The `EpiData` constructor performs double interval censoring to convert our _con " # ╔═╡ c0662d48-4b54-4b6d-8c91-ddf4b0e3aa43 -model_data = EpiData(truth_GI, - D_gen = 10.0) +model_data = EpiData(truth_GI, D_gen = 10.0) # ╔═╡ fd72094f-1b95-4d07-a8b0-ef47dc560dfc md" @@ -141,52 +141,81 @@ epi_model = DirectInfections(model_data, log_I0_prior) md" -### Observation model +### Delayed Observations `ObservationModel` -The observation model is a negative binomial distribution with mean `μ` and cluster factor `r`. Delays are implemented -as the action of a sparse kernel on the infections $I(t)$. The delay kernel is contained in an `EpiModel` struct. +The observation model is a negative binomial distribution with mean `μ` and cluster factor `1 / r`. Delays are implemented +as the action of a sparse kernel on the infections $I(t)$. ```math \begin{align} -y_t &\sim \text{NegBinomial}(\mu = \sum_s\geq 0 K[t, t-s] I(s), r), -r &\sim \text{Gamma}(3, 0.05/3). +y_t &\sim \text{NegBinomial}(\mu = \sum_{s\geq 0} K[t, t-s] I(s), r), \\ +1 / r &\sim \text{Gamma}(3, 0.05/3). \end{align} ``` " +# ╔═╡ e813d547-6100-4c43-b84c-8cebe306bda8 +md" +We also set up the inference to occur over 100 days. +" + # ╔═╡ c7580ae6-0db5-448e-8b20-4dd6fcdb1ae0 time_horizon = 100 +# ╔═╡ 0aa3fcbd-0831-45b8-9a2c-7ffbabf5895f +md" +We choose a simple observation model where infections are observed 0, 1, 2, 3 days later with equal probability. +" + # ╔═╡ 448669bc-99f4-4823-b15e-fcc9040ba31b -obs_model = EpiAware.DelayObservations([1.0], +obs_model = EpiAware.DelayObservations( + fill(0.25, 4), time_horizon, - truncated(Gamma(5, 0.05 / 5), 1e-3, 1.0)) + truncated(Gamma(5, 0.05 / 5), 1e-3, 1.0) +) # ╔═╡ e49713e8-4840-4083-8e3f-fc52d791be7b md" -## Generation +## Generate cases from the `EpiAware` model + +Having chosen an `EpiModel`, `LatentModel` and `ObservationModel`, we can implement the model as a [`Turing`](https://turinglang.org/dev/) model using `make_epi_aware`. + +By giving `missing` to the first argument, we indicate that case data will be _generated_ from the model rather than treated as fixed. " # ╔═╡ abeff860-58c3-4644-9325-66ffd4446b6d -full_epi_aware_mdl = make_epi_aware(missing, time_horizon; epi_model = epi_model, - latent_model_model = rwp, observation_model = obs_model, - pos_shift = 1e-6) +full_epi_aware_mdl = make_epi_aware(missing, time_horizon; + epi_model = epi_model, + latent_model_model = rwp, + observation_model = obs_model) + +# ╔═╡ 821628fb-8044-48b0-aa4f-0b7b57a2f45a +md" +We choose some fixed parameters: +- Initial incidence is 100. +- In the direct infection model, the initial incidence and in the initial value of the random walk form a non-identifiable pair. Therefore, we fix $Z_0 = 0$. +" # ╔═╡ 36b34fd2-2891-42ca-b5dc-abb482e516ee -fixed_parameters = (init_incidence = log(100.0), σ²_RW = 0.1^2, init_rw = 0.0) +fixed_parameters = (rw_init = 0.0, init_incidence = log(100.0)) + +# ╔═╡ 0aadd9e3-7f91-4b45-9663-67d11335f0d0 +md" +We fix these parameters using `fix`, and generate a random epidemic. +" # ╔═╡ 7e0e6012-8648-4f84-a25a-8b0138c4b72a -cond_toy = fix(full_epi_aware_mdl, fixed_parameters) +cond_generative_model = fix(full_epi_aware_mdl, fixed_parameters) # ╔═╡ b20c28be-7b07-410c-a33b-ea5ad6828c12 -random_epidemic = rand(cond_toy) +random_epidemic = rand(cond_generative_model) # ╔═╡ d073e63b-62da-4743-ace0-78ef7806bc0b -gen = generated_quantities(cond_toy, random_epidemic) +true_infections = generated_quantities(cond_generative_model, random_epidemic).I_t # ╔═╡ f68b4e41-ac5c-42cd-a8c2-8761d66f7543 let - plot(gen.I_t, + plot(true_infections, label = "I_t", xlabel = "Time", ylabel = "Infections", @@ -197,6 +226,11 @@ end # ╔═╡ b5bc8f05-b538-4abf-aa84-450bf2dff3d9 md" ## Inference +Fixing $Z_0 = 0$ for the random walk was based on inference principles; in this model $Z_0$ and $\log I_0$ are non-identifiable. + +However, we now treat the generated data as `truth_data` and make inference without fixing any other parameters. + +We do the inference by MCMC/NUTS using the `Turing` NUTS sampler with default warm-up steps. " # ╔═╡ c8ce0d46-a160-4c40-a055-69b3d10d1770 @@ -205,23 +239,31 @@ truth_data = random_epidemic.y_t # ╔═╡ b4033728-b321-4100-8194-1fd9fe2d268d inference_mdl = fix( make_epi_aware(truth_data, time_horizon; epi_model = epi_model, - latent_model_model = rwp, observation_model = obs_model, - pos_shift = 1e-6), - (init_rw = 0.0,)) + latent_model_model = rwp, observation_model = obs_model), + (rw_init = 0.0,) +) # ╔═╡ 3eb5ec5e-aae7-478e-84fb-80f2e9f85b4c chn = sample(inference_mdl, NUTS(; adtype = AutoReverseDiff(true)), MCMCThreads(), - 500, - 2; - n_warmup = 100, + 250, + 4; drop_warmup = true) +# ╔═╡ 30498cc7-16a5-441a-b8cd-c19b220c60c1 +md" +### Predictive plotting + +We can spaghetti plot generated case data from the version of the model _which hasn't conditioned on case data_ using posterior parameters inferred from the version conditioned on observed data. This is known as _posterior predictive checking_, and is a useful diagnostic tool for Bayesian inference (see [here](http://www.stat.columbia.edu/~gelman/book/BDA3.pdf)). + +Because we are using synthetic data we can also plot the model predictions for the _unobserved_ infections and check that (at least in this example) we were able to capture some unobserved/latent variables in the process accurate. +" + # ╔═╡ e9df22b8-8e4d-4ab7-91ea-c01f2239b3e5 let - post_check_mdl = fix(full_epi_aware_mdl, (init_rw = 0.0,)) - post_check_y_t = mapreduce(hcat, generated_quantities(full_epi_aware_mdl, chn)) do gen + post_check_mdl = fix(full_epi_aware_mdl, (rw_init = 0.0,)) + post_check_y_t = mapreduce(hcat, generated_quantities(post_check_mdl, chn)) do gen gen.generated_y_t end @@ -235,27 +277,77 @@ let xlabel = "Time", ylabel = "Cases", title = "Post. predictive checking: cases", - ylims = (-0.5, maximum(truth_data) * 2.5), + ylims = (-0.5, maximum(truth_data) * 1.5), c = :green) p2 = plot(predicted_I_t, c = :grey, alpha = 0.05, lab = "") - scatter!(p2, gen.I_t, + scatter!(p2, true_infections, lab = "Actual infections", xlabel = "Time", ylabel = "Unobserved Infections", title = "Post. predictions: infections", - ylims = (-0.5, maximum(gen.I_t) * 1.5), + ylims = (-0.5, maximum(true_infections) * 1.5), c = :red) - plot(p1, p2, layout = (2, 1)) + plot(p1, p2, + layout = (1, 2), + size = (700, 400)) end +# ╔═╡ fd6321b1-4c3a-4123-b0dc-c45b951e0b80 +md" +As well as checking the posterior predictions for latent infections, we can also check how well inference recovered unknown parameters, such as the random walk variance or the cluster factor of the negative binomial observations. +" + # ╔═╡ 10d8fe24-83a6-47ac-97b7-a374481473d3 let - var_samples = chn[:σ²_RW] |> vec - histogram(var_samples, bins = 50, norm = :pdf) - vline!([fixed_parameters.:σ²_RW]) - plot!(rwp.var_prior) + parameters_to_plot = (:σ²_RW, :neg_bin_cluster_factor) + + plts = map(parameters_to_plot) do name + var_samples = chn[name] |> vec + histogram(var_samples, + bins = 50, + norm = :pdf, + lw = 0, + fillalpha = 0.5, + lab = "MCMC") + vline!([getfield(random_epidemic, name)], lab = "True value") + title!(string(name)) + end + plot(plts..., layout = (2, 1)) +end + +# ╔═╡ 81efe8ca-b753-4a12-bafc-a887a999377b +md" +## Reproductive number back-calculation + +As mentioned at the top, we _don't_ directly use the concept of reproductive numbers in this note. However, we can back-calculate the implied $\mathcal{R}(t)$ values, conditional on the specified generation interval being correct. + +Here we spaghetti plot posterior sampled time-varying reproductive numbers against the actual. +" + +# ╔═╡ 15b9f37f-8d5f-460d-8c28-d7f2271fd099 +let + n = epi_model.data.len_gen_int + Rt_denom = [dot(reverse(epi_model.data.gen_int), true_infections[(t - n):(t - 1)]) + for t in (n + 1):length(true_infections)] + true_Rt = true_infections[(n + 1):end] ./ Rt_denom + + predicted_Rt = mapreduce(hcat, generated_quantities(inference_mdl, chn)) do gen + _It = gen.I_t + _Rt_denom = [dot(reverse(epi_model.data.gen_int), _It[(t - n):(t - 1)]) + for t in (n + 1):length(_It)] + Rt = _It[(n + 1):end] ./ _Rt_denom + end + + plt = plot((n + 1):time_horizon, predicted_Rt, c = :grey, alpha = 0.05, lab = "") + plot!(plt, (n + 1):time_horizon, true_Rt, + lab = "true Rt", + xlabel = "Time", + ylabel = "Rt", + title = "Post. predictions: reproductive number", + c = :red, + lw = 2) end # ╔═╡ Cell order: @@ -272,12 +364,16 @@ end # ╠═6639e66f-7725-4976-81b2-6472419d1a62 # ╟─df5e59f8-3185-4bed-9cca-7c266df17cec # ╠═6fbdd8e6-2323-4352-9185-1f31a9cf9012 -# ╠═5e62a50a-71f4-4902-b1c9-fdf51fe145fa +# ╟─5e62a50a-71f4-4902-b1c9-fdf51fe145fa +# ╟─e813d547-6100-4c43-b84c-8cebe306bda8 # ╠═c7580ae6-0db5-448e-8b20-4dd6fcdb1ae0 +# ╟─0aa3fcbd-0831-45b8-9a2c-7ffbabf5895f # ╠═448669bc-99f4-4823-b15e-fcc9040ba31b # ╟─e49713e8-4840-4083-8e3f-fc52d791be7b # ╠═abeff860-58c3-4644-9325-66ffd4446b6d +# ╟─821628fb-8044-48b0-aa4f-0b7b57a2f45a # ╠═36b34fd2-2891-42ca-b5dc-abb482e516ee +# ╟─0aadd9e3-7f91-4b45-9663-67d11335f0d0 # ╠═7e0e6012-8648-4f84-a25a-8b0138c4b72a # ╠═b20c28be-7b07-410c-a33b-ea5ad6828c12 # ╠═d073e63b-62da-4743-ace0-78ef7806bc0b @@ -286,5 +382,9 @@ end # ╠═c8ce0d46-a160-4c40-a055-69b3d10d1770 # ╠═b4033728-b321-4100-8194-1fd9fe2d268d # ╠═3eb5ec5e-aae7-478e-84fb-80f2e9f85b4c +# ╟─30498cc7-16a5-441a-b8cd-c19b220c60c1 # ╠═e9df22b8-8e4d-4ab7-91ea-c01f2239b3e5 +# ╟─fd6321b1-4c3a-4123-b0dc-c45b951e0b80 # ╠═10d8fe24-83a6-47ac-97b7-a374481473d3 +# ╟─81efe8ca-b753-4a12-bafc-a887a999377b +# ╠═15b9f37f-8d5f-460d-8c28-d7f2271fd099 From 1c740db15d505e53408e47ba8a24f73380f37a56 Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Tue, 5 Mar 2024 13:31:17 +0000 Subject: [PATCH 17/40] Add generated Pluto notebooks to gitignore --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index ffa29724b..95595a137 100644 --- a/.gitignore +++ b/.gitignore @@ -381,3 +381,6 @@ docs/site/ /Manifest.toml .DS_Store .vscode/settings.json + +#Ignore generated Pluto notebooks +EpiAware/docs/src/examples/*.md From a37c77ad15741ade59df76b0307ac9aa666ebb22 Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Tue, 5 Mar 2024 13:33:33 +0000 Subject: [PATCH 18/40] remove old version of example --- EpiAware/docs/src/man/getting-started.md | 193 ----------------------- 1 file changed, 193 deletions(-) delete mode 100644 EpiAware/docs/src/man/getting-started.md diff --git a/EpiAware/docs/src/man/getting-started.md b/EpiAware/docs/src/man/getting-started.md deleted file mode 100644 index 65b6b7a99..000000000 --- a/EpiAware/docs/src/man/getting-started.md +++ /dev/null @@ -1,193 +0,0 @@ -```@meta -EditURL = "../../../test/predictive_checking/toy_model_log_infs_RW.jl" -``` - -# Getting started - -This is a toy model for demonstrating current functionality of EpiAware package. - -## Generative Model without data - -### Latent Process - -The latent process is a random walk defined by a Turing model `random_walk` of specified length `n`. - -_Unfixed parameters_: -- `σ²_RW`: The variance of the random walk process. Current defauly prior is -- `init_rw_value`: The initial value of the random walk process. -- `ϵ_t`: The random noise vector. - -```math -\begin{align} -X(t) &= X(0) + \sigma_{RW} \sum_{t= 1}^n \epsilon_t \\ -X(0) &\sim \mathcal{N}(0, 1) \\ -\epsilon_t &\sim \mathcal{N}(0, 1) \\ -\sigma_{RW} &\sim \text{HalfNormal}(0.05). -\end{align} -``` - -### Log-Infections Model - -The log-infections model is defined by a Turing model `log_infections` that takes the observed data `y_t` (or `missing` value), -an `EpiModel` object `epi_model`, and a `latent_model` model. In this case the latent process is a random walk model. - -It also accepts optional arguments for the `process_priors`, `transform_function`, `pos_shift`, `neg_bin_cluster_factor`, and `neg_bin_cluster_factor_prior`. - -```math -\log I_t = \exp(X(t)). -``` - -### Observation model - -The observation model is a negative binomial distribution with mean `μ` and cluster factor `r`. Delays are implemented -as the action of a sparse kernel on the infections $I(t)$. The delay kernel is contained in an `EpiModel` struct. - -```math -\begin{align} -y_t &\sim \text{NegBinomial}(\mu = \sum_s\geq 0 K[t, t-s] I(s), r), -r &\sim \text{Gamma}(3, 0.05/3). -\end{align} -``` - -## Load dependencies - -This script should be run from Test environment mode. If not, run the following command: - - -````@example toy_model_log_infs_RW -using EpiAware -using Turing -using Distributions -using StatsPlots -using Random -using DynamicPPL -using Statistics -using DataFramesMeta -using CSV # For outputting the MCMC chain - -Random.seed!(0) -```` - -## Create an `EpiModel` struct - -- Medium length generation interval distribution. -- Median 2 day, std 4.3 day delay distribution. - -````@example toy_model_log_infs_RW -truth_GI = Gamma(2, 5) -model_data = EpiData(truth_GI, - D_gen = 10.0) - -log_I0_prior = Normal(0.0, 1.0) -epi_model = DirectInfections(model_data, log_I0_prior) -```` - -## Define the data generating process - -In this case we use the `DirectInfections` model. - -````@example toy_model_log_infs_RW -rwp = EpiAware.RandomWalk(Normal(), - truncated(Normal(0.0, 0.01), 0.0, 0.5)) - -#Define the observation model - no delay model -time_horizon = 100 -obs_model = EpiAware.DelayObservations([1.0], - time_horizon, - truncated(Gamma(5, 0.05 / 5), 1e-3, 1.0)) -```` - -## Generate a `Turing` `Model` -We don't have observed data, so we use `missing` value for `y_t`. - -````@example toy_model_log_infs_RW -log_infs_model = make_epi_aware(missing, time_horizon, ; epi_model = epi_model, - latent_model_model = rwp, observation_model = obs_model, - pos_shift = 1e-6) -```` - -## Sample from the model -I define a fixed version of the model with initial infections set to 1 and variance of the random walk process set to 0.1. -We can sample from the model using the `rand` function, and plot the generated infections against generated cases. - -We can get the generated infections using `generated_quantities` function. Because the observed -cases are "defined" with a `~` operator they can be accessed directly from the randomly sampled -process. - -````@example toy_model_log_infs_RW -cond_toy = fix(log_infs_model, (init = log(1.0), σ²_RW = 0.1)) -random_epidemic = rand(cond_toy) -gen = generated_quantities(cond_toy, random_epidemic) - -plot(gen.I_t, - label = "I_t", - xlabel = "Time", - ylabel = "Infections", - title = "Generated Infections") -scatter!(random_epidemic.y_t, lab = "generated cases") -```` - -## Inference - -We treat the generated data as observed data and attempt to infer underlying infections. - -````@example toy_model_log_infs_RW -truth_data = random_epidemic.y_t - -model = make_epi_aware(truth_data, time_horizon, ; epi_model = epi_model, - latent_model_model = rwp, observation_model = obs_model, - pos_shift = 1e-6) -@time chn = sample(model, - NUTS(; adtype = AutoReverseDiff(true)), - MCMCThreads(), - 250, - 4; - drop_warmup = true) -```` - -## Postior predictive checking - -We check the posterior predictive checking by plotting the predicted cases against the observed cases. - -````@example toy_model_log_infs_RW -predicted_y_t = mapreduce(hcat, generated_quantities(log_infs_model, chn)) do gen - gen.generated_y_t -end - -plot(predicted_y_t, c = :grey, alpha = 0.05, lab = "") -scatter!(truth_data, - lab = "Observed cases", - xlabel = "Time", - ylabel = "Cases", - title = "Posterior Predictive Checking", - ylims = (-0.5, maximum(truth_data) * 2.5)) -```` - -## Underlying inferred infections - -````@example toy_model_log_infs_RW -predicted_I_t = mapreduce(hcat, generated_quantities(log_infs_model, chn)) do gen - gen.I_t -end - -plot(predicted_I_t, c = :grey, alpha = 0.05, lab = "") -scatter!(gen.I_t, - lab = "Actual infections", - xlabel = "Time", - ylabel = "Unobserved Infections", - title = "Posterior Predictive Checking", - ylims = (-0.5, maximum(gen.I_t) * 1.5)) -```` - -## Outputing the MCMC chain -We can use `spread_draws` to convert the MCMC chain into a tidybayes format. - -````@example toy_model_log_infs_RW -df_chn = spread_draws(chn) -save_path = joinpath(@__DIR__, "assets/toy_model_log_infs_RW_draws.csv") -CSV.write(save_path, df_chn) -```` - ---- - -*This page was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).* From 66782b84f3347e1f5d0fe9044b328e8a830f060f Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Tue, 5 Mar 2024 13:33:51 +0000 Subject: [PATCH 19/40] add Pluto and PlutoStaticHTML as deps --- EpiAware/docs/Project.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/EpiAware/docs/Project.toml b/EpiAware/docs/Project.toml index 4ec707df0..900ecfe73 100644 --- a/EpiAware/docs/Project.toml +++ b/EpiAware/docs/Project.toml @@ -6,6 +6,8 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" EpiAware = "b2eeebe4-5992-4301-9193-7ebc9f62c855" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Pluto = "c3e4b0f8-55cb-11ea-2926-15256bba5781" +PlutoStaticHTML = "359b1769-a58e-495b-9770-312e911026ad" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" From 685361480ed86c84b0eb2ca8fbca0aab1291ec18 Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Tue, 5 Mar 2024 13:35:03 +0000 Subject: [PATCH 20/40] Adapt make and build files to generate Pluto notebooks statically and then render into docs --- EpiAware/docs/build.jl | 24 ++++++++++++++++++++++++ EpiAware/docs/make.jl | 10 +++++++++- EpiAware/docs/pages.jl | 2 +- 3 files changed, 34 insertions(+), 2 deletions(-) create mode 100644 EpiAware/docs/build.jl diff --git a/EpiAware/docs/build.jl b/EpiAware/docs/build.jl new file mode 100644 index 000000000..21d1ddcb4 --- /dev/null +++ b/EpiAware/docs/build.jl @@ -0,0 +1,24 @@ + +"""Run all Pluto notebooks (".jl" files) in `tutorials_dir` and write outputs to HTML files.""" +function build(target_subdir; _module = EpiAware) + target_dir = joinpath(pkgdir(_module), "docs", "src", target_subdir) + + @info "Building notebooks in $target_subdir" + # Evaluate notebooks in the same process to avoid having to recompile from scratch each time. + # This is similar to how Documenter and Franklin evaluate code. + # Note that things like method overrides and other global changes may leak between notebooks! + use_distributed = false + output_format = documenter_output + bopts = BuildOptions(target_dir; use_distributed, output_format) + build_notebooks(bopts) + return nothing +end + +"Return Markdown file links which can be passed to Documenter.jl." +function markdown_files(notebook_titles, target_subdir) + md_files = map(notebook_titles) do title + file = lowercase(replace(title, " " => '_')) + return joinpath(target_subdir, "$file.md") + end + return md_files +end diff --git a/EpiAware/docs/make.jl b/EpiAware/docs/make.jl index 115806098..b25a02618 100644 --- a/EpiAware/docs/make.jl +++ b/EpiAware/docs/make.jl @@ -1,8 +1,13 @@ using Documenter using EpiAware +using Pluto: Configuration.CompilerOptions +using PlutoStaticHTML include("changelog.jl") include("pages.jl") +include("build.jl") + +# build("examples") makedocs(; sitename = "EpiAware.jl", authors = "Samuel Brand, Zachary Susswein, Sam Abbott, and contributors", @@ -11,7 +16,10 @@ makedocs(; sitename = "EpiAware.jl", modules = [EpiAware], pages = pages, format = Documenter.HTML( - prettyurls = get(ENV, "CI", nothing) == "true" + prettyurls = get(ENV, "CI", nothing) == "true", + mathengine = Documenter.MathJax3(), + size_threshold = 600 * 2^10, + size_threshold_warn = 200 * 2^10 ) ) diff --git a/EpiAware/docs/pages.jl b/EpiAware/docs/pages.jl index db333fa73..9d379ceb6 100644 --- a/EpiAware/docs/pages.jl +++ b/EpiAware/docs/pages.jl @@ -3,7 +3,7 @@ pages = [ "Manual" => Any[ "Guide" => "man/guide.md", "Examples" => [ - "Getting started" => "man/getting-started.md" + "Getting started" => "examples/getting_started.md" ] ], "Reference" => Any[ From 2fbe618f248735cf5878cd1c89adc4a700983abe Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Tue, 5 Mar 2024 13:51:56 +0000 Subject: [PATCH 21/40] include build step for rendering Pluto notebooks --- EpiAware/docs/make.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/EpiAware/docs/make.jl b/EpiAware/docs/make.jl index b25a02618..08c038d63 100644 --- a/EpiAware/docs/make.jl +++ b/EpiAware/docs/make.jl @@ -7,7 +7,7 @@ include("changelog.jl") include("pages.jl") include("build.jl") -# build("examples") +build("examples") makedocs(; sitename = "EpiAware.jl", authors = "Samuel Brand, Zachary Susswein, Sam Abbott, and contributors", From d97c92ac6075de4d577031a4f4703b1ea767e9b5 Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Tue, 5 Mar 2024 13:52:11 +0000 Subject: [PATCH 22/40] remove old version of the getting started example --- .../toy_model_log_infs_RW.jl | 193 ------------------ 1 file changed, 193 deletions(-) delete mode 100644 EpiAware/test/predictive_checking/toy_model_log_infs_RW.jl diff --git a/EpiAware/test/predictive_checking/toy_model_log_infs_RW.jl b/EpiAware/test/predictive_checking/toy_model_log_infs_RW.jl deleted file mode 100644 index cce09e19a..000000000 --- a/EpiAware/test/predictive_checking/toy_model_log_infs_RW.jl +++ /dev/null @@ -1,193 +0,0 @@ -#= -# Toy model for running analysis: - -This is a toy model for demonstrating current functionality of EpiAware package. - -## Generative Model without data - -### Latent Process - -The latent process is a random walk defined by a Turing model `random_walk` of specified length `n`. - -_Unfixed parameters_: -- `σ²_RW`: The variance of the random walk process. Current defauly prior is -- `init_rw_value`: The initial value of the random walk process. -- `ϵ_t`: The random noise vector. - -```math -\begin{align} -X(t) &= X(0) + \sigma_{RW} \sum_{t= 1}^n \epsilon_t \\ -X(0) &\sim \mathcal{N}(0, 1) \\ -\epsilon_t &\sim \mathcal{N}(0, 1) \\ -\sigma_{RW} &\sim \text{HalfNormal}(0.05). -\end{align} -``` - -### Log-Infections Model - -The log-infections model is defined by a Turing model `log_infections` that takes the observed data `y_t` (or `missing` value), -an `EpiModel` object `epi_model`, and a `latent_model` model. In this case the latent process is a random walk model. - -It also accepts optional arguments for the `process_priors`, `transform_function`, `pos_shift`, `neg_bin_cluster_factor`, and `neg_bin_cluster_factor_prior`. - -```math -\log I_t = \exp(X(t)). -``` - -### Observation model - -The observation model is a negative binomial distribution with mean `μ` and cluster factor `r`. Delays are implemented -as the action of a sparse kernel on the infections $I(t)$. The delay kernel is contained in an `EpiModel` struct. - -```math -\begin{align} -y_t &\sim \text{NegBinomial}(\mu = \sum_s\geq 0 K[t, t-s] I(s), r), -r &\sim \text{Gamma}(3, 0.05/3). -\end{align} -``` - -## Load dependencies - -This script should be run from Test environment mode. If not, run the following command: - -```julia -using TestEnv # Run in Test environment mode -TestEnv.activate() -``` - -=# - -# using TestEnv # Run in Test environment mode -# TestEnv.activate() - -using EpiAware -using Turing -using Distributions -using StatsPlots -using Random -using DynamicPPL -using Statistics -using DataFramesMeta -using CSV # For outputting the MCMC chain - -Random.seed!(0) - -#= -## Create an `EpiModel` struct - -- Medium length generation interval distribution. -- Median 2 day, std 4.3 day delay distribution. -=# - -truth_GI = Gamma(2, 5) -model_data = EpiData(truth_GI, - D_gen = 10.0) - -log_I0_prior = Normal(0.0, 10.0) -epi_model = DirectInfections(model_data, log_I0_prior) - -#= -## Define the data generating process - -In this case we use the `DirectInfections` model. -=# - -rwp = EpiAware.RandomWalk(Normal(), - truncated(Normal(0.0, 0.01), 0.0, 0.5)) - -#Define the observation model - no delay model -time_horizon = 100 -obs_model = EpiAware.DelayObservations([1.0], - time_horizon, - truncated(Gamma(5, 0.05 / 5), 1e-3, 1.0)) - -#= -## Generate a `Turing` `Model` -We don't have observed data, so we use `missing` value for `y_t`. -=# - -log_infs_model = make_epi_aware(missing, time_horizon, ; epi_model = epi_model, - latent_model_model = rwp, observation_model = obs_model, - pos_shift = 1e-6) - -#= -## Sample from the model -I define a fixed version of the model with initial infections set to 1 and variance of the random walk process set to 0.1. -We can sample from the model using the `rand` function, and plot the generated infections against generated cases. - -We can get the generated infections using `generated_quantities` function. Because the observed -cases are "defined" with a `~` operator they can be accessed directly from the randomly sampled -process. -=# - -cond_toy = fix(log_infs_model, (init = log(1.0), σ²_RW = 0.1)) -random_epidemic = rand(cond_toy) -gen = generated_quantities(cond_toy, random_epidemic) - -plot(gen.I_t, - label = "I_t", - xlabel = "Time", - ylabel = "Infections", - title = "Generated Infections") -scatter!(random_epidemic.y_t, lab = "generated cases") - -#= -## Inference - -We treat the generated data as observed data and attempt to infer underlying infections. -=# - -truth_data = random_epidemic.y_t - -model = make_epi_aware(truth_data, time_horizon, ; epi_model = epi_model, - latent_model_model = rwp, observation_model = obs_model, - pos_shift = 1e-6) -@time chn = sample(model, - NUTS(; adtype = AutoReverseDiff(true)), - MCMCThreads(), - 250, - 4; - drop_warmup = true) - -#= -## Postior predictive checking - -We check the posterior predictive checking by plotting the predicted cases against the observed cases. -=# - -predicted_y_t = mapreduce(hcat, generated_quantities(log_infs_model, chn)) do gen - gen.generated_y_t -end - -plot(predicted_y_t, c = :grey, alpha = 0.05, lab = "") -scatter!(truth_data, - lab = "Observed cases", - xlabel = "Time", - ylabel = "Cases", - title = "Posterior Predictive Checking", - ylims = (-0.5, maximum(truth_data) * 2.5)) - -#= -## Underlying inferred infections -=# - -predicted_I_t = mapreduce(hcat, generated_quantities(log_infs_model, chn)) do gen - gen.I_t -end - -plot(predicted_I_t, c = :grey, alpha = 0.05, lab = "") -scatter!(gen.I_t, - lab = "Actual infections", - xlabel = "Time", - ylabel = "Unobserved Infections", - title = "Posterior Predictive Checking", - ylims = (-0.5, maximum(gen.I_t) * 1.5)) - -#= -## Outputing the MCMC chain -We can use `spread_draws` to convert the MCMC chain into a tidybayes format. -=# - -df_chn = spread_draws(chn) -save_path = joinpath(@__DIR__, "assets/toy_model_log_infs_RW_draws.csv") -CSV.write(save_path, df_chn) From e805792302bafa254a930c548c1051faea80caee Mon Sep 17 00:00:00 2001 From: Sam Abbott Date: Wed, 6 Mar 2024 11:43:53 +0000 Subject: [PATCH 23/40] Update EpiAware/docs/src/examples/getting_started.jl --- EpiAware/docs/src/examples/getting_started.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/EpiAware/docs/src/examples/getting_started.jl b/EpiAware/docs/src/examples/getting_started.jl index 200c20c27..4c4bff630 100644 --- a/EpiAware/docs/src/examples/getting_started.jl +++ b/EpiAware/docs/src/examples/getting_started.jl @@ -186,7 +186,7 @@ By giving `missing` to the first argument, we indicate that case data will be _g # ╔═╡ abeff860-58c3-4644-9325-66ffd4446b6d full_epi_aware_mdl = make_epi_aware(missing, time_horizon; epi_model = epi_model, - latent_model_model = rwp, + latent_model = rwp, observation_model = obs_model) # ╔═╡ 821628fb-8044-48b0-aa4f-0b7b57a2f45a From 14b5d30f229c3abbd3a4ceb3ff45e160dff50e04 Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Wed, 6 Mar 2024 11:59:44 +0000 Subject: [PATCH 24/40] Remove unnecessary `CSV` dep in test --- EpiAware/test/Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/EpiAware/test/Project.toml b/EpiAware/test/Project.toml index 0271e5375..e927523a8 100644 --- a/EpiAware/test/Project.toml +++ b/EpiAware/test/Project.toml @@ -1,6 +1,5 @@ [deps] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" -CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" DataFramesMeta = "1313f7d8-7da2-5740-9ea0-a2ca25f37964" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" From f825d381128f7e5af66df7f20b96ab107c135fa2 Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Wed, 6 Mar 2024 12:06:36 +0000 Subject: [PATCH 25/40] fix getting started --- EpiAware/docs/src/examples/getting_started.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/EpiAware/docs/src/examples/getting_started.jl b/EpiAware/docs/src/examples/getting_started.jl index 4c4bff630..dafc99039 100644 --- a/EpiAware/docs/src/examples/getting_started.jl +++ b/EpiAware/docs/src/examples/getting_started.jl @@ -239,7 +239,7 @@ truth_data = random_epidemic.y_t # ╔═╡ b4033728-b321-4100-8194-1fd9fe2d268d inference_mdl = fix( make_epi_aware(truth_data, time_horizon; epi_model = epi_model, - latent_model_model = rwp, observation_model = obs_model), + latent_model = rwp, observation_model = obs_model), (rw_init = 0.0,) ) From ac788d04a0494e87fd66ccd8c4943437aaba1276 Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Wed, 6 Mar 2024 15:18:24 +0000 Subject: [PATCH 26/40] pathfinder first pass --- EpiAware/Project.toml | 1 + EpiAware/docs/Project.toml | 2 + .../examples/getting_started_pf backup 1.jl | 415 +++++++++++++++++ .../docs/src/examples/getting_started_pf.jl | 437 ++++++++++++++++++ EpiAware/src/inference-methods.jl | 19 + EpiAware/src/models.jl | 35 ++ 6 files changed, 909 insertions(+) create mode 100644 EpiAware/docs/src/examples/getting_started_pf backup 1.jl create mode 100644 EpiAware/docs/src/examples/getting_started_pf.jl create mode 100644 EpiAware/src/inference-methods.jl diff --git a/EpiAware/Project.toml b/EpiAware/Project.toml index cb35f98bc..59b673aaf 100644 --- a/EpiAware/Project.toml +++ b/EpiAware/Project.toml @@ -11,6 +11,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" Optim = "429524aa-4258-5aef-a3af-852621145aeb" Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a" +Pathfinder = "b1d3bc72-d0e7-4279-b92f-7fa5d6d2d454" QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" diff --git a/EpiAware/docs/Project.toml b/EpiAware/docs/Project.toml index 900ecfe73..34efc468c 100644 --- a/EpiAware/docs/Project.toml +++ b/EpiAware/docs/Project.toml @@ -6,8 +6,10 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" EpiAware = "b2eeebe4-5992-4301-9193-7ebc9f62c855" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Pathfinder = "b1d3bc72-d0e7-4279-b92f-7fa5d6d2d454" Pluto = "c3e4b0f8-55cb-11ea-2926-15256bba5781" PlutoStaticHTML = "359b1769-a58e-495b-9770-312e911026ad" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd" +Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" diff --git a/EpiAware/docs/src/examples/getting_started_pf backup 1.jl b/EpiAware/docs/src/examples/getting_started_pf backup 1.jl new file mode 100644 index 000000000..78a7db0c0 --- /dev/null +++ b/EpiAware/docs/src/examples/getting_started_pf backup 1.jl @@ -0,0 +1,415 @@ +### A Pluto.jl notebook ### +# v0.19.40 + +using Markdown +using InteractiveUtils + +# ╔═╡ c593a2a0-d7f5-11ee-0931-d9f65ae84a72 +# hideall +let + docs_dir = dirname(dirname(@__DIR__)) + pkg_dir = dirname(docs_dir) + + using Pkg: Pkg + Pkg.activate(docs_dir) + Pkg.develop(; path = pkg_dir) + Pkg.resolve() + Pkg.instantiate() +end; + +# ╔═╡ da479d8d-1312-4b98-b0af-5be52dffaf3f +begin + using EpiAware + using Turing + using Distributions + using StatsPlots + using Random + using DynamicPPL + using Statistics + using DataFramesMeta + using LinearAlgebra + using Pathfinder +end + +# ╔═╡ 3ebc8384-f73d-4597-83a7-07a3744fed61 +md" +# Getting stated with `EpiAware` + +This tutorial introduces the basic functionality of `EpiAware`. `EpiAware` is a package for making inferences on epidemiological case/determined infection data using a model-based approach. + +## `EpiAware` models +The models we consider are discrete-time $t = 1,\dots, T$ with a latent random process, $Z_t$ generating stochasticity in the number of new infections $I_t$ at each time step. Observations are treated as downstream random variables determined by the actual infections and a model of infection to observation delay. + +#### Mathematical definition +```math +\begin{align} +Z_\cdot &\sim \mathcal{P}(\mathbb{R}^T) | \theta_Z, \\ +I_0 &\sim f_0(\theta_I), \\ +I_t &\sim g_I(\{I_s, Z_s\}_{s < t}, \theta_{I}), \\ +y_t &\sim f_O(\{I_s\}_{s \leq t}, \theta_{O}). +\end{align} +``` +Where, $\mathcal{P}(\mathbb{R}^T) | \theta_Z$ is a parametric process on $\mathbb{R}^T$. $f_0$ and $f_O$ are parametric distributions on, respectively, the initial number of infections and the observed case data conditional on underlying infections. $g_I$ is distribution of new infections conditional on infections and latent process in the past. Note that we assume that new infections are conditional on the strict past, whereas new observations can depend on infections on the same time step. + +#### Code structure outline + +An `EpiAware` model in code is created from three modular components: + +- A `LatentModel`: This defines the model for $Z_\cdot$. This chooses $\mathcal{P}(\mathbb{R}^T) | \theta_Z$. +- An `EpiModel`: This defines a generative process for infections conditional on the latent process. This chooses $f_0(\theta_I)$, and $g_I(\{I_s, Z_s\}_{s < t}, \theta_{I})$. +- An `ObservationModel`: This defines the observation model. This chooses $f_O({I_s}_{s \leq t}, \theta_{O})$ + +#### Reproductive number +`EpiAware` models do not need to specify a time-varying reproductive number $\mathcal{R}_t$ to generate $I_\cdot$, however, this is often a quantity of interest. When not directly used we will typically back-calculate $\mathcal{R}_t$ from the generated infections: + +```math +\mathcal{R}_t = {I_t \over \sum_{s \geq 1} g_s I_{t-s} }. +``` + +Where $g_s$ is a discrete generation interval. For this reason, even when not using a reproductive number approach directly, we ask for a generation interval. +" + +# ╔═╡ 5a0d5ab8-e985-4126-a1ac-58fe08beee38 +md" +## Random walk `LatentModel` + +As an example, we choose the latent process as a random walk with parameters $\theta_Z$: + +- ``Z_0``: Initial position. +- ``\sigma^2_{Z}``: The step-size variance. + +Conditional on the parameters the random walk is then generated by white noise: +```math +\begin{align} +Z_t &= Z_0 + \sigma_{RW} \sum_{t= 1}^T \epsilon_t, \\ +\epsilon_t &\sim \mathcal{N}(0,1). +\end{align} +``` + +In `EpiAware` we provide a constructor for random walk latent models with priors for $\theta_Z$. We choose priors, +```math +\begin{align} +Z_0 &\sim \mathcal{N}(0,1),\\ +\sigma^2_Z &\sim \text{HalfNormal}(0.01). +\end{align} +``` +" + +# ╔═╡ 56ae496b-0094-460b-89cb-526627991717 +rwp = EpiAware.RandomWalk(Normal(), + truncated(Normal(0.0, 0.02), 0.0, Inf)) + +# ╔═╡ 767beffd-1ef5-4e6c-9ac6-edb52e60fb44 +md" +## Direct infection `EpiModel` + +This is a simple model where the unobserved log-infections are directly generated by the latent process $Z$. +```math +\log I_t = \log I_0 + Z_t. +``` + +As discussed above, we still ask for a defined generation interval, which can be used to calculate $\mathcal{R}_t$. + +" + +# ╔═╡ 9e43cbe3-94de-44fc-a788-b9c7adb34218 +truth_GI = Gamma(2, 5) + +# ╔═╡ f067284f-a1a6-44a6-9b79-f8c2de447673 +md" +The `EpiData` constructor performs double interval censoring to convert our _continuous_ estimate of the generation interval into a discretized version. We also implement right truncation using the keyword `D_gen`. +" + +# ╔═╡ c0662d48-4b54-4b6d-8c91-ddf4b0e3aa43 +model_data = EpiData(truth_GI, D_gen = 10.0) + +# ╔═╡ fd72094f-1b95-4d07-a8b0-ef47dc560dfc +md" +We can supply a prior for the initial log_infections. +" + +# ╔═╡ 6639e66f-7725-4976-81b2-6472419d1a62 +log_I0_prior = Normal(log(100.0), 1.0) + +# ╔═╡ df5e59f8-3185-4bed-9cca-7c266df17cec +md" +And construct the `EpiModel`. +" + +# ╔═╡ 6fbdd8e6-2323-4352-9185-1f31a9cf9012 +epi_model = DirectInfections(model_data, log_I0_prior) + +# ╔═╡ 5e62a50a-71f4-4902-b1c9-fdf51fe145fa +md" + + +### Delayed Observations `ObservationModel` + +The observation model is a negative binomial distribution with mean `μ` and cluster factor `1 / r`. Delays are implemented +as the action of a sparse kernel on the infections $I(t)$. + +```math +\begin{align} +y_t &\sim \text{NegBinomial}(\mu = \sum_{s\geq 0} K[t, t-s] I(s), r), \\ +1 / r &\sim \text{Gamma}(3, 0.05/3). +\end{align} +``` +" + +# ╔═╡ e813d547-6100-4c43-b84c-8cebe306bda8 +md" +We also set up the inference to occur over 100 days. +" + +# ╔═╡ c7580ae6-0db5-448e-8b20-4dd6fcdb1ae0 +time_horizon = 100 + +# ╔═╡ 0aa3fcbd-0831-45b8-9a2c-7ffbabf5895f +md" +We choose a simple observation model where infections are observed 0, 1, 2, 3 days later with equal probability. +" + +# ╔═╡ 448669bc-99f4-4823-b15e-fcc9040ba31b +obs_model = EpiAware.DelayObservations( + fill(0.25, 4), + time_horizon, + truncated(Gamma(5, 0.05 / 5), 1e-3, 1.0) +) + +# ╔═╡ e49713e8-4840-4083-8e3f-fc52d791be7b +md" +## Generate cases from the `EpiAware` model + +Having chosen an `EpiModel`, `LatentModel` and `ObservationModel`, we can implement the model as a [`Turing`](https://turinglang.org/dev/) model using `make_epi_aware`. + +By giving `missing` to the first argument, we indicate that case data will be _generated_ from the model rather than treated as fixed. +" + +# ╔═╡ abeff860-58c3-4644-9325-66ffd4446b6d +full_epi_aware_mdl = make_epi_aware(missing, time_horizon; + epi_model = epi_model, + latent_model = rwp, + observation_model = obs_model) + +# ╔═╡ 821628fb-8044-48b0-aa4f-0b7b57a2f45a +md" +We choose some fixed parameters: +- Initial incidence is 100. +- In the direct infection model, the initial incidence and in the initial value of the random walk form a non-identifiable pair. Therefore, we fix $Z_0 = 0$. +" + +# ╔═╡ 36b34fd2-2891-42ca-b5dc-abb482e516ee +fixed_parameters = (rw_init = 0.0, init_incidence = log(100.0)) + +# ╔═╡ 0aadd9e3-7f91-4b45-9663-67d11335f0d0 +md" +We fix these parameters using `fix`, and generate a random epidemic. +" + +# ╔═╡ 7e0e6012-8648-4f84-a25a-8b0138c4b72a +cond_generative_model = fix(full_epi_aware_mdl, fixed_parameters) + +# ╔═╡ b20c28be-7b07-410c-a33b-ea5ad6828c12 +random_epidemic = rand(cond_generative_model) + +# ╔═╡ d073e63b-62da-4743-ace0-78ef7806bc0b +true_infections = generated_quantities(cond_generative_model, random_epidemic).I_t + +# ╔═╡ f68b4e41-ac5c-42cd-a8c2-8761d66f7543 +let + plot(true_infections, + label = "I_t", + xlabel = "Time", + ylabel = "Infections", + title = "Generated Infections") + scatter!(random_epidemic.y_t, lab = "generated cases") +end + +# ╔═╡ b5bc8f05-b538-4abf-aa84-450bf2dff3d9 +md" +## Inference +Fixing $Z_0 = 0$ for the random walk was based on inference principles; in this model $Z_0$ and $\log I_0$ are non-identifiable. + +However, we now treat the generated data as `truth_data` and make inference without fixing any other parameters. + +We do the inference by MCMC/NUTS using the `Turing` NUTS sampler with default warm-up steps. +" + +# ╔═╡ c8ce0d46-a160-4c40-a055-69b3d10d1770 +truth_data = random_epidemic.y_t + +# ╔═╡ b4033728-b321-4100-8194-1fd9fe2d268d +inference_mdl = fix( + make_epi_aware(truth_data, time_horizon; epi_model = epi_model, + latent_model = rwp, observation_model = obs_model), + (rw_init = 0.0,) +) + +# ╔═╡ 619b6a45-b202-47f3-97cf-1ab487278674 +safe_inference_mdl = fix( + make_epi_aware(truth_data, time_horizon, Val(:safe_mode); epi_model = epi_model, + latent_model = rwp, observation_model = obs_model), + (rw_init = 0.0,) +) + +# ╔═╡ 7da045fc-49ba-4a9a-8e5a-f9edb6418f7a +mpf = multipathfinder(safe_inference_mdl, 50; nruns = 4) + +# ╔═╡ 3224f468-63d4-483a-ac65-1a584d23d92c +init_params = collect.(eachrow(mpf.draws_transformed.value[1:4, :, 1])) + +# ╔═╡ 3eb5ec5e-aae7-478e-84fb-80f2e9f85b4c +chn = sample(inference_mdl, + NUTS(; adtype = AutoReverseDiff(true)), + MCMCThreads(), + 250, + 4; + init_params = init_params, + drop_warmup = true) + +# ╔═╡ 30498cc7-16a5-441a-b8cd-c19b220c60c1 +md" +### Predictive plotting + +We can spaghetti plot generated case data from the version of the model _which hasn't conditioned on case data_ using posterior parameters inferred from the version conditioned on observed data. This is known as _posterior predictive checking_, and is a useful diagnostic tool for Bayesian inference (see [here](http://www.stat.columbia.edu/~gelman/book/BDA3.pdf)). + +Because we are using synthetic data we can also plot the model predictions for the _unobserved_ infections and check that (at least in this example) we were able to capture some unobserved/latent variables in the process accurate. +" + +# ╔═╡ e9df22b8-8e4d-4ab7-91ea-c01f2239b3e5 +#=╠═╡ +let + post_check_mdl = fix(full_epi_aware_mdl, (rw_init = 0.0,)) + post_check_y_t = mapreduce(hcat, generated_quantities(post_check_mdl, chn)) do gen + gen.generated_y_t + end + + predicted_I_t = mapreduce(hcat, generated_quantities(inference_mdl, chn)) do gen + gen.I_t + end + + p1 = plot(post_check_y_t, c = :grey, alpha = 0.05, lab = "") + scatter!(p1, truth_data, + lab = "Observed cases", + xlabel = "Time", + ylabel = "Cases", + title = "Post. predictive checking: cases", + ylims = (-0.5, maximum(truth_data) * 1.5), + c = :green) + + p2 = plot(predicted_I_t, c = :grey, alpha = 0.05, lab = "") + scatter!(p2, true_infections, + lab = "Actual infections", + xlabel = "Time", + ylabel = "Unobserved Infections", + title = "Post. predictions: infections", + ylims = (-0.5, maximum(true_infections) * 1.5), + c = :red) + + plot(p1, p2, + layout = (1, 2), + size = (700, 400)) +end + ╠═╡ =# + +# ╔═╡ fd6321b1-4c3a-4123-b0dc-c45b951e0b80 +md" +As well as checking the posterior predictions for latent infections, we can also check how well inference recovered unknown parameters, such as the random walk variance or the cluster factor of the negative binomial observations. +" + +# ╔═╡ 10d8fe24-83a6-47ac-97b7-a374481473d3 +#=╠═╡ +let + parameters_to_plot = (:σ²_RW, :neg_bin_cluster_factor) + + plts = map(parameters_to_plot) do name + var_samples = chn[name] |> vec + histogram(var_samples, + bins = 50, + norm = :pdf, + lw = 0, + fillalpha = 0.5, + lab = "MCMC") + vline!([getfield(random_epidemic, name)], lab = "True value") + title!(string(name)) + end + plot(plts..., layout = (2, 1)) +end + ╠═╡ =# + +# ╔═╡ 81efe8ca-b753-4a12-bafc-a887a999377b +md" +## Reproductive number back-calculation + +As mentioned at the top, we _don't_ directly use the concept of reproductive numbers in this note. However, we can back-calculate the implied $\mathcal{R}(t)$ values, conditional on the specified generation interval being correct. + +Here we spaghetti plot posterior sampled time-varying reproductive numbers against the actual. +" + +# ╔═╡ 15b9f37f-8d5f-460d-8c28-d7f2271fd099 +#=╠═╡ +let + n = epi_model.data.len_gen_int + Rt_denom = [dot(reverse(epi_model.data.gen_int), true_infections[(t - n):(t - 1)]) + for t in (n + 1):length(true_infections)] + true_Rt = true_infections[(n + 1):end] ./ Rt_denom + + predicted_Rt = mapreduce(hcat, generated_quantities(inference_mdl, chn)) do gen + _It = gen.I_t + _Rt_denom = [dot(reverse(epi_model.data.gen_int), _It[(t - n):(t - 1)]) + for t in (n + 1):length(_It)] + Rt = _It[(n + 1):end] ./ _Rt_denom + end + + plt = plot((n + 1):time_horizon, predicted_Rt, c = :grey, alpha = 0.05, lab = "") + plot!(plt, (n + 1):time_horizon, true_Rt, + lab = "true Rt", + xlabel = "Time", + ylabel = "Rt", + title = "Post. predictions: reproductive number", + c = :red, + lw = 2) +end + ╠═╡ =# + +# ╔═╡ Cell order: +# ╠═c593a2a0-d7f5-11ee-0931-d9f65ae84a72 +# ╟─3ebc8384-f73d-4597-83a7-07a3744fed61 +# ╠═da479d8d-1312-4b98-b0af-5be52dffaf3f +# ╟─5a0d5ab8-e985-4126-a1ac-58fe08beee38 +# ╠═56ae496b-0094-460b-89cb-526627991717 +# ╟─767beffd-1ef5-4e6c-9ac6-edb52e60fb44 +# ╠═9e43cbe3-94de-44fc-a788-b9c7adb34218 +# ╟─f067284f-a1a6-44a6-9b79-f8c2de447673 +# ╠═c0662d48-4b54-4b6d-8c91-ddf4b0e3aa43 +# ╟─fd72094f-1b95-4d07-a8b0-ef47dc560dfc +# ╠═6639e66f-7725-4976-81b2-6472419d1a62 +# ╟─df5e59f8-3185-4bed-9cca-7c266df17cec +# ╠═6fbdd8e6-2323-4352-9185-1f31a9cf9012 +# ╟─5e62a50a-71f4-4902-b1c9-fdf51fe145fa +# ╟─e813d547-6100-4c43-b84c-8cebe306bda8 +# ╠═c7580ae6-0db5-448e-8b20-4dd6fcdb1ae0 +# ╟─0aa3fcbd-0831-45b8-9a2c-7ffbabf5895f +# ╠═448669bc-99f4-4823-b15e-fcc9040ba31b +# ╟─e49713e8-4840-4083-8e3f-fc52d791be7b +# ╠═abeff860-58c3-4644-9325-66ffd4446b6d +# ╟─821628fb-8044-48b0-aa4f-0b7b57a2f45a +# ╠═36b34fd2-2891-42ca-b5dc-abb482e516ee +# ╟─0aadd9e3-7f91-4b45-9663-67d11335f0d0 +# ╠═7e0e6012-8648-4f84-a25a-8b0138c4b72a +# ╠═b20c28be-7b07-410c-a33b-ea5ad6828c12 +# ╠═d073e63b-62da-4743-ace0-78ef7806bc0b +# ╟─f68b4e41-ac5c-42cd-a8c2-8761d66f7543 +# ╟─b5bc8f05-b538-4abf-aa84-450bf2dff3d9 +# ╠═c8ce0d46-a160-4c40-a055-69b3d10d1770 +# ╠═b4033728-b321-4100-8194-1fd9fe2d268d +# ╠═619b6a45-b202-47f3-97cf-1ab487278674 +# ╠═7da045fc-49ba-4a9a-8e5a-f9edb6418f7a +# ╠═3224f468-63d4-483a-ac65-1a584d23d92c +# ╠═3eb5ec5e-aae7-478e-84fb-80f2e9f85b4c +# ╟─30498cc7-16a5-441a-b8cd-c19b220c60c1 +# ╠═e9df22b8-8e4d-4ab7-91ea-c01f2239b3e5 +# ╟─fd6321b1-4c3a-4123-b0dc-c45b951e0b80 +# ╠═10d8fe24-83a6-47ac-97b7-a374481473d3 +# ╟─81efe8ca-b753-4a12-bafc-a887a999377b +# ╠═15b9f37f-8d5f-460d-8c28-d7f2271fd099 diff --git a/EpiAware/docs/src/examples/getting_started_pf.jl b/EpiAware/docs/src/examples/getting_started_pf.jl new file mode 100644 index 000000000..e97fb72fd --- /dev/null +++ b/EpiAware/docs/src/examples/getting_started_pf.jl @@ -0,0 +1,437 @@ +### A Pluto.jl notebook ### +# v0.19.40 + +using Markdown +using InteractiveUtils + +# ╔═╡ c593a2a0-d7f5-11ee-0931-d9f65ae84a72 +# hideall +let + docs_dir = dirname(dirname(@__DIR__)) + pkg_dir = dirname(docs_dir) + + using Pkg: Pkg + Pkg.activate(docs_dir) + Pkg.develop(; path = pkg_dir) + Pkg.resolve() + Pkg.instantiate() +end; + +# ╔═╡ da479d8d-1312-4b98-b0af-5be52dffaf3f +begin + using EpiAware + using Turing + using Distributions + using StatsPlots + using Random + using DynamicPPL + using Statistics + using DataFramesMeta + using LinearAlgebra + using Pathfinder + using Transducers + + Random.seed!(1) +end + +# ╔═╡ 3ebc8384-f73d-4597-83a7-07a3744fed61 +md" +# Getting stated with `EpiAware` + +This tutorial introduces the basic functionality of `EpiAware`. `EpiAware` is a package for making inferences on epidemiological case/determined infection data using a model-based approach. + +## `EpiAware` models +The models we consider are discrete-time $t = 1,\dots, T$ with a latent random process, $Z_t$ generating stochasticity in the number of new infections $I_t$ at each time step. Observations are treated as downstream random variables determined by the actual infections and a model of infection to observation delay. + +#### Mathematical definition +```math +\begin{align} +Z_\cdot &\sim \mathcal{P}(\mathbb{R}^T) | \theta_Z, \\ +I_0 &\sim f_0(\theta_I), \\ +I_t &\sim g_I(\{I_s, Z_s\}_{s < t}, \theta_{I}), \\ +y_t &\sim f_O(\{I_s\}_{s \leq t}, \theta_{O}). +\end{align} +``` +Where, $\mathcal{P}(\mathbb{R}^T) | \theta_Z$ is a parametric process on $\mathbb{R}^T$. $f_0$ and $f_O$ are parametric distributions on, respectively, the initial number of infections and the observed case data conditional on underlying infections. $g_I$ is distribution of new infections conditional on infections and latent process in the past. Note that we assume that new infections are conditional on the strict past, whereas new observations can depend on infections on the same time step. + +#### Code structure outline + +An `EpiAware` model in code is created from three modular components: + +- A `LatentModel`: This defines the model for $Z_\cdot$. This chooses $\mathcal{P}(\mathbb{R}^T) | \theta_Z$. +- An `EpiModel`: This defines a generative process for infections conditional on the latent process. This chooses $f_0(\theta_I)$, and $g_I(\{I_s, Z_s\}_{s < t}, \theta_{I})$. +- An `ObservationModel`: This defines the observation model. This chooses $f_O({I_s}_{s \leq t}, \theta_{O})$ + +#### Reproductive number +`EpiAware` models do not need to specify a time-varying reproductive number $\mathcal{R}_t$ to generate $I_\cdot$, however, this is often a quantity of interest. When not directly used we will typically back-calculate $\mathcal{R}_t$ from the generated infections: + +```math +\mathcal{R}_t = {I_t \over \sum_{s \geq 1} g_s I_{t-s} }. +``` + +Where $g_s$ is a discrete generation interval. For this reason, even when not using a reproductive number approach directly, we ask for a generation interval. +" + +# ╔═╡ 5a0d5ab8-e985-4126-a1ac-58fe08beee38 +md" +## Random walk `LatentModel` + +As an example, we choose the latent process as a random walk with parameters $\theta_Z$: + +- ``Z_0``: Initial position. +- ``\sigma^2_{Z}``: The step-size variance. + +Conditional on the parameters the random walk is then generated by white noise: +```math +\begin{align} +Z_t &= Z_0 + \sigma_{RW} \sum_{t= 1}^T \epsilon_t, \\ +\epsilon_t &\sim \mathcal{N}(0,1). +\end{align} +``` + +In `EpiAware` we provide a constructor for random walk latent models with priors for $\theta_Z$. We choose priors, +```math +\begin{align} +Z_0 &\sim \mathcal{N}(0,1),\\ +\sigma^2_Z &\sim \text{HalfNormal}(0.01). +\end{align} +``` +" + +# ╔═╡ 56ae496b-0094-460b-89cb-526627991717 +rwp = EpiAware.RandomWalk(Normal(), + truncated(Normal(0.0, 0.02), 0.0, Inf)) + +# ╔═╡ 767beffd-1ef5-4e6c-9ac6-edb52e60fb44 +md" +## Direct infection `EpiModel` + +This is a simple model where the unobserved log-infections are directly generated by the latent process $Z$. +```math +\log I_t = \log I_0 + Z_t. +``` + +As discussed above, we still ask for a defined generation interval, which can be used to calculate $\mathcal{R}_t$. + +" + +# ╔═╡ 9e43cbe3-94de-44fc-a788-b9c7adb34218 +truth_GI = Gamma(2, 5) + +# ╔═╡ f067284f-a1a6-44a6-9b79-f8c2de447673 +md" +The `EpiData` constructor performs double interval censoring to convert our _continuous_ estimate of the generation interval into a discretized version. We also implement right truncation using the keyword `D_gen`. +" + +# ╔═╡ c0662d48-4b54-4b6d-8c91-ddf4b0e3aa43 +model_data = EpiData(truth_GI, D_gen = 10.0) + +# ╔═╡ fd72094f-1b95-4d07-a8b0-ef47dc560dfc +md" +We can supply a prior for the initial log_infections. +" + +# ╔═╡ 6639e66f-7725-4976-81b2-6472419d1a62 +log_I0_prior = Normal(log(100.0), 1.0) + +# ╔═╡ df5e59f8-3185-4bed-9cca-7c266df17cec +md" +And construct the `EpiModel`. +" + +# ╔═╡ 6fbdd8e6-2323-4352-9185-1f31a9cf9012 +epi_model = DirectInfections(model_data, log_I0_prior) + +# ╔═╡ 5e62a50a-71f4-4902-b1c9-fdf51fe145fa +md" + + +### Delayed Observations `ObservationModel` + +The observation model is a negative binomial distribution with mean `μ` and cluster factor `1 / r`. Delays are implemented +as the action of a sparse kernel on the infections $I(t)$. + +```math +\begin{align} +y_t &\sim \text{NegBinomial}(\mu = \sum_{s\geq 0} K[t, t-s] I(s), r), \\ +1 / r &\sim \text{Gamma}(3, 0.05/3). +\end{align} +``` +" + +# ╔═╡ e813d547-6100-4c43-b84c-8cebe306bda8 +md" +We also set up the inference to occur over 100 days. +" + +# ╔═╡ c7580ae6-0db5-448e-8b20-4dd6fcdb1ae0 +time_horizon = 100 + +# ╔═╡ 0aa3fcbd-0831-45b8-9a2c-7ffbabf5895f +md" +We choose a simple observation model where infections are observed 0, 1, 2, 3 days later with equal probability. +" + +# ╔═╡ 448669bc-99f4-4823-b15e-fcc9040ba31b +obs_model = EpiAware.DelayObservations( + fill(0.25, 4), + time_horizon, + truncated(Gamma(5, 0.05 / 5), 1e-3, 1.0) +) + +# ╔═╡ e49713e8-4840-4083-8e3f-fc52d791be7b +md" +## Generate cases from the `EpiAware` model + +Having chosen an `EpiModel`, `LatentModel` and `ObservationModel`, we can implement the model as a [`Turing`](https://turinglang.org/dev/) model using `make_epi_aware`. + +By giving `missing` to the first argument, we indicate that case data will be _generated_ from the model rather than treated as fixed. +" + +# ╔═╡ abeff860-58c3-4644-9325-66ffd4446b6d +full_epi_aware_mdl = make_epi_aware(missing, time_horizon; + epi_model = epi_model, + latent_model = rwp, + observation_model = obs_model) + +# ╔═╡ 821628fb-8044-48b0-aa4f-0b7b57a2f45a +md" +We choose some fixed parameters: +- Initial incidence is 100. +- In the direct infection model, the initial incidence and in the initial value of the random walk form a non-identifiable pair. Therefore, we fix $Z_0 = 0$. +" + +# ╔═╡ 36b34fd2-2891-42ca-b5dc-abb482e516ee +fixed_parameters = (rw_init = 0.0, init_incidence = log(100.0)) + +# ╔═╡ 0aadd9e3-7f91-4b45-9663-67d11335f0d0 +md" +We fix these parameters using `fix`, and generate a random epidemic. +" + +# ╔═╡ 7e0e6012-8648-4f84-a25a-8b0138c4b72a +cond_generative_model = fix(full_epi_aware_mdl, fixed_parameters) + +# ╔═╡ b20c28be-7b07-410c-a33b-ea5ad6828c12 +random_epidemic = rand(cond_generative_model) + +# ╔═╡ d073e63b-62da-4743-ace0-78ef7806bc0b +true_infections = generated_quantities(cond_generative_model, random_epidemic).I_t + +# ╔═╡ f68b4e41-ac5c-42cd-a8c2-8761d66f7543 +let + plot(true_infections, + label = "I_t", + xlabel = "Time", + ylabel = "Infections", + title = "Generated Infections") + scatter!(random_epidemic.y_t, lab = "generated cases") +end + +# ╔═╡ b5bc8f05-b538-4abf-aa84-450bf2dff3d9 +md" +## Inference +Fixing $Z_0 = 0$ for the random walk was based on inference principles; in this model $Z_0$ and $\log I_0$ are non-identifiable. + +However, we now treat the generated data as `truth_data` and make inference without fixing any other parameters. + +We do the inference by MCMC/NUTS using the `Turing` NUTS sampler with default warm-up steps. +" + +# ╔═╡ c8ce0d46-a160-4c40-a055-69b3d10d1770 +truth_data = random_epidemic.y_t + +# ╔═╡ bd4e3762-a711-43b7-8dbe-d5a1f690f021 +function _epi_aware(y_t, + time_steps; + epi_model::AbstractEpiModel, + latent_model::AbstractLatentModel, + observation_model::AbstractObservationModel, + nsamples, + nchains, + pf_ndraws = 50, + pf_nruns = 4, + fixed_parameters = (;), + pos_shift = 1e-6, + executor = Transducers.ThreadedEx(), + adtype = AutoReverseDiff(true), + maxiters = 10, + kwargs...) + gen_mdl = make_epi_aware(missing, time_horizon; epi_model, + latent_model, observation_model, pos_shift) |> + mdl -> fix(mdl, fixed_parameters) + + mdl = make_epi_aware(y_t, time_horizon; epi_model, + latent_model, observation_model, pos_shift) |> + mdl -> fix(mdl, fixed_parameters) + + safe_mdl = make_epi_aware(y_t, time_horizon, Val(:safe_mode); epi_model, + latent_model, observation_model, pos_shift) |> + mdl -> fix(mdl, fixed_parameters) + + mpf = multipathfinder(safe_mdl, max(pf_ndraws, nchains); + nruns = pf_nruns, + executor, + maxiters, + kwargs...) + + init_params = collect.(eachrow(mpf.draws_transformed.value[1:nchains, :, 1])) + + chn = sample(mdl, + NUTS(; adtype), + MCMCThreads(), + nsamples ÷ nchains, + nchains; + init_params = init_params, + drop_warmup = true) + + return chn, (; pathfinder_res = mpf, + inference_mdl = mdl, + generative_mdl = gen_mdl) +end + +# ╔═╡ d0ed77f0-b27b-49af-a682-5ad567fe2d45 +@time chn, epi_mdls = _epi_aware( + truth_data, time_horizon; epi_model = epi_model, latent_model = rwp, + observation_model = obs_model, nsamples = 1000, nchains = 4, + fixed_parameters = (rw_init = 0.0,)) + +# ╔═╡ 30498cc7-16a5-441a-b8cd-c19b220c60c1 +md" +### Predictive plotting + +We can spaghetti plot generated case data from the version of the model _which hasn't conditioned on case data_ using posterior parameters inferred from the version conditioned on observed data. This is known as _posterior predictive checking_, and is a useful diagnostic tool for Bayesian inference (see [here](http://www.stat.columbia.edu/~gelman/book/BDA3.pdf)). + +Because we are using synthetic data we can also plot the model predictions for the _unobserved_ infections and check that (at least in this example) we were able to capture some unobserved/latent variables in the process accurate. +" + +# ╔═╡ e9df22b8-8e4d-4ab7-91ea-c01f2239b3e5 +let + post_check_mdl = epi_mdls.generative_mdl + inference_mdl = epi_mdls.inference_mdl + post_check_y_t = mapreduce(hcat, generated_quantities(post_check_mdl, chn)) do gen + gen.generated_y_t + end + + predicted_I_t = mapreduce(hcat, generated_quantities(inference_mdl, chn)) do gen + gen.I_t + end + + p1 = plot(post_check_y_t, c = :grey, alpha = 0.05, lab = "") + scatter!(p1, truth_data, + lab = "Observed cases", + xlabel = "Time", + ylabel = "Cases", + title = "Post. predictive checking: cases", + ylims = (-0.5, maximum(truth_data) * 1.5), + c = :green) + + p2 = plot(predicted_I_t, c = :grey, alpha = 0.05, lab = "") + scatter!(p2, true_infections, + lab = "Actual infections", + xlabel = "Time", + ylabel = "Unobserved Infections", + title = "Post. predictions: infections", + ylims = (-0.5, maximum(true_infections) * 1.5), + c = :red) + + plot(p1, p2, + layout = (1, 2), + size = (700, 400)) +end + +# ╔═╡ fd6321b1-4c3a-4123-b0dc-c45b951e0b80 +md" +As well as checking the posterior predictions for latent infections, we can also check how well inference recovered unknown parameters, such as the random walk variance or the cluster factor of the negative binomial observations. +" + +# ╔═╡ 10d8fe24-83a6-47ac-97b7-a374481473d3 +let + parameters_to_plot = (:σ²_RW, :neg_bin_cluster_factor) + + plts = map(parameters_to_plot) do name + var_samples = chn[name] |> vec + histogram(var_samples, + bins = 50, + norm = :pdf, + lw = 0, + fillalpha = 0.5, + lab = "MCMC") + vline!([getfield(random_epidemic, name)], lab = "True value") + title!(string(name)) + end + plot(plts..., layout = (2, 1)) +end + +# ╔═╡ 81efe8ca-b753-4a12-bafc-a887a999377b +md" +## Reproductive number back-calculation + +As mentioned at the top, we _don't_ directly use the concept of reproductive numbers in this note. However, we can back-calculate the implied $\mathcal{R}(t)$ values, conditional on the specified generation interval being correct. + +Here we spaghetti plot posterior sampled time-varying reproductive numbers against the actual. +" + +# ╔═╡ 15b9f37f-8d5f-460d-8c28-d7f2271fd099 +let + inference_mdl = epi_mdls.inference_mdl + n = epi_model.data.len_gen_int + Rt_denom = [dot(reverse(epi_model.data.gen_int), true_infections[(t - n):(t - 1)]) + for t in (n + 1):length(true_infections)] + true_Rt = true_infections[(n + 1):end] ./ Rt_denom + + predicted_Rt = mapreduce(hcat, generated_quantities(inference_mdl, chn)) do gen + _It = gen.I_t + _Rt_denom = [dot(reverse(epi_model.data.gen_int), _It[(t - n):(t - 1)]) + for t in (n + 1):length(_It)] + Rt = _It[(n + 1):end] ./ _Rt_denom + end + + plt = plot((n + 1):time_horizon, predicted_Rt, c = :grey, alpha = 0.05, lab = "") + plot!(plt, (n + 1):time_horizon, true_Rt, + lab = "true Rt", + xlabel = "Time", + ylabel = "Rt", + title = "Post. predictions: reproductive number", + c = :red, + lw = 2) +end + +# ╔═╡ Cell order: +# ╠═c593a2a0-d7f5-11ee-0931-d9f65ae84a72 +# ╟─3ebc8384-f73d-4597-83a7-07a3744fed61 +# ╠═da479d8d-1312-4b98-b0af-5be52dffaf3f +# ╟─5a0d5ab8-e985-4126-a1ac-58fe08beee38 +# ╠═56ae496b-0094-460b-89cb-526627991717 +# ╟─767beffd-1ef5-4e6c-9ac6-edb52e60fb44 +# ╠═9e43cbe3-94de-44fc-a788-b9c7adb34218 +# ╟─f067284f-a1a6-44a6-9b79-f8c2de447673 +# ╠═c0662d48-4b54-4b6d-8c91-ddf4b0e3aa43 +# ╟─fd72094f-1b95-4d07-a8b0-ef47dc560dfc +# ╠═6639e66f-7725-4976-81b2-6472419d1a62 +# ╟─df5e59f8-3185-4bed-9cca-7c266df17cec +# ╠═6fbdd8e6-2323-4352-9185-1f31a9cf9012 +# ╟─5e62a50a-71f4-4902-b1c9-fdf51fe145fa +# ╟─e813d547-6100-4c43-b84c-8cebe306bda8 +# ╠═c7580ae6-0db5-448e-8b20-4dd6fcdb1ae0 +# ╟─0aa3fcbd-0831-45b8-9a2c-7ffbabf5895f +# ╠═448669bc-99f4-4823-b15e-fcc9040ba31b +# ╟─e49713e8-4840-4083-8e3f-fc52d791be7b +# ╠═abeff860-58c3-4644-9325-66ffd4446b6d +# ╟─821628fb-8044-48b0-aa4f-0b7b57a2f45a +# ╠═36b34fd2-2891-42ca-b5dc-abb482e516ee +# ╟─0aadd9e3-7f91-4b45-9663-67d11335f0d0 +# ╠═7e0e6012-8648-4f84-a25a-8b0138c4b72a +# ╠═b20c28be-7b07-410c-a33b-ea5ad6828c12 +# ╠═d073e63b-62da-4743-ace0-78ef7806bc0b +# ╟─f68b4e41-ac5c-42cd-a8c2-8761d66f7543 +# ╟─b5bc8f05-b538-4abf-aa84-450bf2dff3d9 +# ╠═c8ce0d46-a160-4c40-a055-69b3d10d1770 +# ╠═bd4e3762-a711-43b7-8dbe-d5a1f690f021 +# ╠═d0ed77f0-b27b-49af-a682-5ad567fe2d45 +# ╟─30498cc7-16a5-441a-b8cd-c19b220c60c1 +# ╠═e9df22b8-8e4d-4ab7-91ea-c01f2239b3e5 +# ╟─fd6321b1-4c3a-4123-b0dc-c45b951e0b80 +# ╠═10d8fe24-83a6-47ac-97b7-a374481473d3 +# ╟─81efe8ca-b753-4a12-bafc-a887a999377b +# ╠═15b9f37f-8d5f-460d-8c28-d7f2271fd099 diff --git a/EpiAware/src/inference-methods.jl b/EpiAware/src/inference-methods.jl new file mode 100644 index 000000000..e3accbce3 --- /dev/null +++ b/EpiAware/src/inference-methods.jl @@ -0,0 +1,19 @@ +@model function safe_mode_model(model, + y_t, + time_steps; + epi_model::AbstractEpiModel, + latent_model::AbstractLatentModel, + observation_model::AbstractObservationModel, + pos_shift = 1e-6) + try + @submodel model(y_t, + time_steps; + epi_model = epi_model, + latent_model = latent_model, + observation_model = observation_model, + pos_shift = pos_shift) + catch + Turing.@addlogprob! -Inf + return + end +end diff --git a/EpiAware/src/models.jl b/EpiAware/src/models.jl index a4ea306a5..c85ff6b7c 100644 --- a/EpiAware/src/models.jl +++ b/EpiAware/src/models.jl @@ -25,3 +25,38 @@ latent_model, process_aux = merge(latent_model_aux, generated_y_t_aux)) end + +@model function make_epi_aware(y_t, + time_steps, + ::Val{:safe_mode}; + epi_model::AbstractEpiModel, + latent_model::AbstractLatentModel, + observation_model::AbstractObservationModel, + pos_shift = 1e-6) + try + #Latent process + @submodel Z_t, latent_model_aux = generate_latent( + latent_model, + time_steps) + + #Transform into infections + @submodel I_t = generate_latent_infs(epi_model, Z_t) + + #Predictive distribution of ascerted cases + @submodel generated_y_t, generated_y_t_aux = generate_observations( + observation_model, + y_t, + I_t; + pos_shift = pos_shift) + + #Generate quantities + return (; + generated_y_t, + I_t, + latent_model, + process_aux = merge(latent_model_aux, generated_y_t_aux)) + catch + Turing.@addlogprob! -Inf + return + end +end From f6b892628cb495813eefcf7557e5942bac44b210 Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Wed, 6 Mar 2024 16:15:47 +0000 Subject: [PATCH 27/40] add deps which are required for `_epi_aware` --- EpiAware/Project.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/EpiAware/Project.toml b/EpiAware/Project.toml index 59b673aaf..c6cce63ce 100644 --- a/EpiAware/Project.toml +++ b/EpiAware/Project.toml @@ -7,6 +7,7 @@ version = "0.1.0-DEV" DataFramesMeta = "1313f7d8-7da2-5740-9ea0-a2ca25f37964" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" Optim = "429524aa-4258-5aef-a3af-852621145aeb" @@ -16,6 +17,7 @@ QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" [compat] From 2586edcd9c61c38590a33775558769b9230196e3 Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Wed, 6 Mar 2024 16:17:42 +0000 Subject: [PATCH 28/40] add `_epi_aware` function --- EpiAware/docs/src/examples/getting_started.jl | 34 +++++------ EpiAware/src/EpiAware.jl | 3 +- EpiAware/src/inference-methods.jl | 60 ++++++++++++++----- 3 files changed, 64 insertions(+), 33 deletions(-) diff --git a/EpiAware/docs/src/examples/getting_started.jl b/EpiAware/docs/src/examples/getting_started.jl index dafc99039..e2949d3f7 100644 --- a/EpiAware/docs/src/examples/getting_started.jl +++ b/EpiAware/docs/src/examples/getting_started.jl @@ -1,3 +1,4 @@ + ### A Pluto.jl notebook ### # v0.19.40 @@ -13,6 +14,7 @@ let using Pkg: Pkg Pkg.activate(docs_dir) Pkg.develop(; path = pkg_dir) + Pkg.resolve() Pkg.instantiate() end; @@ -27,6 +29,10 @@ begin using Statistics using DataFramesMeta using LinearAlgebra + using Pathfinder + using Transducers + + Random.seed!(1234) end # ╔═╡ 3ebc8384-f73d-4597-83a7-07a3744fed61 @@ -236,20 +242,11 @@ We do the inference by MCMC/NUTS using the `Turing` NUTS sampler with default wa # ╔═╡ c8ce0d46-a160-4c40-a055-69b3d10d1770 truth_data = random_epidemic.y_t -# ╔═╡ b4033728-b321-4100-8194-1fd9fe2d268d -inference_mdl = fix( - make_epi_aware(truth_data, time_horizon; epi_model = epi_model, - latent_model = rwp, observation_model = obs_model), - (rw_init = 0.0,) -) - -# ╔═╡ 3eb5ec5e-aae7-478e-84fb-80f2e9f85b4c -chn = sample(inference_mdl, - NUTS(; adtype = AutoReverseDiff(true)), - MCMCThreads(), - 250, - 4; - drop_warmup = true) +# ╔═╡ d0ed77f0-b27b-49af-a682-5ad567fe2d45 +@time chn, epi_mdls = EpiAware._epi_aware( + truth_data, time_horizon; epi_model = epi_model, latent_model = rwp, + observation_model = obs_model, nsamples = 1000, nchains = 4, + fixed_parameters = (rw_init = 0.0,)) # ╔═╡ 30498cc7-16a5-441a-b8cd-c19b220c60c1 md" @@ -262,7 +259,8 @@ Because we are using synthetic data we can also plot the model predictions for t # ╔═╡ e9df22b8-8e4d-4ab7-91ea-c01f2239b3e5 let - post_check_mdl = fix(full_epi_aware_mdl, (rw_init = 0.0,)) + post_check_mdl = epi_mdls.generative_mdl + inference_mdl = epi_mdls.inference_mdl post_check_y_t = mapreduce(hcat, generated_quantities(post_check_mdl, chn)) do gen gen.generated_y_t end @@ -328,6 +326,7 @@ Here we spaghetti plot posterior sampled time-varying reproductive numbers again # ╔═╡ 15b9f37f-8d5f-460d-8c28-d7f2271fd099 let + inference_mdl = epi_mdls.inference_mdl n = epi_model.data.len_gen_int Rt_denom = [dot(reverse(epi_model.data.gen_int), true_infections[(t - n):(t - 1)]) for t in (n + 1):length(true_infections)] @@ -351,7 +350,7 @@ let end # ╔═╡ Cell order: -# ╟─c593a2a0-d7f5-11ee-0931-d9f65ae84a72 +# ╠═c593a2a0-d7f5-11ee-0931-d9f65ae84a72 # ╟─3ebc8384-f73d-4597-83a7-07a3744fed61 # ╠═da479d8d-1312-4b98-b0af-5be52dffaf3f # ╟─5a0d5ab8-e985-4126-a1ac-58fe08beee38 @@ -380,8 +379,7 @@ end # ╟─f68b4e41-ac5c-42cd-a8c2-8761d66f7543 # ╟─b5bc8f05-b538-4abf-aa84-450bf2dff3d9 # ╠═c8ce0d46-a160-4c40-a055-69b3d10d1770 -# ╠═b4033728-b321-4100-8194-1fd9fe2d268d -# ╠═3eb5ec5e-aae7-478e-84fb-80f2e9f85b4c +# ╠═d0ed77f0-b27b-49af-a682-5ad567fe2d45 # ╟─30498cc7-16a5-441a-b8cd-c19b220c60c1 # ╠═e9df22b8-8e4d-4ab7-91ea-c01f2239b3e5 # ╟─fd6321b1-4c3a-4123-b0dc-c45b951e0b80 diff --git a/EpiAware/src/EpiAware.jl b/EpiAware/src/EpiAware.jl index a33275686..3443d7f75 100644 --- a/EpiAware/src/EpiAware.jl +++ b/EpiAware/src/EpiAware.jl @@ -34,7 +34,7 @@ module EpiAware using Distributions, Turing, LogExpFunctions, LinearAlgebra, SparseArrays, Random, ReverseDiff, Optim, Parameters, QuadGK, DataFramesMeta, - DocStringExtensions + DocStringExtensions, Pathfinder, DynamicPPL, Transducers # Exported abstract types export AbstractModel, AbstractEpiModel, AbstractLatentModel, @@ -61,5 +61,6 @@ include("utilities.jl") include("latent-models.jl") include("observation-models.jl") include("models.jl") +include("inference-methods.jl") end diff --git a/EpiAware/src/inference-methods.jl b/EpiAware/src/inference-methods.jl index e3accbce3..6a783ad1d 100644 --- a/EpiAware/src/inference-methods.jl +++ b/EpiAware/src/inference-methods.jl @@ -1,19 +1,51 @@ -@model function safe_mode_model(model, - y_t, +""" + +Do inference on an EpiAware model. +""" +function _epi_aware(y_t, time_steps; epi_model::AbstractEpiModel, latent_model::AbstractLatentModel, observation_model::AbstractObservationModel, - pos_shift = 1e-6) - try - @submodel model(y_t, - time_steps; - epi_model = epi_model, - latent_model = latent_model, - observation_model = observation_model, - pos_shift = pos_shift) - catch - Turing.@addlogprob! -Inf - return - end + nsamples, + nchains, + pf_ndraws = 10, + pf_nruns = 10, + fixed_parameters = (;), + pos_shift = 1e-6, + executor = Transducers.ThreadedEx(), + adtype = AutoReverseDiff(true), + maxiters = 10, + kwargs...) + gen_mdl = make_epi_aware(missing, time_steps; epi_model, + latent_model, observation_model, pos_shift) |> + mdl -> fix(mdl, fixed_parameters) + + mdl = make_epi_aware(y_t, time_steps; epi_model, + latent_model, observation_model, pos_shift) |> + mdl -> fix(mdl, fixed_parameters) + + safe_mdl = make_epi_aware(y_t, time_steps, Val(:safe_mode); epi_model, + latent_model, observation_model, pos_shift) |> + mdl -> fix(mdl, fixed_parameters) + + mpf = multipathfinder(safe_mdl, max(pf_ndraws, nchains); + nruns = pf_nruns, + executor, + maxiters, + kwargs...) + + init_params = collect.(eachrow(mpf.draws_transformed.value[(end - nchains):end, :, 1])) + + chn = sample(mdl, + NUTS(; adtype), + MCMCThreads(), + nsamples ÷ nchains, + nchains; + init_params = init_params, + drop_warmup = true) + + return chn, (; pathfinder_res = mpf, + inference_mdl = mdl, + generative_mdl = gen_mdl) end From c21df4ce17bcbc12566a89f223141463d9cad6e2 Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Wed, 6 Mar 2024 16:18:02 +0000 Subject: [PATCH 29/40] Add _epi_aware unit test --- EpiAware/test/Project.toml | 1 + EpiAware/test/test_inference-methods.jl | 55 +++++++++++++++++++++++++ 2 files changed, 56 insertions(+) create mode 100644 EpiAware/test/test_inference-methods.jl diff --git a/EpiAware/test/Project.toml b/EpiAware/test/Project.toml index e927523a8..eb8752d5f 100644 --- a/EpiAware/test/Project.toml +++ b/EpiAware/test/Project.toml @@ -11,4 +11,5 @@ StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a" TestItems = "1c621080-faea-4a02-84b6-bbd5e436b8fe" +Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" diff --git a/EpiAware/test/test_inference-methods.jl b/EpiAware/test/test_inference-methods.jl new file mode 100644 index 000000000..73a3c8148 --- /dev/null +++ b/EpiAware/test/test_inference-methods.jl @@ -0,0 +1,55 @@ +@testitem "Testing _epi_aware function" begin + using Transducers, Turing + time_steps = 20 + + y_t = fill(10, time_steps) + nsamples = 10 + nchains = 2 + pf_ndraws = 10 + pf_nruns = 10 + fixed_parameters = (;) + pos_shift = 1e-6 + executor = Transducers.ThreadedEx() + adtype = AutoReverseDiff(true) + maxiters = 10 + + #Define the epi_model + epi_model = DirectInfections(EpiData([0.2, 0.3, 0.5], exp), Normal()) + + #Define the latent process model + rwp = EpiAware.RandomWalk(Normal(0.0, 1.0), + truncated(Normal(0.0, 0.05), 0.0, Inf)) + + #Define the observation model + delay_distribution = Gamma(2.0, 5 / 2) + time_horizon = time_steps + D_delay = 14.0 + Δd = 1.0 + + obs_model = EpiAware.DelayObservations(delay_distribution = delay_distribution, + time_horizon = time_horizon, + neg_bin_cluster_factor_prior = Gamma(5, 0.05 / 5), + D_delay = D_delay, + Δd = Δd) + + # Call the _epi_aware function to check this runs + chn, results = EpiAware._epi_aware(y_t, time_steps; + epi_model = epi_model, + latent_model = rwp, + observation_model = obs_model, + nsamples = nsamples, + nchains = nchains, + pf_ndraws = pf_ndraws, + pf_nruns = pf_nruns, + fixed_parameters = fixed_parameters, + pos_shift = pos_shift, + executor = executor, + adtype = adtype, + maxiters = maxiters) + + # Perform assertions to check the correctness of the results + @test size(chn, 1) == nsamples ÷ nchains + @test haskey(results, :pathfinder_res) + @test haskey(results, :inference_mdl) + @test haskey(results, :generative_mdl) +end From 30173c6f289b24f56324969d78441e637472ebc3 Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Wed, 6 Mar 2024 16:21:45 +0000 Subject: [PATCH 30/40] add project dep compat --- EpiAware/Project.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/EpiAware/Project.toml b/EpiAware/Project.toml index c6cce63ce..33006ea15 100644 --- a/EpiAware/Project.toml +++ b/EpiAware/Project.toml @@ -24,13 +24,16 @@ Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" DataFramesMeta = "0.14" Distributions = "0.25" DocStringExtensions = "0.9" +DynamicPPL = "0.24" LinearAlgebra = "1.9" LogExpFunctions = "0.3" Optim = "1.9" Parameters = "0.12" +Pathfinder = "0.8" QuadGK = "2.9" Random = "1.9" ReverseDiff = "1.15" SparseArrays = "1.10" +Transducers = "0.4" Turing = "0.30" julia = "1.10" From dbb655f4d724a5cb7bc46be31274c78f885e1616 Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Wed, 6 Mar 2024 16:28:54 +0000 Subject: [PATCH 31/40] remove old notebooks so doesn't waste time trying to render --- .../examples/getting_started_pf backup 1.jl | 415 ----------------- .../docs/src/examples/getting_started_pf.jl | 437 ------------------ 2 files changed, 852 deletions(-) delete mode 100644 EpiAware/docs/src/examples/getting_started_pf backup 1.jl delete mode 100644 EpiAware/docs/src/examples/getting_started_pf.jl diff --git a/EpiAware/docs/src/examples/getting_started_pf backup 1.jl b/EpiAware/docs/src/examples/getting_started_pf backup 1.jl deleted file mode 100644 index 78a7db0c0..000000000 --- a/EpiAware/docs/src/examples/getting_started_pf backup 1.jl +++ /dev/null @@ -1,415 +0,0 @@ -### A Pluto.jl notebook ### -# v0.19.40 - -using Markdown -using InteractiveUtils - -# ╔═╡ c593a2a0-d7f5-11ee-0931-d9f65ae84a72 -# hideall -let - docs_dir = dirname(dirname(@__DIR__)) - pkg_dir = dirname(docs_dir) - - using Pkg: Pkg - Pkg.activate(docs_dir) - Pkg.develop(; path = pkg_dir) - Pkg.resolve() - Pkg.instantiate() -end; - -# ╔═╡ da479d8d-1312-4b98-b0af-5be52dffaf3f -begin - using EpiAware - using Turing - using Distributions - using StatsPlots - using Random - using DynamicPPL - using Statistics - using DataFramesMeta - using LinearAlgebra - using Pathfinder -end - -# ╔═╡ 3ebc8384-f73d-4597-83a7-07a3744fed61 -md" -# Getting stated with `EpiAware` - -This tutorial introduces the basic functionality of `EpiAware`. `EpiAware` is a package for making inferences on epidemiological case/determined infection data using a model-based approach. - -## `EpiAware` models -The models we consider are discrete-time $t = 1,\dots, T$ with a latent random process, $Z_t$ generating stochasticity in the number of new infections $I_t$ at each time step. Observations are treated as downstream random variables determined by the actual infections and a model of infection to observation delay. - -#### Mathematical definition -```math -\begin{align} -Z_\cdot &\sim \mathcal{P}(\mathbb{R}^T) | \theta_Z, \\ -I_0 &\sim f_0(\theta_I), \\ -I_t &\sim g_I(\{I_s, Z_s\}_{s < t}, \theta_{I}), \\ -y_t &\sim f_O(\{I_s\}_{s \leq t}, \theta_{O}). -\end{align} -``` -Where, $\mathcal{P}(\mathbb{R}^T) | \theta_Z$ is a parametric process on $\mathbb{R}^T$. $f_0$ and $f_O$ are parametric distributions on, respectively, the initial number of infections and the observed case data conditional on underlying infections. $g_I$ is distribution of new infections conditional on infections and latent process in the past. Note that we assume that new infections are conditional on the strict past, whereas new observations can depend on infections on the same time step. - -#### Code structure outline - -An `EpiAware` model in code is created from three modular components: - -- A `LatentModel`: This defines the model for $Z_\cdot$. This chooses $\mathcal{P}(\mathbb{R}^T) | \theta_Z$. -- An `EpiModel`: This defines a generative process for infections conditional on the latent process. This chooses $f_0(\theta_I)$, and $g_I(\{I_s, Z_s\}_{s < t}, \theta_{I})$. -- An `ObservationModel`: This defines the observation model. This chooses $f_O({I_s}_{s \leq t}, \theta_{O})$ - -#### Reproductive number -`EpiAware` models do not need to specify a time-varying reproductive number $\mathcal{R}_t$ to generate $I_\cdot$, however, this is often a quantity of interest. When not directly used we will typically back-calculate $\mathcal{R}_t$ from the generated infections: - -```math -\mathcal{R}_t = {I_t \over \sum_{s \geq 1} g_s I_{t-s} }. -``` - -Where $g_s$ is a discrete generation interval. For this reason, even when not using a reproductive number approach directly, we ask for a generation interval. -" - -# ╔═╡ 5a0d5ab8-e985-4126-a1ac-58fe08beee38 -md" -## Random walk `LatentModel` - -As an example, we choose the latent process as a random walk with parameters $\theta_Z$: - -- ``Z_0``: Initial position. -- ``\sigma^2_{Z}``: The step-size variance. - -Conditional on the parameters the random walk is then generated by white noise: -```math -\begin{align} -Z_t &= Z_0 + \sigma_{RW} \sum_{t= 1}^T \epsilon_t, \\ -\epsilon_t &\sim \mathcal{N}(0,1). -\end{align} -``` - -In `EpiAware` we provide a constructor for random walk latent models with priors for $\theta_Z$. We choose priors, -```math -\begin{align} -Z_0 &\sim \mathcal{N}(0,1),\\ -\sigma^2_Z &\sim \text{HalfNormal}(0.01). -\end{align} -``` -" - -# ╔═╡ 56ae496b-0094-460b-89cb-526627991717 -rwp = EpiAware.RandomWalk(Normal(), - truncated(Normal(0.0, 0.02), 0.0, Inf)) - -# ╔═╡ 767beffd-1ef5-4e6c-9ac6-edb52e60fb44 -md" -## Direct infection `EpiModel` - -This is a simple model where the unobserved log-infections are directly generated by the latent process $Z$. -```math -\log I_t = \log I_0 + Z_t. -``` - -As discussed above, we still ask for a defined generation interval, which can be used to calculate $\mathcal{R}_t$. - -" - -# ╔═╡ 9e43cbe3-94de-44fc-a788-b9c7adb34218 -truth_GI = Gamma(2, 5) - -# ╔═╡ f067284f-a1a6-44a6-9b79-f8c2de447673 -md" -The `EpiData` constructor performs double interval censoring to convert our _continuous_ estimate of the generation interval into a discretized version. We also implement right truncation using the keyword `D_gen`. -" - -# ╔═╡ c0662d48-4b54-4b6d-8c91-ddf4b0e3aa43 -model_data = EpiData(truth_GI, D_gen = 10.0) - -# ╔═╡ fd72094f-1b95-4d07-a8b0-ef47dc560dfc -md" -We can supply a prior for the initial log_infections. -" - -# ╔═╡ 6639e66f-7725-4976-81b2-6472419d1a62 -log_I0_prior = Normal(log(100.0), 1.0) - -# ╔═╡ df5e59f8-3185-4bed-9cca-7c266df17cec -md" -And construct the `EpiModel`. -" - -# ╔═╡ 6fbdd8e6-2323-4352-9185-1f31a9cf9012 -epi_model = DirectInfections(model_data, log_I0_prior) - -# ╔═╡ 5e62a50a-71f4-4902-b1c9-fdf51fe145fa -md" - - -### Delayed Observations `ObservationModel` - -The observation model is a negative binomial distribution with mean `μ` and cluster factor `1 / r`. Delays are implemented -as the action of a sparse kernel on the infections $I(t)$. - -```math -\begin{align} -y_t &\sim \text{NegBinomial}(\mu = \sum_{s\geq 0} K[t, t-s] I(s), r), \\ -1 / r &\sim \text{Gamma}(3, 0.05/3). -\end{align} -``` -" - -# ╔═╡ e813d547-6100-4c43-b84c-8cebe306bda8 -md" -We also set up the inference to occur over 100 days. -" - -# ╔═╡ c7580ae6-0db5-448e-8b20-4dd6fcdb1ae0 -time_horizon = 100 - -# ╔═╡ 0aa3fcbd-0831-45b8-9a2c-7ffbabf5895f -md" -We choose a simple observation model where infections are observed 0, 1, 2, 3 days later with equal probability. -" - -# ╔═╡ 448669bc-99f4-4823-b15e-fcc9040ba31b -obs_model = EpiAware.DelayObservations( - fill(0.25, 4), - time_horizon, - truncated(Gamma(5, 0.05 / 5), 1e-3, 1.0) -) - -# ╔═╡ e49713e8-4840-4083-8e3f-fc52d791be7b -md" -## Generate cases from the `EpiAware` model - -Having chosen an `EpiModel`, `LatentModel` and `ObservationModel`, we can implement the model as a [`Turing`](https://turinglang.org/dev/) model using `make_epi_aware`. - -By giving `missing` to the first argument, we indicate that case data will be _generated_ from the model rather than treated as fixed. -" - -# ╔═╡ abeff860-58c3-4644-9325-66ffd4446b6d -full_epi_aware_mdl = make_epi_aware(missing, time_horizon; - epi_model = epi_model, - latent_model = rwp, - observation_model = obs_model) - -# ╔═╡ 821628fb-8044-48b0-aa4f-0b7b57a2f45a -md" -We choose some fixed parameters: -- Initial incidence is 100. -- In the direct infection model, the initial incidence and in the initial value of the random walk form a non-identifiable pair. Therefore, we fix $Z_0 = 0$. -" - -# ╔═╡ 36b34fd2-2891-42ca-b5dc-abb482e516ee -fixed_parameters = (rw_init = 0.0, init_incidence = log(100.0)) - -# ╔═╡ 0aadd9e3-7f91-4b45-9663-67d11335f0d0 -md" -We fix these parameters using `fix`, and generate a random epidemic. -" - -# ╔═╡ 7e0e6012-8648-4f84-a25a-8b0138c4b72a -cond_generative_model = fix(full_epi_aware_mdl, fixed_parameters) - -# ╔═╡ b20c28be-7b07-410c-a33b-ea5ad6828c12 -random_epidemic = rand(cond_generative_model) - -# ╔═╡ d073e63b-62da-4743-ace0-78ef7806bc0b -true_infections = generated_quantities(cond_generative_model, random_epidemic).I_t - -# ╔═╡ f68b4e41-ac5c-42cd-a8c2-8761d66f7543 -let - plot(true_infections, - label = "I_t", - xlabel = "Time", - ylabel = "Infections", - title = "Generated Infections") - scatter!(random_epidemic.y_t, lab = "generated cases") -end - -# ╔═╡ b5bc8f05-b538-4abf-aa84-450bf2dff3d9 -md" -## Inference -Fixing $Z_0 = 0$ for the random walk was based on inference principles; in this model $Z_0$ and $\log I_0$ are non-identifiable. - -However, we now treat the generated data as `truth_data` and make inference without fixing any other parameters. - -We do the inference by MCMC/NUTS using the `Turing` NUTS sampler with default warm-up steps. -" - -# ╔═╡ c8ce0d46-a160-4c40-a055-69b3d10d1770 -truth_data = random_epidemic.y_t - -# ╔═╡ b4033728-b321-4100-8194-1fd9fe2d268d -inference_mdl = fix( - make_epi_aware(truth_data, time_horizon; epi_model = epi_model, - latent_model = rwp, observation_model = obs_model), - (rw_init = 0.0,) -) - -# ╔═╡ 619b6a45-b202-47f3-97cf-1ab487278674 -safe_inference_mdl = fix( - make_epi_aware(truth_data, time_horizon, Val(:safe_mode); epi_model = epi_model, - latent_model = rwp, observation_model = obs_model), - (rw_init = 0.0,) -) - -# ╔═╡ 7da045fc-49ba-4a9a-8e5a-f9edb6418f7a -mpf = multipathfinder(safe_inference_mdl, 50; nruns = 4) - -# ╔═╡ 3224f468-63d4-483a-ac65-1a584d23d92c -init_params = collect.(eachrow(mpf.draws_transformed.value[1:4, :, 1])) - -# ╔═╡ 3eb5ec5e-aae7-478e-84fb-80f2e9f85b4c -chn = sample(inference_mdl, - NUTS(; adtype = AutoReverseDiff(true)), - MCMCThreads(), - 250, - 4; - init_params = init_params, - drop_warmup = true) - -# ╔═╡ 30498cc7-16a5-441a-b8cd-c19b220c60c1 -md" -### Predictive plotting - -We can spaghetti plot generated case data from the version of the model _which hasn't conditioned on case data_ using posterior parameters inferred from the version conditioned on observed data. This is known as _posterior predictive checking_, and is a useful diagnostic tool for Bayesian inference (see [here](http://www.stat.columbia.edu/~gelman/book/BDA3.pdf)). - -Because we are using synthetic data we can also plot the model predictions for the _unobserved_ infections and check that (at least in this example) we were able to capture some unobserved/latent variables in the process accurate. -" - -# ╔═╡ e9df22b8-8e4d-4ab7-91ea-c01f2239b3e5 -#=╠═╡ -let - post_check_mdl = fix(full_epi_aware_mdl, (rw_init = 0.0,)) - post_check_y_t = mapreduce(hcat, generated_quantities(post_check_mdl, chn)) do gen - gen.generated_y_t - end - - predicted_I_t = mapreduce(hcat, generated_quantities(inference_mdl, chn)) do gen - gen.I_t - end - - p1 = plot(post_check_y_t, c = :grey, alpha = 0.05, lab = "") - scatter!(p1, truth_data, - lab = "Observed cases", - xlabel = "Time", - ylabel = "Cases", - title = "Post. predictive checking: cases", - ylims = (-0.5, maximum(truth_data) * 1.5), - c = :green) - - p2 = plot(predicted_I_t, c = :grey, alpha = 0.05, lab = "") - scatter!(p2, true_infections, - lab = "Actual infections", - xlabel = "Time", - ylabel = "Unobserved Infections", - title = "Post. predictions: infections", - ylims = (-0.5, maximum(true_infections) * 1.5), - c = :red) - - plot(p1, p2, - layout = (1, 2), - size = (700, 400)) -end - ╠═╡ =# - -# ╔═╡ fd6321b1-4c3a-4123-b0dc-c45b951e0b80 -md" -As well as checking the posterior predictions for latent infections, we can also check how well inference recovered unknown parameters, such as the random walk variance or the cluster factor of the negative binomial observations. -" - -# ╔═╡ 10d8fe24-83a6-47ac-97b7-a374481473d3 -#=╠═╡ -let - parameters_to_plot = (:σ²_RW, :neg_bin_cluster_factor) - - plts = map(parameters_to_plot) do name - var_samples = chn[name] |> vec - histogram(var_samples, - bins = 50, - norm = :pdf, - lw = 0, - fillalpha = 0.5, - lab = "MCMC") - vline!([getfield(random_epidemic, name)], lab = "True value") - title!(string(name)) - end - plot(plts..., layout = (2, 1)) -end - ╠═╡ =# - -# ╔═╡ 81efe8ca-b753-4a12-bafc-a887a999377b -md" -## Reproductive number back-calculation - -As mentioned at the top, we _don't_ directly use the concept of reproductive numbers in this note. However, we can back-calculate the implied $\mathcal{R}(t)$ values, conditional on the specified generation interval being correct. - -Here we spaghetti plot posterior sampled time-varying reproductive numbers against the actual. -" - -# ╔═╡ 15b9f37f-8d5f-460d-8c28-d7f2271fd099 -#=╠═╡ -let - n = epi_model.data.len_gen_int - Rt_denom = [dot(reverse(epi_model.data.gen_int), true_infections[(t - n):(t - 1)]) - for t in (n + 1):length(true_infections)] - true_Rt = true_infections[(n + 1):end] ./ Rt_denom - - predicted_Rt = mapreduce(hcat, generated_quantities(inference_mdl, chn)) do gen - _It = gen.I_t - _Rt_denom = [dot(reverse(epi_model.data.gen_int), _It[(t - n):(t - 1)]) - for t in (n + 1):length(_It)] - Rt = _It[(n + 1):end] ./ _Rt_denom - end - - plt = plot((n + 1):time_horizon, predicted_Rt, c = :grey, alpha = 0.05, lab = "") - plot!(plt, (n + 1):time_horizon, true_Rt, - lab = "true Rt", - xlabel = "Time", - ylabel = "Rt", - title = "Post. predictions: reproductive number", - c = :red, - lw = 2) -end - ╠═╡ =# - -# ╔═╡ Cell order: -# ╠═c593a2a0-d7f5-11ee-0931-d9f65ae84a72 -# ╟─3ebc8384-f73d-4597-83a7-07a3744fed61 -# ╠═da479d8d-1312-4b98-b0af-5be52dffaf3f -# ╟─5a0d5ab8-e985-4126-a1ac-58fe08beee38 -# ╠═56ae496b-0094-460b-89cb-526627991717 -# ╟─767beffd-1ef5-4e6c-9ac6-edb52e60fb44 -# ╠═9e43cbe3-94de-44fc-a788-b9c7adb34218 -# ╟─f067284f-a1a6-44a6-9b79-f8c2de447673 -# ╠═c0662d48-4b54-4b6d-8c91-ddf4b0e3aa43 -# ╟─fd72094f-1b95-4d07-a8b0-ef47dc560dfc -# ╠═6639e66f-7725-4976-81b2-6472419d1a62 -# ╟─df5e59f8-3185-4bed-9cca-7c266df17cec -# ╠═6fbdd8e6-2323-4352-9185-1f31a9cf9012 -# ╟─5e62a50a-71f4-4902-b1c9-fdf51fe145fa -# ╟─e813d547-6100-4c43-b84c-8cebe306bda8 -# ╠═c7580ae6-0db5-448e-8b20-4dd6fcdb1ae0 -# ╟─0aa3fcbd-0831-45b8-9a2c-7ffbabf5895f -# ╠═448669bc-99f4-4823-b15e-fcc9040ba31b -# ╟─e49713e8-4840-4083-8e3f-fc52d791be7b -# ╠═abeff860-58c3-4644-9325-66ffd4446b6d -# ╟─821628fb-8044-48b0-aa4f-0b7b57a2f45a -# ╠═36b34fd2-2891-42ca-b5dc-abb482e516ee -# ╟─0aadd9e3-7f91-4b45-9663-67d11335f0d0 -# ╠═7e0e6012-8648-4f84-a25a-8b0138c4b72a -# ╠═b20c28be-7b07-410c-a33b-ea5ad6828c12 -# ╠═d073e63b-62da-4743-ace0-78ef7806bc0b -# ╟─f68b4e41-ac5c-42cd-a8c2-8761d66f7543 -# ╟─b5bc8f05-b538-4abf-aa84-450bf2dff3d9 -# ╠═c8ce0d46-a160-4c40-a055-69b3d10d1770 -# ╠═b4033728-b321-4100-8194-1fd9fe2d268d -# ╠═619b6a45-b202-47f3-97cf-1ab487278674 -# ╠═7da045fc-49ba-4a9a-8e5a-f9edb6418f7a -# ╠═3224f468-63d4-483a-ac65-1a584d23d92c -# ╠═3eb5ec5e-aae7-478e-84fb-80f2e9f85b4c -# ╟─30498cc7-16a5-441a-b8cd-c19b220c60c1 -# ╠═e9df22b8-8e4d-4ab7-91ea-c01f2239b3e5 -# ╟─fd6321b1-4c3a-4123-b0dc-c45b951e0b80 -# ╠═10d8fe24-83a6-47ac-97b7-a374481473d3 -# ╟─81efe8ca-b753-4a12-bafc-a887a999377b -# ╠═15b9f37f-8d5f-460d-8c28-d7f2271fd099 diff --git a/EpiAware/docs/src/examples/getting_started_pf.jl b/EpiAware/docs/src/examples/getting_started_pf.jl deleted file mode 100644 index e97fb72fd..000000000 --- a/EpiAware/docs/src/examples/getting_started_pf.jl +++ /dev/null @@ -1,437 +0,0 @@ -### A Pluto.jl notebook ### -# v0.19.40 - -using Markdown -using InteractiveUtils - -# ╔═╡ c593a2a0-d7f5-11ee-0931-d9f65ae84a72 -# hideall -let - docs_dir = dirname(dirname(@__DIR__)) - pkg_dir = dirname(docs_dir) - - using Pkg: Pkg - Pkg.activate(docs_dir) - Pkg.develop(; path = pkg_dir) - Pkg.resolve() - Pkg.instantiate() -end; - -# ╔═╡ da479d8d-1312-4b98-b0af-5be52dffaf3f -begin - using EpiAware - using Turing - using Distributions - using StatsPlots - using Random - using DynamicPPL - using Statistics - using DataFramesMeta - using LinearAlgebra - using Pathfinder - using Transducers - - Random.seed!(1) -end - -# ╔═╡ 3ebc8384-f73d-4597-83a7-07a3744fed61 -md" -# Getting stated with `EpiAware` - -This tutorial introduces the basic functionality of `EpiAware`. `EpiAware` is a package for making inferences on epidemiological case/determined infection data using a model-based approach. - -## `EpiAware` models -The models we consider are discrete-time $t = 1,\dots, T$ with a latent random process, $Z_t$ generating stochasticity in the number of new infections $I_t$ at each time step. Observations are treated as downstream random variables determined by the actual infections and a model of infection to observation delay. - -#### Mathematical definition -```math -\begin{align} -Z_\cdot &\sim \mathcal{P}(\mathbb{R}^T) | \theta_Z, \\ -I_0 &\sim f_0(\theta_I), \\ -I_t &\sim g_I(\{I_s, Z_s\}_{s < t}, \theta_{I}), \\ -y_t &\sim f_O(\{I_s\}_{s \leq t}, \theta_{O}). -\end{align} -``` -Where, $\mathcal{P}(\mathbb{R}^T) | \theta_Z$ is a parametric process on $\mathbb{R}^T$. $f_0$ and $f_O$ are parametric distributions on, respectively, the initial number of infections and the observed case data conditional on underlying infections. $g_I$ is distribution of new infections conditional on infections and latent process in the past. Note that we assume that new infections are conditional on the strict past, whereas new observations can depend on infections on the same time step. - -#### Code structure outline - -An `EpiAware` model in code is created from three modular components: - -- A `LatentModel`: This defines the model for $Z_\cdot$. This chooses $\mathcal{P}(\mathbb{R}^T) | \theta_Z$. -- An `EpiModel`: This defines a generative process for infections conditional on the latent process. This chooses $f_0(\theta_I)$, and $g_I(\{I_s, Z_s\}_{s < t}, \theta_{I})$. -- An `ObservationModel`: This defines the observation model. This chooses $f_O({I_s}_{s \leq t}, \theta_{O})$ - -#### Reproductive number -`EpiAware` models do not need to specify a time-varying reproductive number $\mathcal{R}_t$ to generate $I_\cdot$, however, this is often a quantity of interest. When not directly used we will typically back-calculate $\mathcal{R}_t$ from the generated infections: - -```math -\mathcal{R}_t = {I_t \over \sum_{s \geq 1} g_s I_{t-s} }. -``` - -Where $g_s$ is a discrete generation interval. For this reason, even when not using a reproductive number approach directly, we ask for a generation interval. -" - -# ╔═╡ 5a0d5ab8-e985-4126-a1ac-58fe08beee38 -md" -## Random walk `LatentModel` - -As an example, we choose the latent process as a random walk with parameters $\theta_Z$: - -- ``Z_0``: Initial position. -- ``\sigma^2_{Z}``: The step-size variance. - -Conditional on the parameters the random walk is then generated by white noise: -```math -\begin{align} -Z_t &= Z_0 + \sigma_{RW} \sum_{t= 1}^T \epsilon_t, \\ -\epsilon_t &\sim \mathcal{N}(0,1). -\end{align} -``` - -In `EpiAware` we provide a constructor for random walk latent models with priors for $\theta_Z$. We choose priors, -```math -\begin{align} -Z_0 &\sim \mathcal{N}(0,1),\\ -\sigma^2_Z &\sim \text{HalfNormal}(0.01). -\end{align} -``` -" - -# ╔═╡ 56ae496b-0094-460b-89cb-526627991717 -rwp = EpiAware.RandomWalk(Normal(), - truncated(Normal(0.0, 0.02), 0.0, Inf)) - -# ╔═╡ 767beffd-1ef5-4e6c-9ac6-edb52e60fb44 -md" -## Direct infection `EpiModel` - -This is a simple model where the unobserved log-infections are directly generated by the latent process $Z$. -```math -\log I_t = \log I_0 + Z_t. -``` - -As discussed above, we still ask for a defined generation interval, which can be used to calculate $\mathcal{R}_t$. - -" - -# ╔═╡ 9e43cbe3-94de-44fc-a788-b9c7adb34218 -truth_GI = Gamma(2, 5) - -# ╔═╡ f067284f-a1a6-44a6-9b79-f8c2de447673 -md" -The `EpiData` constructor performs double interval censoring to convert our _continuous_ estimate of the generation interval into a discretized version. We also implement right truncation using the keyword `D_gen`. -" - -# ╔═╡ c0662d48-4b54-4b6d-8c91-ddf4b0e3aa43 -model_data = EpiData(truth_GI, D_gen = 10.0) - -# ╔═╡ fd72094f-1b95-4d07-a8b0-ef47dc560dfc -md" -We can supply a prior for the initial log_infections. -" - -# ╔═╡ 6639e66f-7725-4976-81b2-6472419d1a62 -log_I0_prior = Normal(log(100.0), 1.0) - -# ╔═╡ df5e59f8-3185-4bed-9cca-7c266df17cec -md" -And construct the `EpiModel`. -" - -# ╔═╡ 6fbdd8e6-2323-4352-9185-1f31a9cf9012 -epi_model = DirectInfections(model_data, log_I0_prior) - -# ╔═╡ 5e62a50a-71f4-4902-b1c9-fdf51fe145fa -md" - - -### Delayed Observations `ObservationModel` - -The observation model is a negative binomial distribution with mean `μ` and cluster factor `1 / r`. Delays are implemented -as the action of a sparse kernel on the infections $I(t)$. - -```math -\begin{align} -y_t &\sim \text{NegBinomial}(\mu = \sum_{s\geq 0} K[t, t-s] I(s), r), \\ -1 / r &\sim \text{Gamma}(3, 0.05/3). -\end{align} -``` -" - -# ╔═╡ e813d547-6100-4c43-b84c-8cebe306bda8 -md" -We also set up the inference to occur over 100 days. -" - -# ╔═╡ c7580ae6-0db5-448e-8b20-4dd6fcdb1ae0 -time_horizon = 100 - -# ╔═╡ 0aa3fcbd-0831-45b8-9a2c-7ffbabf5895f -md" -We choose a simple observation model where infections are observed 0, 1, 2, 3 days later with equal probability. -" - -# ╔═╡ 448669bc-99f4-4823-b15e-fcc9040ba31b -obs_model = EpiAware.DelayObservations( - fill(0.25, 4), - time_horizon, - truncated(Gamma(5, 0.05 / 5), 1e-3, 1.0) -) - -# ╔═╡ e49713e8-4840-4083-8e3f-fc52d791be7b -md" -## Generate cases from the `EpiAware` model - -Having chosen an `EpiModel`, `LatentModel` and `ObservationModel`, we can implement the model as a [`Turing`](https://turinglang.org/dev/) model using `make_epi_aware`. - -By giving `missing` to the first argument, we indicate that case data will be _generated_ from the model rather than treated as fixed. -" - -# ╔═╡ abeff860-58c3-4644-9325-66ffd4446b6d -full_epi_aware_mdl = make_epi_aware(missing, time_horizon; - epi_model = epi_model, - latent_model = rwp, - observation_model = obs_model) - -# ╔═╡ 821628fb-8044-48b0-aa4f-0b7b57a2f45a -md" -We choose some fixed parameters: -- Initial incidence is 100. -- In the direct infection model, the initial incidence and in the initial value of the random walk form a non-identifiable pair. Therefore, we fix $Z_0 = 0$. -" - -# ╔═╡ 36b34fd2-2891-42ca-b5dc-abb482e516ee -fixed_parameters = (rw_init = 0.0, init_incidence = log(100.0)) - -# ╔═╡ 0aadd9e3-7f91-4b45-9663-67d11335f0d0 -md" -We fix these parameters using `fix`, and generate a random epidemic. -" - -# ╔═╡ 7e0e6012-8648-4f84-a25a-8b0138c4b72a -cond_generative_model = fix(full_epi_aware_mdl, fixed_parameters) - -# ╔═╡ b20c28be-7b07-410c-a33b-ea5ad6828c12 -random_epidemic = rand(cond_generative_model) - -# ╔═╡ d073e63b-62da-4743-ace0-78ef7806bc0b -true_infections = generated_quantities(cond_generative_model, random_epidemic).I_t - -# ╔═╡ f68b4e41-ac5c-42cd-a8c2-8761d66f7543 -let - plot(true_infections, - label = "I_t", - xlabel = "Time", - ylabel = "Infections", - title = "Generated Infections") - scatter!(random_epidemic.y_t, lab = "generated cases") -end - -# ╔═╡ b5bc8f05-b538-4abf-aa84-450bf2dff3d9 -md" -## Inference -Fixing $Z_0 = 0$ for the random walk was based on inference principles; in this model $Z_0$ and $\log I_0$ are non-identifiable. - -However, we now treat the generated data as `truth_data` and make inference without fixing any other parameters. - -We do the inference by MCMC/NUTS using the `Turing` NUTS sampler with default warm-up steps. -" - -# ╔═╡ c8ce0d46-a160-4c40-a055-69b3d10d1770 -truth_data = random_epidemic.y_t - -# ╔═╡ bd4e3762-a711-43b7-8dbe-d5a1f690f021 -function _epi_aware(y_t, - time_steps; - epi_model::AbstractEpiModel, - latent_model::AbstractLatentModel, - observation_model::AbstractObservationModel, - nsamples, - nchains, - pf_ndraws = 50, - pf_nruns = 4, - fixed_parameters = (;), - pos_shift = 1e-6, - executor = Transducers.ThreadedEx(), - adtype = AutoReverseDiff(true), - maxiters = 10, - kwargs...) - gen_mdl = make_epi_aware(missing, time_horizon; epi_model, - latent_model, observation_model, pos_shift) |> - mdl -> fix(mdl, fixed_parameters) - - mdl = make_epi_aware(y_t, time_horizon; epi_model, - latent_model, observation_model, pos_shift) |> - mdl -> fix(mdl, fixed_parameters) - - safe_mdl = make_epi_aware(y_t, time_horizon, Val(:safe_mode); epi_model, - latent_model, observation_model, pos_shift) |> - mdl -> fix(mdl, fixed_parameters) - - mpf = multipathfinder(safe_mdl, max(pf_ndraws, nchains); - nruns = pf_nruns, - executor, - maxiters, - kwargs...) - - init_params = collect.(eachrow(mpf.draws_transformed.value[1:nchains, :, 1])) - - chn = sample(mdl, - NUTS(; adtype), - MCMCThreads(), - nsamples ÷ nchains, - nchains; - init_params = init_params, - drop_warmup = true) - - return chn, (; pathfinder_res = mpf, - inference_mdl = mdl, - generative_mdl = gen_mdl) -end - -# ╔═╡ d0ed77f0-b27b-49af-a682-5ad567fe2d45 -@time chn, epi_mdls = _epi_aware( - truth_data, time_horizon; epi_model = epi_model, latent_model = rwp, - observation_model = obs_model, nsamples = 1000, nchains = 4, - fixed_parameters = (rw_init = 0.0,)) - -# ╔═╡ 30498cc7-16a5-441a-b8cd-c19b220c60c1 -md" -### Predictive plotting - -We can spaghetti plot generated case data from the version of the model _which hasn't conditioned on case data_ using posterior parameters inferred from the version conditioned on observed data. This is known as _posterior predictive checking_, and is a useful diagnostic tool for Bayesian inference (see [here](http://www.stat.columbia.edu/~gelman/book/BDA3.pdf)). - -Because we are using synthetic data we can also plot the model predictions for the _unobserved_ infections and check that (at least in this example) we were able to capture some unobserved/latent variables in the process accurate. -" - -# ╔═╡ e9df22b8-8e4d-4ab7-91ea-c01f2239b3e5 -let - post_check_mdl = epi_mdls.generative_mdl - inference_mdl = epi_mdls.inference_mdl - post_check_y_t = mapreduce(hcat, generated_quantities(post_check_mdl, chn)) do gen - gen.generated_y_t - end - - predicted_I_t = mapreduce(hcat, generated_quantities(inference_mdl, chn)) do gen - gen.I_t - end - - p1 = plot(post_check_y_t, c = :grey, alpha = 0.05, lab = "") - scatter!(p1, truth_data, - lab = "Observed cases", - xlabel = "Time", - ylabel = "Cases", - title = "Post. predictive checking: cases", - ylims = (-0.5, maximum(truth_data) * 1.5), - c = :green) - - p2 = plot(predicted_I_t, c = :grey, alpha = 0.05, lab = "") - scatter!(p2, true_infections, - lab = "Actual infections", - xlabel = "Time", - ylabel = "Unobserved Infections", - title = "Post. predictions: infections", - ylims = (-0.5, maximum(true_infections) * 1.5), - c = :red) - - plot(p1, p2, - layout = (1, 2), - size = (700, 400)) -end - -# ╔═╡ fd6321b1-4c3a-4123-b0dc-c45b951e0b80 -md" -As well as checking the posterior predictions for latent infections, we can also check how well inference recovered unknown parameters, such as the random walk variance or the cluster factor of the negative binomial observations. -" - -# ╔═╡ 10d8fe24-83a6-47ac-97b7-a374481473d3 -let - parameters_to_plot = (:σ²_RW, :neg_bin_cluster_factor) - - plts = map(parameters_to_plot) do name - var_samples = chn[name] |> vec - histogram(var_samples, - bins = 50, - norm = :pdf, - lw = 0, - fillalpha = 0.5, - lab = "MCMC") - vline!([getfield(random_epidemic, name)], lab = "True value") - title!(string(name)) - end - plot(plts..., layout = (2, 1)) -end - -# ╔═╡ 81efe8ca-b753-4a12-bafc-a887a999377b -md" -## Reproductive number back-calculation - -As mentioned at the top, we _don't_ directly use the concept of reproductive numbers in this note. However, we can back-calculate the implied $\mathcal{R}(t)$ values, conditional on the specified generation interval being correct. - -Here we spaghetti plot posterior sampled time-varying reproductive numbers against the actual. -" - -# ╔═╡ 15b9f37f-8d5f-460d-8c28-d7f2271fd099 -let - inference_mdl = epi_mdls.inference_mdl - n = epi_model.data.len_gen_int - Rt_denom = [dot(reverse(epi_model.data.gen_int), true_infections[(t - n):(t - 1)]) - for t in (n + 1):length(true_infections)] - true_Rt = true_infections[(n + 1):end] ./ Rt_denom - - predicted_Rt = mapreduce(hcat, generated_quantities(inference_mdl, chn)) do gen - _It = gen.I_t - _Rt_denom = [dot(reverse(epi_model.data.gen_int), _It[(t - n):(t - 1)]) - for t in (n + 1):length(_It)] - Rt = _It[(n + 1):end] ./ _Rt_denom - end - - plt = plot((n + 1):time_horizon, predicted_Rt, c = :grey, alpha = 0.05, lab = "") - plot!(plt, (n + 1):time_horizon, true_Rt, - lab = "true Rt", - xlabel = "Time", - ylabel = "Rt", - title = "Post. predictions: reproductive number", - c = :red, - lw = 2) -end - -# ╔═╡ Cell order: -# ╠═c593a2a0-d7f5-11ee-0931-d9f65ae84a72 -# ╟─3ebc8384-f73d-4597-83a7-07a3744fed61 -# ╠═da479d8d-1312-4b98-b0af-5be52dffaf3f -# ╟─5a0d5ab8-e985-4126-a1ac-58fe08beee38 -# ╠═56ae496b-0094-460b-89cb-526627991717 -# ╟─767beffd-1ef5-4e6c-9ac6-edb52e60fb44 -# ╠═9e43cbe3-94de-44fc-a788-b9c7adb34218 -# ╟─f067284f-a1a6-44a6-9b79-f8c2de447673 -# ╠═c0662d48-4b54-4b6d-8c91-ddf4b0e3aa43 -# ╟─fd72094f-1b95-4d07-a8b0-ef47dc560dfc -# ╠═6639e66f-7725-4976-81b2-6472419d1a62 -# ╟─df5e59f8-3185-4bed-9cca-7c266df17cec -# ╠═6fbdd8e6-2323-4352-9185-1f31a9cf9012 -# ╟─5e62a50a-71f4-4902-b1c9-fdf51fe145fa -# ╟─e813d547-6100-4c43-b84c-8cebe306bda8 -# ╠═c7580ae6-0db5-448e-8b20-4dd6fcdb1ae0 -# ╟─0aa3fcbd-0831-45b8-9a2c-7ffbabf5895f -# ╠═448669bc-99f4-4823-b15e-fcc9040ba31b -# ╟─e49713e8-4840-4083-8e3f-fc52d791be7b -# ╠═abeff860-58c3-4644-9325-66ffd4446b6d -# ╟─821628fb-8044-48b0-aa4f-0b7b57a2f45a -# ╠═36b34fd2-2891-42ca-b5dc-abb482e516ee -# ╟─0aadd9e3-7f91-4b45-9663-67d11335f0d0 -# ╠═7e0e6012-8648-4f84-a25a-8b0138c4b72a -# ╠═b20c28be-7b07-410c-a33b-ea5ad6828c12 -# ╠═d073e63b-62da-4743-ace0-78ef7806bc0b -# ╟─f68b4e41-ac5c-42cd-a8c2-8761d66f7543 -# ╟─b5bc8f05-b538-4abf-aa84-450bf2dff3d9 -# ╠═c8ce0d46-a160-4c40-a055-69b3d10d1770 -# ╠═bd4e3762-a711-43b7-8dbe-d5a1f690f021 -# ╠═d0ed77f0-b27b-49af-a682-5ad567fe2d45 -# ╟─30498cc7-16a5-441a-b8cd-c19b220c60c1 -# ╠═e9df22b8-8e4d-4ab7-91ea-c01f2239b3e5 -# ╟─fd6321b1-4c3a-4123-b0dc-c45b951e0b80 -# ╠═10d8fe24-83a6-47ac-97b7-a374481473d3 -# ╟─81efe8ca-b753-4a12-bafc-a887a999377b -# ╠═15b9f37f-8d5f-460d-8c28-d7f2271fd099 From 5611ebc52b56b301cbc1c4a5e3d02506758d0d68 Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Wed, 6 Mar 2024 17:20:19 +0000 Subject: [PATCH 32/40] Update getting started to include inference loop function --- EpiAware/docs/src/examples/getting_started.jl | 166 +++++++++--------- 1 file changed, 84 insertions(+), 82 deletions(-) diff --git a/EpiAware/docs/src/examples/getting_started.jl b/EpiAware/docs/src/examples/getting_started.jl index e2949d3f7..c7facf03f 100644 --- a/EpiAware/docs/src/examples/getting_started.jl +++ b/EpiAware/docs/src/examples/getting_started.jl @@ -1,12 +1,10 @@ - ### A Pluto.jl notebook ### # v0.19.40 using Markdown using InteractiveUtils -# ╔═╡ c593a2a0-d7f5-11ee-0931-d9f65ae84a72 -# hideall +# ╔═╡ 4680eed2-dbd8-11ee-17d6-2552317711c1 let docs_dir = dirname(dirname(@__DIR__)) pkg_dir = dirname(docs_dir) @@ -16,9 +14,9 @@ let Pkg.develop(; path = pkg_dir) Pkg.resolve() Pkg.instantiate() -end; +end -# ╔═╡ da479d8d-1312-4b98-b0af-5be52dffaf3f +# ╔═╡ 914b0cd7-86ca-4227-96cb-7cd887956833 begin using EpiAware using Turing @@ -32,15 +30,18 @@ begin using Pathfinder using Transducers - Random.seed!(1234) -end + Random.seed!(1) +end; -# ╔═╡ 3ebc8384-f73d-4597-83a7-07a3744fed61 +# ╔═╡ 7c22d80b-2d52-4935-b62d-cbc64437195c md" # Getting stated with `EpiAware` This tutorial introduces the basic functionality of `EpiAware`. `EpiAware` is a package for making inferences on epidemiological case/determined infection data using a model-based approach. +" +# ╔═╡ 683743e6-9426-4b95-994b-4a579aa2564d +md" ## `EpiAware` models The models we consider are discrete-time $t = 1,\dots, T$ with a latent random process, $Z_t$ generating stochasticity in the number of new infections $I_t$ at each time step. Observations are treated as downstream random variables determined by the actual infections and a model of infection to observation delay. @@ -73,7 +74,7 @@ An `EpiAware` model in code is created from three modular components: Where $g_s$ is a discrete generation interval. For this reason, even when not using a reproductive number approach directly, we ask for a generation interval. " -# ╔═╡ 5a0d5ab8-e985-4126-a1ac-58fe08beee38 +# ╔═╡ 0b554dd9-79c7-44bc-9cbf-ea56439cb80d md" ## Random walk `LatentModel` @@ -99,11 +100,11 @@ Z_0 &\sim \mathcal{N}(0,1),\\ ``` " -# ╔═╡ 56ae496b-0094-460b-89cb-526627991717 +# ╔═╡ 65540fb2-f97f-49dc-91f4-af26b803994e rwp = EpiAware.RandomWalk(Normal(), - truncated(Normal(0.0, 0.02), 0.0, Inf)) + truncated(Normal(0.0, 0.02), 0.0, 0.5)) -# ╔═╡ 767beffd-1ef5-4e6c-9ac6-edb52e60fb44 +# ╔═╡ 3ed6eb84-d0e1-4f09-9fa0-4021d0f79f88 md" ## Direct infection `EpiModel` @@ -116,34 +117,34 @@ As discussed above, we still ask for a defined generation interval, which can be " -# ╔═╡ 9e43cbe3-94de-44fc-a788-b9c7adb34218 +# ╔═╡ 0a532e52-8305-470a-8462-2aa023b724a2 truth_GI = Gamma(2, 5) -# ╔═╡ f067284f-a1a6-44a6-9b79-f8c2de447673 +# ╔═╡ ec24d355-0158-45e9-9584-ed89bbb17b31 md" The `EpiData` constructor performs double interval censoring to convert our _continuous_ estimate of the generation interval into a discretized version. We also implement right truncation using the keyword `D_gen`. " -# ╔═╡ c0662d48-4b54-4b6d-8c91-ddf4b0e3aa43 +# ╔═╡ 8452c589-aceb-41b7-978e-918b83db58d3 model_data = EpiData(truth_GI, D_gen = 10.0) -# ╔═╡ fd72094f-1b95-4d07-a8b0-ef47dc560dfc +# ╔═╡ 6028ccd2-428f-4737-9fa6-ab5bf17631bd md" We can supply a prior for the initial log_infections. " -# ╔═╡ 6639e66f-7725-4976-81b2-6472419d1a62 +# ╔═╡ bd6714a0-6e70-4602-993e-12238e1f37f2 log_I0_prior = Normal(log(100.0), 1.0) -# ╔═╡ df5e59f8-3185-4bed-9cca-7c266df17cec +# ╔═╡ af012cc5-02ea-47f4-8545-cf54f2c6f6cc md" And construct the `EpiModel`. " -# ╔═╡ 6fbdd8e6-2323-4352-9185-1f31a9cf9012 +# ╔═╡ 18e51238-4038-4022-9b22-0e51ed51ea0a epi_model = DirectInfections(model_data, log_I0_prior) -# ╔═╡ 5e62a50a-71f4-4902-b1c9-fdf51fe145fa +# ╔═╡ d75e7957-a020-46f6-8d40-e2ac1abb917f md" @@ -160,27 +161,27 @@ y_t &\sim \text{NegBinomial}(\mu = \sum_{s\geq 0} K[t, t-s] I(s), r), \\ ``` " -# ╔═╡ e813d547-6100-4c43-b84c-8cebe306bda8 +# ╔═╡ 74771683-48cb-456f-8089-65c2fe2fdef2 md" We also set up the inference to occur over 100 days. " -# ╔═╡ c7580ae6-0db5-448e-8b20-4dd6fcdb1ae0 +# ╔═╡ cba7072f-aaaa-40fe-8e5f-f22d98fbdb30 time_horizon = 100 -# ╔═╡ 0aa3fcbd-0831-45b8-9a2c-7ffbabf5895f +# ╔═╡ 6c39de90-4791-42f0-b863-228753618c8a md" We choose a simple observation model where infections are observed 0, 1, 2, 3 days later with equal probability. " -# ╔═╡ 448669bc-99f4-4823-b15e-fcc9040ba31b +# ╔═╡ b4a420af-71e6-4094-9795-9544dd5f34a2 obs_model = EpiAware.DelayObservations( fill(0.25, 4), time_horizon, - truncated(Gamma(5, 0.05 / 5), 1e-3, 1.0) + truncated(Gamma(5, 0.05 / 5), 1e-3, 0.2) ) -# ╔═╡ e49713e8-4840-4083-8e3f-fc52d791be7b +# ╔═╡ 5b667f47-8c90-42ca-8253-998a3ad3878d md" ## Generate cases from the `EpiAware` model @@ -189,37 +190,37 @@ Having chosen an `EpiModel`, `LatentModel` and `ObservationModel`, we can implem By giving `missing` to the first argument, we indicate that case data will be _generated_ from the model rather than treated as fixed. " -# ╔═╡ abeff860-58c3-4644-9325-66ffd4446b6d +# ╔═╡ 12632c1c-b233-4990-9e7e-9add1bb9a8ee full_epi_aware_mdl = make_epi_aware(missing, time_horizon; epi_model = epi_model, latent_model = rwp, observation_model = obs_model) -# ╔═╡ 821628fb-8044-48b0-aa4f-0b7b57a2f45a +# ╔═╡ 031a55e8-617b-4da4-859c-94b660ec424d md" We choose some fixed parameters: - Initial incidence is 100. - In the direct infection model, the initial incidence and in the initial value of the random walk form a non-identifiable pair. Therefore, we fix $Z_0 = 0$. " -# ╔═╡ 36b34fd2-2891-42ca-b5dc-abb482e516ee +# ╔═╡ 775c7295-a4ea-484e-9ad9-291df3f6ffe8 fixed_parameters = (rw_init = 0.0, init_incidence = log(100.0)) -# ╔═╡ 0aadd9e3-7f91-4b45-9663-67d11335f0d0 +# ╔═╡ 5ce1855e-9f9e-4feb-8df3-8ec6139ac943 md" We fix these parameters using `fix`, and generate a random epidemic. " -# ╔═╡ 7e0e6012-8648-4f84-a25a-8b0138c4b72a +# ╔═╡ 5b86a63b-677c-4125-b0be-1527b73b91bd cond_generative_model = fix(full_epi_aware_mdl, fixed_parameters) -# ╔═╡ b20c28be-7b07-410c-a33b-ea5ad6828c12 +# ╔═╡ 4a1fc7bb-82a0-4643-a18b-d331a31c1390 random_epidemic = rand(cond_generative_model) -# ╔═╡ d073e63b-62da-4743-ace0-78ef7806bc0b +# ╔═╡ e571e7b6-0e26-4855-ae90-05a18be6ff38 true_infections = generated_quantities(cond_generative_model, random_epidemic).I_t -# ╔═╡ f68b4e41-ac5c-42cd-a8c2-8761d66f7543 +# ╔═╡ 88e8fb2c-38ce-4c68-88b9-c42f3fa6de13 let plot(true_infections, label = "I_t", @@ -229,7 +230,7 @@ let scatter!(random_epidemic.y_t, lab = "generated cases") end -# ╔═╡ b5bc8f05-b538-4abf-aa84-450bf2dff3d9 +# ╔═╡ 2f90bee6-067d-4267-beb9-356e4a4d714c md" ## Inference Fixing $Z_0 = 0$ for the random walk was based on inference principles; in this model $Z_0$ and $\log I_0$ are non-identifiable. @@ -239,16 +240,16 @@ However, we now treat the generated data as `truth_data` and make inference with We do the inference by MCMC/NUTS using the `Turing` NUTS sampler with default warm-up steps. " -# ╔═╡ c8ce0d46-a160-4c40-a055-69b3d10d1770 +# ╔═╡ 7e48a4c5-cd30-4377-8a98-e0c23f2dc31e truth_data = random_epidemic.y_t -# ╔═╡ d0ed77f0-b27b-49af-a682-5ad567fe2d45 -@time chn, epi_mdls = EpiAware._epi_aware( +# ╔═╡ 6b32f804-7534-4c98-9788-b0fd22771d43 +chn, epi_mdls = EpiAware._epi_aware( truth_data, time_horizon; epi_model = epi_model, latent_model = rwp, - observation_model = obs_model, nsamples = 1000, nchains = 4, - fixed_parameters = (rw_init = 0.0,)) + observation_model = obs_model, nsamples = 1000, nchains = 4, pf_nruns = 4, pf_ndraws = 50, + fixed_parameters = (rw_init = 0.0,)); -# ╔═╡ 30498cc7-16a5-441a-b8cd-c19b220c60c1 +# ╔═╡ 2e42cb30-b087-4ae1-9b8f-95d103e1c290 md" ### Predictive plotting @@ -257,7 +258,7 @@ We can spaghetti plot generated case data from the version of the model _which h Because we are using synthetic data we can also plot the model predictions for the _unobserved_ infections and check that (at least in this example) we were able to capture some unobserved/latent variables in the process accurate. " -# ╔═╡ e9df22b8-8e4d-4ab7-91ea-c01f2239b3e5 +# ╔═╡ e74fc652-cd5f-4764-a416-caa8bab0bf0c let post_check_mdl = epi_mdls.generative_mdl inference_mdl = epi_mdls.inference_mdl @@ -292,12 +293,12 @@ let size = (700, 400)) end -# ╔═╡ fd6321b1-4c3a-4123-b0dc-c45b951e0b80 +# ╔═╡ 96df9c68-b2e2-4669-b420-5ef23c77aee7 md" As well as checking the posterior predictions for latent infections, we can also check how well inference recovered unknown parameters, such as the random walk variance or the cluster factor of the negative binomial observations. " -# ╔═╡ 10d8fe24-83a6-47ac-97b7-a374481473d3 +# ╔═╡ 04d741d8-a2ff-48eb-90b1-e4da416eb582 let parameters_to_plot = (:σ²_RW, :neg_bin_cluster_factor) @@ -315,7 +316,7 @@ let plot(plts..., layout = (2, 1)) end -# ╔═╡ 81efe8ca-b753-4a12-bafc-a887a999377b +# ╔═╡ 42763332-096d-40eb-a152-96e858992ed4 md" ## Reproductive number back-calculation @@ -324,7 +325,7 @@ As mentioned at the top, we _don't_ directly use the concept of reproductive num Here we spaghetti plot posterior sampled time-varying reproductive numbers against the actual. " -# ╔═╡ 15b9f37f-8d5f-460d-8c28-d7f2271fd099 +# ╔═╡ 3b5a3fa6-fc57-4b3c-b03d-04641bf0e48b let inference_mdl = epi_mdls.inference_mdl n = epi_model.data.len_gen_int @@ -350,39 +351,40 @@ let end # ╔═╡ Cell order: -# ╠═c593a2a0-d7f5-11ee-0931-d9f65ae84a72 -# ╟─3ebc8384-f73d-4597-83a7-07a3744fed61 -# ╠═da479d8d-1312-4b98-b0af-5be52dffaf3f -# ╟─5a0d5ab8-e985-4126-a1ac-58fe08beee38 -# ╠═56ae496b-0094-460b-89cb-526627991717 -# ╟─767beffd-1ef5-4e6c-9ac6-edb52e60fb44 -# ╠═9e43cbe3-94de-44fc-a788-b9c7adb34218 -# ╟─f067284f-a1a6-44a6-9b79-f8c2de447673 -# ╠═c0662d48-4b54-4b6d-8c91-ddf4b0e3aa43 -# ╟─fd72094f-1b95-4d07-a8b0-ef47dc560dfc -# ╠═6639e66f-7725-4976-81b2-6472419d1a62 -# ╟─df5e59f8-3185-4bed-9cca-7c266df17cec -# ╠═6fbdd8e6-2323-4352-9185-1f31a9cf9012 -# ╟─5e62a50a-71f4-4902-b1c9-fdf51fe145fa -# ╟─e813d547-6100-4c43-b84c-8cebe306bda8 -# ╠═c7580ae6-0db5-448e-8b20-4dd6fcdb1ae0 -# ╟─0aa3fcbd-0831-45b8-9a2c-7ffbabf5895f -# ╠═448669bc-99f4-4823-b15e-fcc9040ba31b -# ╟─e49713e8-4840-4083-8e3f-fc52d791be7b -# ╠═abeff860-58c3-4644-9325-66ffd4446b6d -# ╟─821628fb-8044-48b0-aa4f-0b7b57a2f45a -# ╠═36b34fd2-2891-42ca-b5dc-abb482e516ee -# ╟─0aadd9e3-7f91-4b45-9663-67d11335f0d0 -# ╠═7e0e6012-8648-4f84-a25a-8b0138c4b72a -# ╠═b20c28be-7b07-410c-a33b-ea5ad6828c12 -# ╠═d073e63b-62da-4743-ace0-78ef7806bc0b -# ╟─f68b4e41-ac5c-42cd-a8c2-8761d66f7543 -# ╟─b5bc8f05-b538-4abf-aa84-450bf2dff3d9 -# ╠═c8ce0d46-a160-4c40-a055-69b3d10d1770 -# ╠═d0ed77f0-b27b-49af-a682-5ad567fe2d45 -# ╟─30498cc7-16a5-441a-b8cd-c19b220c60c1 -# ╠═e9df22b8-8e4d-4ab7-91ea-c01f2239b3e5 -# ╟─fd6321b1-4c3a-4123-b0dc-c45b951e0b80 -# ╠═10d8fe24-83a6-47ac-97b7-a374481473d3 -# ╟─81efe8ca-b753-4a12-bafc-a887a999377b -# ╠═15b9f37f-8d5f-460d-8c28-d7f2271fd099 +# ╟─4680eed2-dbd8-11ee-17d6-2552317711c1 +# ╟─914b0cd7-86ca-4227-96cb-7cd887956833 +# ╟─7c22d80b-2d52-4935-b62d-cbc64437195c +# ╟─683743e6-9426-4b95-994b-4a579aa2564d +# ╟─0b554dd9-79c7-44bc-9cbf-ea56439cb80d +# ╠═65540fb2-f97f-49dc-91f4-af26b803994e +# ╟─3ed6eb84-d0e1-4f09-9fa0-4021d0f79f88 +# ╠═0a532e52-8305-470a-8462-2aa023b724a2 +# ╟─ec24d355-0158-45e9-9584-ed89bbb17b31 +# ╠═8452c589-aceb-41b7-978e-918b83db58d3 +# ╟─6028ccd2-428f-4737-9fa6-ab5bf17631bd +# ╠═bd6714a0-6e70-4602-993e-12238e1f37f2 +# ╟─af012cc5-02ea-47f4-8545-cf54f2c6f6cc +# ╠═18e51238-4038-4022-9b22-0e51ed51ea0a +# ╟─d75e7957-a020-46f6-8d40-e2ac1abb917f +# ╟─74771683-48cb-456f-8089-65c2fe2fdef2 +# ╠═cba7072f-aaaa-40fe-8e5f-f22d98fbdb30 +# ╟─6c39de90-4791-42f0-b863-228753618c8a +# ╠═b4a420af-71e6-4094-9795-9544dd5f34a2 +# ╟─5b667f47-8c90-42ca-8253-998a3ad3878d +# ╠═12632c1c-b233-4990-9e7e-9add1bb9a8ee +# ╟─031a55e8-617b-4da4-859c-94b660ec424d +# ╠═775c7295-a4ea-484e-9ad9-291df3f6ffe8 +# ╟─5ce1855e-9f9e-4feb-8df3-8ec6139ac943 +# ╠═5b86a63b-677c-4125-b0be-1527b73b91bd +# ╠═4a1fc7bb-82a0-4643-a18b-d331a31c1390 +# ╠═e571e7b6-0e26-4855-ae90-05a18be6ff38 +# ╟─88e8fb2c-38ce-4c68-88b9-c42f3fa6de13 +# ╟─2f90bee6-067d-4267-beb9-356e4a4d714c +# ╠═7e48a4c5-cd30-4377-8a98-e0c23f2dc31e +# ╠═6b32f804-7534-4c98-9788-b0fd22771d43 +# ╟─2e42cb30-b087-4ae1-9b8f-95d103e1c290 +# ╠═e74fc652-cd5f-4764-a416-caa8bab0bf0c +# ╠═96df9c68-b2e2-4669-b420-5ef23c77aee7 +# ╠═04d741d8-a2ff-48eb-90b1-e4da416eb582 +# ╠═42763332-096d-40eb-a152-96e858992ed4 +# ╠═3b5a3fa6-fc57-4b3c-b03d-04641bf0e48b From 4af04acace68ebedf8ec882c33d6fe8f5f2fc50d Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Thu, 7 Mar 2024 11:19:32 +0000 Subject: [PATCH 33/40] manypathfinder approach with new getting started and tests --- EpiAware/docs/src/examples/getting_started.jl | 68 ++++++- EpiAware/src/EpiAware.jl | 3 + EpiAware/src/inference-methods.jl | 142 ++++++++----- EpiAware/src/models.jl | 35 ---- EpiAware/test/Project.toml | 1 + EpiAware/test/test_inference-methods.jl | 188 +++++++++++++----- 6 files changed, 291 insertions(+), 146 deletions(-) diff --git a/EpiAware/docs/src/examples/getting_started.jl b/EpiAware/docs/src/examples/getting_started.jl index c7facf03f..047cdbd9c 100644 --- a/EpiAware/docs/src/examples/getting_started.jl +++ b/EpiAware/docs/src/examples/getting_started.jl @@ -243,11 +243,56 @@ We do the inference by MCMC/NUTS using the `Turing` NUTS sampler with default wa # ╔═╡ 7e48a4c5-cd30-4377-8a98-e0c23f2dc31e truth_data = random_epidemic.y_t -# ╔═╡ 6b32f804-7534-4c98-9788-b0fd22771d43 -chn, epi_mdls = EpiAware._epi_aware( - truth_data, time_horizon; epi_model = epi_model, latent_model = rwp, - observation_model = obs_model, nsamples = 1000, nchains = 4, pf_nruns = 4, pf_ndraws = 50, - fixed_parameters = (rw_init = 0.0,)); +# ╔═╡ 272e6798-1151-486f-9667-924dbc63bd69 +inference_mdl = fix( + make_epi_aware(truth_data, time_horizon; + epi_model = epi_model, + latent_model = rwp, + observation_model = obs_model), + (init_rw = 0.0,) +) + +# ╔═╡ 4298f0ec-f6df-42ee-aa28-f7ed60f1e530 +md" +### Initialising inference + +It is possible for the default warm-up process for NUTS to get stuck in low probability or otherwise degenerate regions of parameter space. + +To make NUTS more robust we provide `manypathfinder`, which is built on pathfinder variational inference through the [Pathfinder.jl](https://mlcolab.github.io/Pathfinder.jl/stable/). `manypathfinder` runs `nruns` pathfinder processes on the inference problem and returns the pathfinder run with maximum estimated ELBO. + +`manypathfinder` differs from `Pathfinder.multipathfinder`; `multipathfinder` is aimed at sampling from a potentially non-Gaussian target distribution which is first approximated as a uniformly weighted collection of normal approximations from pathfinder runs. `manypathfinder` is aimed at moving rapidly to a 'good' part of parameter space, and is robust to runs that fail. +" + +# ╔═╡ 40ebd47a-4a08-4a46-a727-26347d3fca51 +best_pf = manypathfinder(inference_mdl; nruns = 20); + +# ╔═╡ b7d9a56a-b2d5-4595-a6b9-9cd5fa6b1445 +md" +We can use draws from the best pathfinder run to initialise NUTS. +" + +# ╔═╡ cdd805e2-b00c-4522-9261-1819c6a195eb +best_pf.draws_transformed + +# ╔═╡ e847b0b6-9d70-46ba-bec6-1e3fa676a33c +init_params = collect.(eachrow(best_pf.draws_transformed.value[1:4, :, 1])) + +# ╔═╡ 9734a535-e3d8-4481-9897-f537ad095d21 +md" +**NB: We are running this inference run for speed rather than accuracy as a demonstration. Use a higher target acceptance and more samples in a typical workflow.** +" + +# ╔═╡ 2fdb4ca6-47ba-4a16-95fa-14b2b32cef10 +begin + target_acc_rate = 0.8 + chn = sample(inference_mdl, + NUTS(target_acc_rate; adtype = AutoReverseDiff(true)), + MCMCThreads(), + 250, + 4; + init_params, + drop_warmup = true) +end # ╔═╡ 2e42cb30-b087-4ae1-9b8f-95d103e1c290 md" @@ -260,8 +305,7 @@ Because we are using synthetic data we can also plot the model predictions for t # ╔═╡ e74fc652-cd5f-4764-a416-caa8bab0bf0c let - post_check_mdl = epi_mdls.generative_mdl - inference_mdl = epi_mdls.inference_mdl + post_check_mdl = fix(full_epi_aware_mdl, (init_rw = 0.0,)) post_check_y_t = mapreduce(hcat, generated_quantities(post_check_mdl, chn)) do gen gen.generated_y_t end @@ -327,7 +371,6 @@ Here we spaghetti plot posterior sampled time-varying reproductive numbers again # ╔═╡ 3b5a3fa6-fc57-4b3c-b03d-04641bf0e48b let - inference_mdl = epi_mdls.inference_mdl n = epi_model.data.len_gen_int Rt_denom = [dot(reverse(epi_model.data.gen_int), true_infections[(t - n):(t - 1)]) for t in (n + 1):length(true_infections)] @@ -381,7 +424,14 @@ end # ╟─88e8fb2c-38ce-4c68-88b9-c42f3fa6de13 # ╟─2f90bee6-067d-4267-beb9-356e4a4d714c # ╠═7e48a4c5-cd30-4377-8a98-e0c23f2dc31e -# ╠═6b32f804-7534-4c98-9788-b0fd22771d43 +# ╠═272e6798-1151-486f-9667-924dbc63bd69 +# ╟─4298f0ec-f6df-42ee-aa28-f7ed60f1e530 +# ╠═40ebd47a-4a08-4a46-a727-26347d3fca51 +# ╟─b7d9a56a-b2d5-4595-a6b9-9cd5fa6b1445 +# ╠═cdd805e2-b00c-4522-9261-1819c6a195eb +# ╠═e847b0b6-9d70-46ba-bec6-1e3fa676a33c +# ╟─9734a535-e3d8-4481-9897-f537ad095d21 +# ╠═2fdb4ca6-47ba-4a16-95fa-14b2b32cef10 # ╟─2e42cb30-b087-4ae1-9b8f-95d103e1c290 # ╠═e74fc652-cd5f-4764-a416-caa8bab0bf0c # ╠═96df9c68-b2e2-4669-b420-5ef23c77aee7 diff --git a/EpiAware/src/EpiAware.jl b/EpiAware/src/EpiAware.jl index 3443d7f75..0d971220a 100644 --- a/EpiAware/src/EpiAware.jl +++ b/EpiAware/src/EpiAware.jl @@ -54,6 +54,9 @@ export generate_latent, generate_latent_infs, generate_observations export create_discrete_pmf, spread_draws, scan, R_to_r, r_to_R, default_rw_priors, default_delay_obs_priors +# Exported inference methods +export manypathfinder + include("docstrings.jl") include("abstract-types.jl") include("epi-models.jl") diff --git a/EpiAware/src/inference-methods.jl b/EpiAware/src/inference-methods.jl index 6a783ad1d..516b650c3 100644 --- a/EpiAware/src/inference-methods.jl +++ b/EpiAware/src/inference-methods.jl @@ -1,51 +1,97 @@ """ +Run pathfinder multiple times and store the results in an array. Fails safely. -Do inference on an EpiAware model. -""" -function _epi_aware(y_t, - time_steps; - epi_model::AbstractEpiModel, - latent_model::AbstractLatentModel, - observation_model::AbstractObservationModel, - nsamples, - nchains, - pf_ndraws = 10, - pf_nruns = 10, - fixed_parameters = (;), - pos_shift = 1e-6, - executor = Transducers.ThreadedEx(), - adtype = AutoReverseDiff(true), - maxiters = 10, - kwargs...) - gen_mdl = make_epi_aware(missing, time_steps; epi_model, - latent_model, observation_model, pos_shift) |> - mdl -> fix(mdl, fixed_parameters) - - mdl = make_epi_aware(y_t, time_steps; epi_model, - latent_model, observation_model, pos_shift) |> - mdl -> fix(mdl, fixed_parameters) - - safe_mdl = make_epi_aware(y_t, time_steps, Val(:safe_mode); epi_model, - latent_model, observation_model, pos_shift) |> - mdl -> fix(mdl, fixed_parameters) - - mpf = multipathfinder(safe_mdl, max(pf_ndraws, nchains); - nruns = pf_nruns, - executor, - maxiters, - kwargs...) - - init_params = collect.(eachrow(mpf.draws_transformed.value[(end - nchains):end, :, 1])) - - chn = sample(mdl, - NUTS(; adtype), - MCMCThreads(), - nsamples ÷ nchains, - nchains; - init_params = init_params, - drop_warmup = true) - - return chn, (; pathfinder_res = mpf, - inference_mdl = mdl, - generative_mdl = gen_mdl) +# Arguments +- `mdl::DynamicPPL.Model`: The `Turing` model to be used for inference. +- `nruns`: The number of times to run the `pathfinder` function. +- `kwargs...`: Additional keyword arguments passed to `pathfinder`. + +# Returns +An array of `PathfinderResult` objects or `Symbol` values indicating success or failure. +""" +function _run_manypathfinder(mdl::DynamicPPL.Model; nruns, kwargs...) + @info "Running pathfinder $nruns times" + pfs = Vector{Union{Pathfinder.PathfinderResult, Symbol}}(undef, nruns) + Threads.@threads for i in 1:nruns + try + pfs[i] = pathfinder(mdl; kwargs...) + catch + pfs[i] = :fail + end + end + return pfs +end + +""" +Continue running the pathfinder algorithm until a pathfinder succeeds or the maximum number +of tries is reached. + +# Arguments +- `pfs`: An array of pathfinder objects. +- `mdl::DynamicPPL.Model`: The model to perform inference on. +- `max_tries`: The maximum number of tries to run the pathfinder algorithm. Default is + `Inf`. +- `kwargs...`: Additional keyword arguments passed to `pathfinder`. + +# Returns +- `pfs`: The updated array of pathfinder objects. + +""" +function _continue_manypathfinder!(pfs, mdl::DynamicPPL.Model; max_tries, kwargs...) + tryiter = 1 + while all(pfs .== :fail) && tryiter <= max_tries + new_pf = try + pathfinder(mdl; kwargs...) + catch + :fail + end + pfs = vcat(pfs, new_pf) + tryiter += 1 + end + if all(pfs .== :fail) + @warn "All pathfinder runs failed" + end + return pfs +end + +""" +Selects the pathfinder with the highest ELBO estimate from a list of pathfinders. + +# Arguments +- `pfs`: A list of pathfinders results or `Symbol` values indicating failure. + +# Returns +The pathfinder with the highest ELBO estimate. +""" +function _get_best_elbo_pathfinder(pfs) + elbos = map(pfs) do pf_res + pf_res == :fail ? -Inf : pf_res.elbo_estimates[end].value + end + _, choice_of_pf = findmax(elbos) + return pfs[choice_of_pf] +end + +""" +Run multiple instances of the pathfinder algorithm and returns the pathfinder run with the +largest ELBO estimate. + +## Arguments +- `mdl::DynamicPPL.Model`: The model to perform inference on. +- `nruns::Int`: The number of pathfinder runs to perform. +- `ndraws::Int`: The number of draws per pathfinder run, readjusted to be at least as large + as the number of chains. +- `nchains::Int`: The number of chains that will be initialised by pathfinder draws. +- `maxiters::Int`: The maximum number of optimizer iterations per pathfinder run. +- `max_tries::Int`: The maximum number of extra tries to find a valid pathfinder result. +- `kwargs...`: Additional keyword arguments passed to `pathfinder`. + +## Returns +- `best_pfs::PathfinderResult`: Best pathfinder result by estimated ELBO. +""" +function manypathfinder(mdl::DynamicPPL.Model; nruns = 4, ndraws = 10, + nchains = 4, maxiters = 50, max_tries = 100, kwargs...) + ndraws = max(ndraws, nchains) + _run_manypathfinder(mdl; nruns, ndraws, maxiters, kwargs...) |> + pfs -> _continue_manypathfinder!(pfs, mdl; max_tries, kwargs...) |> + pfs -> _get_best_elbo_pathfinder(pfs) end diff --git a/EpiAware/src/models.jl b/EpiAware/src/models.jl index c85ff6b7c..a4ea306a5 100644 --- a/EpiAware/src/models.jl +++ b/EpiAware/src/models.jl @@ -25,38 +25,3 @@ latent_model, process_aux = merge(latent_model_aux, generated_y_t_aux)) end - -@model function make_epi_aware(y_t, - time_steps, - ::Val{:safe_mode}; - epi_model::AbstractEpiModel, - latent_model::AbstractLatentModel, - observation_model::AbstractObservationModel, - pos_shift = 1e-6) - try - #Latent process - @submodel Z_t, latent_model_aux = generate_latent( - latent_model, - time_steps) - - #Transform into infections - @submodel I_t = generate_latent_infs(epi_model, Z_t) - - #Predictive distribution of ascerted cases - @submodel generated_y_t, generated_y_t_aux = generate_observations( - observation_model, - y_t, - I_t; - pos_shift = pos_shift) - - #Generate quantities - return (; - generated_y_t, - I_t, - latent_model, - process_aux = merge(latent_model_aux, generated_y_t_aux)) - catch - Turing.@addlogprob! -Inf - return - end -end diff --git a/EpiAware/test/Project.toml b/EpiAware/test/Project.toml index eb8752d5f..b882dd0ac 100644 --- a/EpiAware/test/Project.toml +++ b/EpiAware/test/Project.toml @@ -5,6 +5,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Pathfinder = "b1d3bc72-d0e7-4279-b92f-7fa5d6d2d454" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd" diff --git a/EpiAware/test/test_inference-methods.jl b/EpiAware/test/test_inference-methods.jl index 73a3c8148..b44f0e0fa 100644 --- a/EpiAware/test/test_inference-methods.jl +++ b/EpiAware/test/test_inference-methods.jl @@ -1,55 +1,135 @@ -@testitem "Testing _epi_aware function" begin - using Transducers, Turing - time_steps = 20 - - y_t = fill(10, time_steps) - nsamples = 10 - nchains = 2 - pf_ndraws = 10 - pf_nruns = 10 - fixed_parameters = (;) - pos_shift = 1e-6 - executor = Transducers.ThreadedEx() - adtype = AutoReverseDiff(true) - maxiters = 10 - - #Define the epi_model - epi_model = DirectInfections(EpiData([0.2, 0.3, 0.5], exp), Normal()) - - #Define the latent process model - rwp = EpiAware.RandomWalk(Normal(0.0, 1.0), - truncated(Normal(0.0, 0.05), 0.0, Inf)) - - #Define the observation model - delay_distribution = Gamma(2.0, 5 / 2) - time_horizon = time_steps - D_delay = 14.0 - Δd = 1.0 - - obs_model = EpiAware.DelayObservations(delay_distribution = delay_distribution, - time_horizon = time_horizon, - neg_bin_cluster_factor_prior = Gamma(5, 0.05 / 5), - D_delay = D_delay, - Δd = Δd) - - # Call the _epi_aware function to check this runs - chn, results = EpiAware._epi_aware(y_t, time_steps; - epi_model = epi_model, - latent_model = rwp, - observation_model = obs_model, - nsamples = nsamples, - nchains = nchains, - pf_ndraws = pf_ndraws, - pf_nruns = pf_nruns, - fixed_parameters = fixed_parameters, - pos_shift = pos_shift, - executor = executor, - adtype = adtype, - maxiters = maxiters) - - # Perform assertions to check the correctness of the results - @test size(chn, 1) == nsamples ÷ nchains - @test haskey(results, :pathfinder_res) - @test haskey(results, :inference_mdl) - @test haskey(results, :generative_mdl) +@testitem "Testing _run_manypathfinder function" begin + using Turing, Pathfinder + @model function test_model() + x ~ Normal(0, 1) + y ~ Normal(x, 1) + end + + mdl = test_model() + + # Test case 1 + @testset "Test case 1" begin + nruns = 10 + ndraws = 100 + maxiters = 50 + + pfs = EpiAware._run_manypathfinder( + mdl; nruns = nruns, ndraws = ndraws, maxiters = maxiters) + + @test length(pfs) == nruns + @test all(p -> p isa Union{PathfinderResult, Symbol}, pfs) + end + + # Test case 2 + @testset "Test case 2" begin + nruns = 5 + ndraws = 50 + maxiters = 100 + + pfs = EpiAware._run_manypathfinder( + mdl; nruns = nruns, ndraws = ndraws, maxiters = maxiters) + + @test length(pfs) == nruns + @test all(p -> p isa Union{PathfinderResult, Symbol}, pfs) + end +end +@testitem "Testing _continue_manypathfinder! function" begin + using Turing, Pathfinder + + @testset "Test case 1" begin + @model function test_model() + x ~ Normal(0, 1) + y ~ Normal(x, 1) + end + + mdl = test_model() + + pfs = Vector{Union{PathfinderResult, Symbol}}([:fail, :fail, :fail]) + max_tries = 3 + nruns = 10 + ndraws = 100 + maxiters = 50 + + pfs = EpiAware._continue_manypathfinder!( + pfs, mdl; max_tries, nruns, ndraws, maxiters) + + @test all(p -> p isa Union{PathfinderResult, Symbol}, pfs) + end + + @testset "Test case 2" begin + @model function test_model() + x ~ Normal(0, 1) + y ~ Normal(x, 1) + end + + mdl = test_model() + + pfs = Vector{Union{PathfinderResult, Symbol}}([:fail, :fail, :fail]) + max_tries = 3 + nruns = 10 + ndraws = 100 + maxiters = 50 + + pfs = EpiAware._continue_manypathfinder!( + pfs, mdl; max_tries, nruns, ndraws, maxiters) + + @test all(p -> p isa Union{PathfinderResult, Symbol}, pfs) + end +end +@testitem "Testing _get_best_elbo_pathfinder function" begin + using Pathfinder, Turing + + @model function test_model() + x ~ Normal(0, 1) + y ~ Normal(x, 1) + end + + mdl = test_model() + nruns = 10 + ndraws = 100 + maxiters = 50 + + pfs = EpiAware._run_manypathfinder( + mdl; nruns = nruns, ndraws = ndraws, maxiters = maxiters) + + best_pf = EpiAware._get_best_elbo_pathfinder(pfs) + @test best_pf isa PathfinderResult +end +@testitem "Testing manypathfinder function" begin + using Turing, Pathfinder + + @model function test_model() + x ~ Normal(0, 1) + y ~ Normal(x, 1) + end + + mdl = test_model() + + # Test case 1 + @testset "Test case 1" begin + nruns = 4 + ndraws = 10 + nchains = 4 + maxiters = 50 + max_tries = 100 + + best_pf = manypathfinder(mdl; nruns = nruns, ndraws = ndraws, nchains = nchains, + maxiters = maxiters, max_tries = max_tries) + + @test best_pf isa PathfinderResult + end + + # Test case 2 + @testset "Test case 2" begin + nruns = 2 + ndraws = 5 + nchains = 2 + maxiters = 30 + max_tries = 50 + + best_pf = manypathfinder(mdl; nruns = nruns, ndraws = ndraws, nchains = nchains, + maxiters = maxiters, max_tries = max_tries) + + @test best_pf isa PathfinderResult + end end From 40f3a5b9393d6c757a8f5f6fb3338f95ab019a13 Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Thu, 7 Mar 2024 11:31:14 +0000 Subject: [PATCH 34/40] change to generated_y_t approach --- EpiAware/docs/src/examples/getting_started.jl | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/EpiAware/docs/src/examples/getting_started.jl b/EpiAware/docs/src/examples/getting_started.jl index 7d7cd6bb7..58842531a 100644 --- a/EpiAware/docs/src/examples/getting_started.jl +++ b/EpiAware/docs/src/examples/getting_started.jl @@ -1,5 +1,5 @@ ### A Pluto.jl notebook ### -# v0.19.39 +# v0.19.40 using Markdown using InteractiveUtils @@ -220,6 +220,9 @@ random_epidemic = rand(cond_generative_model) # ╔═╡ e571e7b6-0e26-4855-ae90-05a18be6ff38 true_infections = generated_quantities(cond_generative_model, random_epidemic).I_t +# ╔═╡ 62092c7f-ebe7-428e-baaa-65c34be52371 +generated_obs = generated_quantities(cond_generative_model, random_epidemic).generated_y_t + # ╔═╡ 88e8fb2c-38ce-4c68-88b9-c42f3fa6de13 let plot(true_infections, @@ -241,7 +244,7 @@ We do the inference by MCMC/NUTS using the `Turing` NUTS sampler with default wa " # ╔═╡ 7e48a4c5-cd30-4377-8a98-e0c23f2dc31e -truth_data = random_epidemic.y_t +truth_data = generated_obs # ╔═╡ 272e6798-1151-486f-9667-924dbc63bd69 inference_mdl = fix( @@ -337,7 +340,6 @@ let size = (700, 400)) end - # ╔═╡ 96df9c68-b2e2-4669-b420-5ef23c77aee7 md" As well as checking the posterior predictions for latent infections, we can also check how well inference recovered unknown parameters, such as the random walk variance or the cluster factor of the negative binomial observations. @@ -422,6 +424,7 @@ end # ╠═5b86a63b-677c-4125-b0be-1527b73b91bd # ╠═4a1fc7bb-82a0-4643-a18b-d331a31c1390 # ╠═e571e7b6-0e26-4855-ae90-05a18be6ff38 +# ╠═62092c7f-ebe7-428e-baaa-65c34be52371 # ╟─88e8fb2c-38ce-4c68-88b9-c42f3fa6de13 # ╟─2f90bee6-067d-4267-beb9-356e4a4d714c # ╠═7e48a4c5-cd30-4377-8a98-e0c23f2dc31e From 9b2b509339dbd479fccfb749ddab9f6b4cbd1871 Mon Sep 17 00:00:00 2001 From: Samuel Brand <48288458+SamuelBrand1@users.noreply.github.com> Date: Thu, 7 Mar 2024 11:43:48 +0000 Subject: [PATCH 35/40] Update EpiAware/docs/src/examples/getting_started.jl Co-authored-by: Sam Abbott --- EpiAware/docs/src/examples/getting_started.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/EpiAware/docs/src/examples/getting_started.jl b/EpiAware/docs/src/examples/getting_started.jl index 58842531a..cb0a90345 100644 --- a/EpiAware/docs/src/examples/getting_started.jl +++ b/EpiAware/docs/src/examples/getting_started.jl @@ -261,7 +261,7 @@ md" It is possible for the default warm-up process for NUTS to get stuck in low probability or otherwise degenerate regions of parameter space. -To make NUTS more robust we provide `manypathfinder`, which is built on pathfinder variational inference through the [Pathfinder.jl](https://mlcolab.github.io/Pathfinder.jl/stable/). `manypathfinder` runs `nruns` pathfinder processes on the inference problem and returns the pathfinder run with maximum estimated ELBO. +To make NUTS more robust we provide `manypathfinder`, which is built on pathfinder variational inference from [Pathfinder.jl](https://mlcolab.github.io/Pathfinder.jl/stable/). `manypathfinder` runs `nruns` pathfinder processes on the inference problem and returns the pathfinder run with maximum estimated ELBO. `manypathfinder` differs from `Pathfinder.multipathfinder`; `multipathfinder` is aimed at sampling from a potentially non-Gaussian target distribution which is first approximated as a uniformly weighted collection of normal approximations from pathfinder runs. `manypathfinder` is aimed at moving rapidly to a 'good' part of parameter space, and is robust to runs that fail. " From a8ca3e81fb75e0af5b4cbd50ac913d49ec3000c9 Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Thu, 7 Mar 2024 11:56:38 +0000 Subject: [PATCH 36/40] fix rw_init --- EpiAware/docs/src/examples/getting_started.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/EpiAware/docs/src/examples/getting_started.jl b/EpiAware/docs/src/examples/getting_started.jl index 58842531a..adb09b0dd 100644 --- a/EpiAware/docs/src/examples/getting_started.jl +++ b/EpiAware/docs/src/examples/getting_started.jl @@ -252,7 +252,7 @@ inference_mdl = fix( epi_model = epi_model, latent_model = rwp, observation_model = obs_model), - (init_rw = 0.0,) + (rw_init = 0.0,) ) # ╔═╡ 4298f0ec-f6df-42ee-aa28-f7ed60f1e530 @@ -308,7 +308,7 @@ Because we are using synthetic data we can also plot the model predictions for t # ╔═╡ e74fc652-cd5f-4764-a416-caa8bab0bf0c let - post_check_mdl = fix(full_epi_aware_mdl, (init_rw = 0.0,)) + post_check_mdl = fix(full_epi_aware_mdl, (rw_init = 0.0,)) post_check_y_t = mapreduce(hcat, generated_quantities(post_check_mdl, chn)) do gen gen.generated_y_t end From 42b0b1bf051cae710908f4757b800fc8bae40820 Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Thu, 7 Mar 2024 12:16:04 +0000 Subject: [PATCH 37/40] Add warning about fail of all initial pathfinder runs --- EpiAware/src/inference-methods.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/EpiAware/src/inference-methods.jl b/EpiAware/src/inference-methods.jl index 516b650c3..9b47397d9 100644 --- a/EpiAware/src/inference-methods.jl +++ b/EpiAware/src/inference-methods.jl @@ -1,4 +1,5 @@ """ + Run pathfinder multiple times and store the results in an array. Fails safely. # Arguments @@ -39,6 +40,9 @@ of tries is reached. """ function _continue_manypathfinder!(pfs, mdl::DynamicPPL.Model; max_tries, kwargs...) tryiter = 1 + if all(pfs .== :fail) + @warn "All initial pathfinder runs failed, trying again for $max_tries tries." + end while all(pfs .== :fail) && tryiter <= max_tries new_pf = try pathfinder(mdl; kwargs...) @@ -49,7 +53,8 @@ function _continue_manypathfinder!(pfs, mdl::DynamicPPL.Model; max_tries, kwargs tryiter += 1 end if all(pfs .== :fail) - @warn "All pathfinder runs failed" + @warn "All pathfinder runs failed after $max_tries tries. Returning failed + pathfinder." end return pfs end From 902d206299870a4e1db13f732f6f28bdfcd92e62 Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Thu, 7 Mar 2024 12:54:59 +0000 Subject: [PATCH 38/40] Increased testing cases for manypathfinder --- EpiAware/test/test_inference-methods.jl | 112 ++++++++++++++---------- 1 file changed, 67 insertions(+), 45 deletions(-) diff --git a/EpiAware/test/test_inference-methods.jl b/EpiAware/test/test_inference-methods.jl index b44f0e0fa..752facfc6 100644 --- a/EpiAware/test/test_inference-methods.jl +++ b/EpiAware/test/test_inference-methods.jl @@ -1,14 +1,14 @@ @testitem "Testing _run_manypathfinder function" begin using Turing, Pathfinder - @model function test_model() - x ~ Normal(0, 1) - y ~ Normal(x, 1) - end - mdl = test_model() + @testset "Test case: check runs" begin + @model function test_model() + x ~ Normal(0, 1) + y ~ Normal(x, 1) + end + + mdl = test_model() - # Test case 1 - @testset "Test case 1" begin nruns = 10 ndraws = 100 maxiters = 50 @@ -19,30 +19,31 @@ @test length(pfs) == nruns @test all(p -> p isa Union{PathfinderResult, Symbol}, pfs) end - - # Test case 2 - @testset "Test case 2" begin + @testset "Test case: check fail mode for bad model" begin + @model function bad_model() + x ~ Normal(0, 1) + return sqrt(x) #<-fails + end + badmdl = bad_model() nruns = 5 ndraws = 50 maxiters = 100 pfs = EpiAware._run_manypathfinder( - mdl; nruns = nruns, ndraws = ndraws, maxiters = maxiters) + badmdl; nruns = nruns, ndraws = ndraws, maxiters = maxiters) - @test length(pfs) == nruns - @test all(p -> p isa Union{PathfinderResult, Symbol}, pfs) + @test all(pfs .== :fail) end end @testitem "Testing _continue_manypathfinder! function" begin using Turing, Pathfinder - @testset "Test case 1" begin - @model function test_model() + @testset "Check that it only adds one more for easy model" begin + @model function easy_model() x ~ Normal(0, 1) - y ~ Normal(x, 1) end - mdl = test_model() + easymdl = easy_model() pfs = Vector{Union{PathfinderResult, Symbol}}([:fail, :fail, :fail]) max_tries = 3 @@ -51,18 +52,17 @@ end maxiters = 50 pfs = EpiAware._continue_manypathfinder!( - pfs, mdl; max_tries, nruns, ndraws, maxiters) + pfs, easymdl; max_tries, nruns, ndraws, maxiters) - @test all(p -> p isa Union{PathfinderResult, Symbol}, pfs) + @test pfs[end] isa PathfinderResult end - @testset "Test case 2" begin - @model function test_model() + @testset "Check always fails for bad models and throws correct Exception" begin + @model function bad_model() x ~ Normal(0, 1) - y ~ Normal(x, 1) + return sqrt(x) #<-fails end - - mdl = test_model() + badmdl = bad_model() pfs = Vector{Union{PathfinderResult, Symbol}}([:fail, :fail, :fail]) max_tries = 3 @@ -70,10 +70,10 @@ end ndraws = 100 maxiters = 50 - pfs = EpiAware._continue_manypathfinder!( - pfs, mdl; max_tries, nruns, ndraws, maxiters) - - @test all(p -> p isa Union{PathfinderResult, Symbol}, pfs) + @test_throws "All pathfinder runs failed after $max_tries tries." begin + pfs = EpiAware._continue_manypathfinder!( + pfs, badmdl; max_tries, nruns, ndraws, maxiters) + end end end @testitem "Testing _get_best_elbo_pathfinder function" begin @@ -96,17 +96,15 @@ end @test best_pf isa PathfinderResult end @testitem "Testing manypathfinder function" begin - using Turing, Pathfinder - - @model function test_model() - x ~ Normal(0, 1) - y ~ Normal(x, 1) - end + using Turing, Pathfinder, HypothesisTests + @testset "Test model works" begin + @model function test_model() + x ~ Normal(0, 1) + y ~ Normal(x, 1) + end - mdl = test_model() + mdl = test_model() - # Test case 1 - @testset "Test case 1" begin nruns = 4 ndraws = 10 nchains = 4 @@ -119,17 +117,41 @@ end @test best_pf isa PathfinderResult end - # Test case 2 - @testset "Test case 2" begin - nruns = 2 - ndraws = 5 - nchains = 2 - maxiters = 30 - max_tries = 50 + @testset "Does good job finding simple distribution" begin + @model function basic_normal() + x ~ Normal(0, 1) + end + mdl = basic_normal() + nruns = 4 + ndraws = 2000 + nchains = 4 + maxiters = 50 + max_tries = 10 best_pf = manypathfinder(mdl; nruns = nruns, ndraws = ndraws, nchains = nchains, maxiters = maxiters, max_tries = max_tries) - @test best_pf isa PathfinderResult + pathfinder_samples = best_pf.draws |> vec + ks_test_pval = ExactOneSampleKSTest(pathfinder_samples, Normal(0.0, 1)) |> pvalue + @test ks_test_pval > 1e-6 + end + + @testset "Check always fails for bad models and throws correct Exception" begin + @model function bad_model() + x ~ Normal(0, 1) + return sqrt(x) #<-fails + end + badmdl = bad_model() + + max_tries = 3 + nruns = 10 + ndraws = 100 + maxiters = 50 + nchains = 4 + + @test_throws "All pathfinder runs failed after $max_tries tries." begin + manypathfinder(badmdl; nruns = nruns, ndraws = ndraws, nchains = nchains, + maxiters = maxiters, max_tries = max_tries) + end end end From 67427fbeb5a06063c623fd006946ac18154f8cac Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Thu, 7 Mar 2024 12:55:21 +0000 Subject: [PATCH 39/40] Make total fail a verbose exception error --- EpiAware/src/inference-methods.jl | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/EpiAware/src/inference-methods.jl b/EpiAware/src/inference-methods.jl index 9b47397d9..4640f5b17 100644 --- a/EpiAware/src/inference-methods.jl +++ b/EpiAware/src/inference-methods.jl @@ -32,13 +32,14 @@ of tries is reached. - `mdl::DynamicPPL.Model`: The model to perform inference on. - `max_tries`: The maximum number of tries to run the pathfinder algorithm. Default is `Inf`. +- `nruns`: The number of times to run the `pathfinder` function. - `kwargs...`: Additional keyword arguments passed to `pathfinder`. # Returns - `pfs`: The updated array of pathfinder objects. """ -function _continue_manypathfinder!(pfs, mdl::DynamicPPL.Model; max_tries, kwargs...) +function _continue_manypathfinder!(pfs, mdl::DynamicPPL.Model; max_tries, nruns, kwargs...) tryiter = 1 if all(pfs .== :fail) @warn "All initial pathfinder runs failed, trying again for $max_tries tries." @@ -53,8 +54,8 @@ function _continue_manypathfinder!(pfs, mdl::DynamicPPL.Model; max_tries, kwargs tryiter += 1 end if all(pfs .== :fail) - @warn "All pathfinder runs failed after $max_tries tries. Returning failed - pathfinder." + e = ErrorException("All pathfinder runs failed after $max_tries tries.") + throw(e) end return pfs end @@ -97,6 +98,6 @@ function manypathfinder(mdl::DynamicPPL.Model; nruns = 4, ndraws = 10, nchains = 4, maxiters = 50, max_tries = 100, kwargs...) ndraws = max(ndraws, nchains) _run_manypathfinder(mdl; nruns, ndraws, maxiters, kwargs...) |> - pfs -> _continue_manypathfinder!(pfs, mdl; max_tries, kwargs...) |> + pfs -> _continue_manypathfinder!(pfs, mdl; max_tries, nruns, kwargs...) |> pfs -> _get_best_elbo_pathfinder(pfs) end From 6deb30a67f770b4f2e70e4c1a1e19856d51534b2 Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Fri, 8 Mar 2024 11:49:11 +0000 Subject: [PATCH 40/40] Change `manypathfinder` to have some API as `multipathfinder` --- EpiAware/docs/src/examples/getting_started.jl | 2 +- EpiAware/src/inference-methods.jl | 2 +- EpiAware/test/test_inference-methods.jl | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/EpiAware/docs/src/examples/getting_started.jl b/EpiAware/docs/src/examples/getting_started.jl index f79eb29a3..d9e49bb6c 100644 --- a/EpiAware/docs/src/examples/getting_started.jl +++ b/EpiAware/docs/src/examples/getting_started.jl @@ -267,7 +267,7 @@ To make NUTS more robust we provide `manypathfinder`, which is built on pathfind " # ╔═╡ 40ebd47a-4a08-4a46-a727-26347d3fca51 -best_pf = manypathfinder(inference_mdl; nruns = 20); +best_pf = manypathfinder(inference_mdl, 10; nruns = 20, executor = Transducers.ThreadedEx()); # ╔═╡ b7d9a56a-b2d5-4595-a6b9-9cd5fa6b1445 md" diff --git a/EpiAware/src/inference-methods.jl b/EpiAware/src/inference-methods.jl index 4640f5b17..e105390cf 100644 --- a/EpiAware/src/inference-methods.jl +++ b/EpiAware/src/inference-methods.jl @@ -94,7 +94,7 @@ largest ELBO estimate. ## Returns - `best_pfs::PathfinderResult`: Best pathfinder result by estimated ELBO. """ -function manypathfinder(mdl::DynamicPPL.Model; nruns = 4, ndraws = 10, +function manypathfinder(mdl::DynamicPPL.Model, ndraws; nruns = 4, nchains = 4, maxiters = 50, max_tries = 100, kwargs...) ndraws = max(ndraws, nchains) _run_manypathfinder(mdl; nruns, ndraws, maxiters, kwargs...) |> diff --git a/EpiAware/test/test_inference-methods.jl b/EpiAware/test/test_inference-methods.jl index 752facfc6..c7f10e186 100644 --- a/EpiAware/test/test_inference-methods.jl +++ b/EpiAware/test/test_inference-methods.jl @@ -111,7 +111,7 @@ end maxiters = 50 max_tries = 100 - best_pf = manypathfinder(mdl; nruns = nruns, ndraws = ndraws, nchains = nchains, + best_pf = manypathfinder(mdl, ndraws; nruns = nruns, nchains = nchains, maxiters = maxiters, max_tries = max_tries) @test best_pf isa PathfinderResult @@ -128,7 +128,7 @@ end maxiters = 50 max_tries = 10 - best_pf = manypathfinder(mdl; nruns = nruns, ndraws = ndraws, nchains = nchains, + best_pf = manypathfinder(mdl, ndraws; nruns = nruns, nchains = nchains, maxiters = maxiters, max_tries = max_tries) pathfinder_samples = best_pf.draws |> vec @@ -150,7 +150,7 @@ end nchains = 4 @test_throws "All pathfinder runs failed after $max_tries tries." begin - manypathfinder(badmdl; nruns = nruns, ndraws = ndraws, nchains = nchains, + manypathfinder(badmdl, ndraws; nruns = nruns, nchains = nchains, maxiters = maxiters, max_tries = max_tries) end end