Skip to content

Commit

Permalink
Merge pull request #245 from CDCgov/244-bounded-population-size-as-de…
Browse files Browse the repository at this point in the history
…fault

add `RenewalWithPopulation` infection generating process model
  • Loading branch information
SamuelBrand1 authored May 31, 2024
2 parents bf853a3 + ca64262 commit 599c32b
Show file tree
Hide file tree
Showing 6 changed files with 251 additions and 10 deletions.
6 changes: 3 additions & 3 deletions EpiAware/docs/src/examples/getting_started.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
### A Pluto.jl notebook ###
# v0.19.41
# v0.19.42

using Markdown
using InteractiveUtils
Expand Down Expand Up @@ -240,7 +240,7 @@ R_1 = 1 \Big{/} \sum_{t\geq 1} e^{-rt} g_t
log_I0_prior = Normal(log(1.0), 1.0)

# ╔═╡ 8487835e-d430-4300-bd7c-e33f5769ee32
epi = Renewal(model_data, log_I0_prior)
epi = RenewalWithPopulation(model_data, log_I0_prior, 1e8)

# ╔═╡ 2119319f-a2ef-4c96-82c4-3c7eaf40d2e0
md"
Expand Down Expand Up @@ -387,7 +387,7 @@ num_threads = min(10, Threads.nthreads())
# ╔═╡ 88b43e23-1e06-4716-b284-76e8afc6171b
inference_method = EpiMethod(
pre_sampler_steps = [ManyPathfinder(nruns = 4, maxiters = 100)],
sampler = NUTSampler(adtype = AutoForwardDiff(),
sampler = NUTSampler(adtype = AutoReverseDiff(true),
ndraws = 2000,
nchains = num_threads,
mcmc_parallel = MCMCThreads())
Expand Down
2 changes: 1 addition & 1 deletion EpiAware/src/EpiAwareBase/EpiAwareBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ export AbstractModel, AbstractEpiModel, AbstractLatentModel, AbstractObservation

# Export Turing-based models
export AbstractTuringEpiModel, AbstractTuringLatentModel, AbstractTuringIntercept,
AbstractTuringObservationModel
AbstractTuringObservationModel, AbstractTuringRenewal

# Export support types
export AbstractBroadcastRule
Expand Down
5 changes: 5 additions & 0 deletions EpiAware/src/EpiAwareBase/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,8 @@ Abstract supertype for infence/generative methods that are based on sampling
from the posterior distribution, e.g. NUTS.
"""
abstract type AbstractEpiSamplingMethod <: AbstractEpiMethod end

"""
Abstract type for all Turing-based Renewal infection generating models.
"""
abstract type AbstractTuringRenewal <: AbstractTuringEpiModel end
2 changes: 1 addition & 1 deletion EpiAware/src/EpiInfModels/EpiInfModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ using ..EpiAwareUtils: scan, censored_pmf
using Turing, Distributions, DocStringExtensions, LinearAlgebra

#Export models
export EpiData, DirectInfections, ExpGrowthRate, Renewal
export EpiData, DirectInfections, ExpGrowthRate, Renewal, RenewalWithPopulation

#Export functions
export R_to_r, r_to_R
Expand Down
170 changes: 165 additions & 5 deletions EpiAware/src/EpiInfModels/Renewal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ unobserved infections.
I_t = generated_quantities(latent_inf, θ)
```
"
@kwdef struct Renewal{S <: Sampleable} <: EpiAwareBase.AbstractTuringEpiModel
@kwdef struct Renewal{S <: Sampleable} <: EpiAwareBase.AbstractTuringRenewal
data::EpiData
initialisation_prior::S = Normal()
end
Expand Down Expand Up @@ -111,6 +111,166 @@ function (epi_model::Renewal)(recent_incidence, Rt)
new_incidence)
end

"""
Create the initial vector of infected individuals for a renewal model.
# Arguments
- `epi_model::Renewal`: The renewal model.
- `I₀`: The initial number of infected individuals.
- `Rt`: The time-varying reproduction number.
# Returns
The initial vector of infected individuals.
"""
function make_renewal_init(epi_model::Renewal, I₀, Rt)
r_approx = R_to_r(Rt[1], epi_model)
return I₀ * [exp(-r_approx * t) for t in 0:(epi_model.data.len_gen_int - 1)]
end

@doc raw"
Model unobserved/latent infections as due to time-varying Renewal model with reproduction
number ``\mathcal{R}_t`` which is generated by a latent process and a population size
of available people who can be infected `N`.
## Mathematical specification
If ``Z_t`` is a realisation of the latent model, then the unobserved/latent infections are
given by
```math
\begin{align}
\mathcal{R}_t &= g(Z_t),\\
S_t &= S_{t-1} - I_t,\\
I_t &= {S_{t-1} \over N}\mathcal{R}_t \sum_{i=1}^{n-1} I_{t-i} g_i, \qquad t \geq 1, \\
I_t &= g(\hat{I}_0) \exp(r(\mathcal{R}_1) t), \qquad t \leq 0.
\end{align}
```
where ``g`` is a transformation function and the unconstrained initial infections
``\hat{I}_0`` are sampled from a prior distribution. The discrete generation interval is
given by ``g_i``.
``r(\mathcal{R}_1)`` is the exponential growth rate implied by ``\mathcal{R}_1)``
using the implicit relationship between the exponential growth rate and the reproduction
number.
```math
\mathcal{R} \sum_{j \geq 1} g_j \exp(- r j)= 1.
```
`Renewal` are constructed by passing an `EpiData` object `data` and an
`initialisation_prior` for the prior distribution of ``\hat{I}_0``. The default
`initialisation_prior` is `Normal()`.
## Constructor
- `RenewalWithPopulation(; data, initialisation_prior, pop_size)`.
## Example usage with `generate_latent_infs`
`generate_latent_infs` can be used to construct a `Turing` model for the latent infections
conditional on the sample path of a latent process. In this example, we generate a sample
of a white noise latent process.
First, we construct an `Renewal` struct with an `EpiData` object, an initialisation
prior and a transformation function.
```julia
using Distributions, Turing, EpiAware
gen_int = [0.2, 0.3, 0.5]
g = exp
# Create an EpiData object
data = EpiData(gen_int, g)
# Create an Renewal model
renewal_model = RenewalWithPopulation(data = data, initialisation_prior = Normal(), pop_size = 1e6)
```
Then, we can use `generate_latent_infs` to construct a Turing model for the unobserved
infection generation model set by the type of `direct_inf_model`.
```julia
# Construct a Turing model
Z_t = randn(100) * 0.05
latent_inf = generate_latent_infs(renewal_model, Z_t)
```
Now we can use the `Turing` PPL API to sample underlying parameters and generate the
unobserved infections.
```julia
# Sample from the unobserved infections model
#Sample random parameters from prior
θ = rand(latent_inf)
#Get unobserved infections as a generated quantities from the model
I_t = generated_quantities(latent_inf, θ)
```
"
@kwdef struct RenewalWithPopulation{S <: Sampleable} <: EpiAwareBase.AbstractTuringRenewal
data::EpiData
initialisation_prior::S = Normal()
pop_size::Float64 = 1e6
end

@doc """
function (epi_model::RenewalWithPopulation)(recent_incidence_and_available_sus, Rt)
Callable on a `RenewalWithPopulation` struct for compute new incidence based on
recent incidence, Rt and depletion of susceptibles.
## Mathematical specification
The new incidence is given by
```math
I_t = {S_{t-1} / N} R_t \\sum_{i=1}^{n-1} I_{t-i} g_i
```
where `I_t` is the new incidence, `R_t` is the reproduction number, `I_{t-i}` is the recent incidence
and `g_i` is the generation interval.
# Arguments
- `recent_incidence_and_available_sus`: A tuple with an array of recent incidence
values and the remaining susceptible/available individuals.
- `Rt`: Reproduction number.
# Returns
- Tuple containing the updated incidence array and the new `recent_incidence_and_available_sus`
value.
"""
function (epi_model::RenewalWithPopulation)(recent_incidence_and_available_sus, Rt)
recent_incidence, S = recent_incidence_and_available_sus
new_incidence = max(S / epi_model.pop_size, 0.0) * Rt *
dot(recent_incidence, epi_model.data.gen_int)
new_S = S - new_incidence
new_recent_incidence_and_available_sus = (
[new_incidence; recent_incidence[1:(epi_model.data.len_gen_int - 1)]], new_S)

return (new_recent_incidence_and_available_sus, new_incidence)
end

"""
Constructs the initial conditions for a renewal model with population.
# Arguments
- `epi_model::RenewalWithPopulation`: The renewal model with population.
- `I₀`: The initial number of infected individuals.
- `Rt`: The time-varying reproduction number.
# Returns
- A tuple containing the initial number of infected individuals at each generation
interval and the population size of susceptible/available people.
"""
function make_renewal_init(epi_model::RenewalWithPopulation, I₀, Rt)
r_approx = R_to_r(Rt[1], epi_model)
return (I₀ * [exp(-r_approx * t) for t in 0:(epi_model.data.len_gen_int - 1)],
epi_model.pop_size)
end

@doc raw"
Implement the `generate_latent_infs` function for the `Renewal` model.
Expand Down Expand Up @@ -156,14 +316,14 @@ unobserved infections.
I_t = generated_quantities(latent_inf, θ)
```
"
@model function EpiAwareBase.generate_latent_infs(epi_model::Renewal, _Rt)
@model function EpiAwareBase.generate_latent_infs(
epi_model::EpiAwareBase.AbstractTuringRenewal, _Rt)
init_incidence ~ epi_model.initialisation_prior
I₀ = epi_model.data.transformation(init_incidence)
Rt = epi_model.data.transformation.(_Rt)

r_approx = R_to_r(Rt[1], epi_model)
init = I₀ * [exp(-r_approx * t) for t in 0:(epi_model.data.len_gen_int - 1)]

init = make_renewal_init(epi_model, I₀, Rt)
I_t, _ = scan(epi_model, init, Rt)

return I_t
end
76 changes: 76 additions & 0 deletions EpiAware/test/EpiInfModels/Renewal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,41 @@
@test generate_infs(recent_incidence, Rt) == expected_output
end

@testitem "RenewalWithPopulation function: internal generate infs" begin
using LinearAlgebra, Distributions
gen_int = [0.2, 0.3, 0.5]
delay_int = [0.1, 0.4, 0.5]
cluster_coeff = 0.8
time_horizon = 10
transformation = exp
pop_size = 1000.0

data = EpiData(gen_int, transformation)
epi_model = RenewalWithPopulation(data, Normal(), pop_size)

function generate_infs(recent_incidence_and_available_sus, Rt)
recent_incidence, S = recent_incidence_and_available_sus
new_incidence = max(S / epi_model.pop_size, 0.0) * Rt *
dot(recent_incidence, epi_model.data.gen_int)
new_S = S - new_incidence
new_recent_incidence_and_available_sus = (
[new_incidence; recent_incidence[1:(epi_model.data.len_gen_int - 1)]], new_S)

return (new_recent_incidence_and_available_sus, new_incidence)
end

recent_incidence = [10, 20, 30]
Rt = 1.5

expected_new_incidence = Rt * dot(recent_incidence, [0.2, 0.3, 0.5])
expected_new_recent_incidence_and_available_sus = (
[expected_new_incidence; recent_incidence[1:2]], pop_size - expected_new_incidence)
expected_output = (
expected_new_recent_incidence_and_available_sus, expected_new_incidence)

@test generate_infs((recent_incidence, pop_size), Rt) == expected_output
end

@testitem "generate_latent_infs dispatched on Renewal" begin
using Distributions, Turing, HypothesisTests, DynamicPPL, LinearAlgebra
gen_int = [0.2, 0.3, 0.5]
Expand Down Expand Up @@ -61,3 +96,44 @@ end

@test mdl_incidence[1:3] [day1_incidence, day2_incidence, day3_incidence]
end

@testitem "generate_latent_infs dispatched on RenewalWithPopulation" begin
using Distributions, Turing, HypothesisTests, DynamicPPL, LinearAlgebra
gen_int = [0.2, 0.3, 0.5]
transformation = exp
pop_size = 1000.0

data = EpiData(gen_int, transformation)
log_init_incidence_prior = Normal()

renewal_model = RenewalWithPopulation(data, log_init_incidence_prior, pop_size)

#Actual Rt
Rt = [1.0, 1.2, 1.5, 1.5, 1.5]
log_Rt = log.(Rt)
initial_incidence = [1.0, 1.0, 1.0]#aligns with initial exp growth rate of 0.

#Check log_init is sampled from the correct distribution
@time sample_init_inc = sample(generate_latent_infs(renewal_model, log_Rt),
Prior(), 1000; progress = false) |>
chn -> chn[:init_incidence] |>
Array |>
vec

ks_test_pval = ExactOneSampleKSTest(sample_init_inc, log_init_incidence_prior) |> pvalue
@test ks_test_pval > 1e-6 #Very unlikely to fail if the model is correctly implemented

#Check that the generated incidence is correct given correct initialisation
#Check first three days "by hand"
mdl_incidence = generated_quantities(
generate_latent_infs(renewal_model,
log_Rt), (init_incidence = 0.0,))

day1_incidence = dot(initial_incidence, gen_int) * Rt[1]
day2_incidence = ((pop_size - day1_incidence) / pop_size) *
dot(initial_incidence, gen_int) * Rt[2]
day3_incidence = ((pop_size - day1_incidence - day2_incidence) / pop_size) *
dot([day2_incidence, 1.0, 1.0], gen_int) * Rt[3]

@test mdl_incidence[1:3] [day1_incidence, day2_incidence, day3_incidence]
end

0 comments on commit 599c32b

Please sign in to comment.