Skip to content

Commit

Permalink
Merge pull request #42 from CDCgov/40-initialisation-infection-genera…
Browse files Browse the repository at this point in the history
…tion-processes

Initialisation for epidemic inference
  • Loading branch information
seabbs authored Feb 19, 2024
2 parents 22a5f5d + e503a62 commit 399c766
Show file tree
Hide file tree
Showing 10 changed files with 313 additions and 40 deletions.
9 changes: 6 additions & 3 deletions EpiAware/src/EpiAware.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,19 @@ export scan,
create_discrete_pmf,
growth_rate_to_reproductive_ratio,
generate_observation_kernel,
default_rw_priors
default_rw_priors,
neg_MGF,
dneg_MGF_dr,
R_to_r

# Exported types
export EpiData, Renewal, ExpGrowthRate, DirectInfections
export EpiData, Renewal, ExpGrowthRate, DirectInfections, AbstractEpiModel

# Exported Turing model constructors
export make_epi_inference_model, random_walk

include("utilities.jl")
include("epimodel.jl")
include("utilities.jl")
include("models.jl")
include("latent-processes.jl")

Expand Down
26 changes: 18 additions & 8 deletions EpiAware/src/epimodel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,24 +60,34 @@ struct DirectInfections <: AbstractEpiModel
data::EpiData
end

function (epi_model::DirectInfections)(recent_incidence, unc_I_t)
nothing, epi_model.data.transformation(unc_I_t)
function (epimodel::DirectInfections)(_It, latent_process_aux)
epimodel.data.transformation.(_It)
end

struct ExpGrowthRate <: AbstractEpiModel
data::EpiData
end

function (epi_model::ExpGrowthRate)(unc_recent_incidence, rt)
new_unc_recent_incidence = unc_recent_incidence + rt
new_unc_recent_incidence, epi_model.data.transformation(new_unc_recent_incidence)
function (epimodel::ExpGrowthRate)(rt, latent_process_aux)
latent_process_aux.init .+ cumsum(rt) .|> exp
end

struct Renewal <: AbstractEpiModel
data::EpiData
end

function (epi_model::Renewal)(recent_incidence, Rt)
new_incidence = Rt * dot(recent_incidence, epi_model.data.gen_int)
[new_incidence; recent_incidence[1:(epi_model.data.len_gen_int-1)]], new_incidence
function (epimodel::Renewal)(_Rt, latent_process_aux)
I₀ = epimodel.data.transformation(latent_process_aux.init)
Rt = epimodel.data.transformation.(_Rt)

r_approx = R_to_r(Rt[1], epimodel)
init = I₀ * [exp(-r_approx * t) for t = 0:(epimodel.data.len_gen_int-1)]

function generate_infs(recent_incidence, Rt)
new_incidence = Rt * dot(recent_incidence, epimodel.data.gen_int)
[new_incidence; recent_incidence[1:(epimodel.data.len_gen_int-1)]], new_incidence
end

I_t, _ = scan(generate_infs, init, Rt)
return I_t
end
6 changes: 3 additions & 3 deletions EpiAware/src/latent-processes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ end
rw = Vector{T}(undef, n)
ϵ_t ~ MvNormal(ones(n))
σ²_RW ~ latent_process_priors.var_RW_dist
init_rw_value ~ latent_process_priors.init_rw_value_dist
init ~ latent_process_priors.init_rw_value_dist
σ_RW = sqrt(σ²_RW)

rw[1] = init_rw_value + σ_RW * ϵ_t[1]
rw[1] = σ_RW * ϵ_t[1]
for t = 2:n
rw[t] = rw[t-1] + σ_RW * ϵ_t[t]
end
return rw, (; σ_RW, init_rw_value, init = rw[1])
return rw, (; σ_RW, init)
end
6 changes: 3 additions & 3 deletions EpiAware/src/models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@

#Latent process
time_steps = epimodel.data.time_horizon
@submodel latent_process, latent_process_parameters =
@submodel latent_process, latent_process_aux =
latent_process(time_steps; latent_process_priors = latent_process_priors)

#Transform into infections
I_t, _ = scan(epimodel, latent_process_parameters.init, latent_process)
I_t = epimodel(latent_process, latent_process_aux)

#Predictive distribution
case_pred_dists =
Expand All @@ -27,5 +27,5 @@
y_t ~ arraydist(case_pred_dists)

#Generate quantities
return (; I_t, latent_process_parameters)
return (; I_t, latent_process, latent_process_aux)
end
67 changes: 65 additions & 2 deletions EpiAware/src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ value. This is similar to the JAX function `jax.lax.scan`.
- `ys`: An array containing the result values of applying `f` to each element of `xs`.
- `carry`: The final accumulator value.
"""
function scan(f, init, xs)
function scan(f::Function, init, xs::Vector{T}) where {T<:Union{Integer,AbstractFloat}}
carry = init
ys = similar(xs)
for (i, x) in enumerate(xs)
Expand Down Expand Up @@ -109,6 +109,69 @@ function create_discrete_pmf(dist::Distribution; Δd = 1.0, D)
ts .|> (t -> ∫F(dist, t, Δd)) |> diff |> p -> p ./ sum(p)
end

"""
neg_MGF(r, w::AbstractVector)
Compute the negative moment generating function (MGF) for a given rate `r` and weights `w`.
# Arguments
- `r`: The rate parameter.
- `w`: An abstract vector of weights.
# Returns
The value of the negative MGF.
"""
function neg_MGF(r, w::AbstractVector)
return sum([w[i] * exp(-r * i) for i = 1:length(w)])
end

function dneg_MGF_dr(r, w::AbstractVector)
return -sum([w[i] * i * exp(-r * i) for i = 1:length(w)])
end

"""
R_to_r(R₀, w::Vector{T}; newton_steps = 2, Δd = 1.0)
This function computes an approximation to the exponential growth rate `r`
given the reproductive ratio `R₀` and the discretized generation interval `w` with
discretized interval width `Δd`. This is based on the implicit solution of
```math
G(r) - {1 \\over R_0} = 0.
```
where
```math
G(r) = \\sum_{i=1}^n w_i e^{-r i}.
```
is the negative moment generating function (MGF) of the generation interval distribution.
The two step approximation is based on:
1. Direct solution of implicit equation for a small `r` approximation.
2. Improving the approximation using Newton's method for a fixed number of steps `newton_steps`.
Returns:
- The approximate value of `r`.
"""
function R_to_r(R₀, w::Vector{T}; newton_steps = 2, Δd = 1.0) where {T<:AbstractFloat}
mean_gen_time = dot(w, 1:length(w)) * Δd
# Small r approximation as initial guess
r_approx = (R₀ - 1) / (R₀ * mean_gen_time)
# Newton's method
for _ = 1:newton_steps
r_approx -= (R₀ * neg_MGF(r_approx, w) - 1) / (R₀ * dneg_MGF_dr(r_approx, w))
end
return r_approx
end

function R_to_r(R₀, epimodel::AbstractEpiModel; newton_steps = 2, Δd = 1.0)
R_to_r(R₀, epimodel.data.gen_int; newton_steps = newton_steps, Δd = Δd)
end


"""
growth_rate_to_reproductive_ratio(r, w)
Expand All @@ -123,7 +186,7 @@ Compute the reproductive ratio given exponential growth rate `r`
- The reproductive ratio.
"""
function growth_rate_to_reproductive_ratio(r, w::AbstractVector)
return 1 / sum([w[i] * exp(-r * i) for i = 1:length(w)])
return 1 / neg_MGF(r, w::AbstractVector)
end


Expand Down
2 changes: 1 addition & 1 deletion EpiAware/test/Aqua.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

@testitem "Aqua.jl" begin
using Aqua
Aqua.test_all(EpiAware, ambiguities = false)
Aqua.test_all(EpiAware, ambiguities = false, persistent_tasks = false)
Aqua.test_ambiguities(EpiAware)
end
73 changes: 73 additions & 0 deletions EpiAware/test/predictive_checking/fast_approx_for_r.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#=
# Fast approximation for `r` from `R₀`
I use the negative moment generating function (MGF).
Let
```math
G(r) = \sum_{i=1}^{\infty} w_i e^{-r i}.
```
and
```math
f(r, \mathcal{R}_0) = \mathcal{R}_0 G(r) - 1.
```
then the connection between `R₀` and `r` is given by
```math
f(r, \mathcal{R}_0) = 0.
```
Given an estimate of $\mathcal{R}_0$ we implicit solve for $r$ using a root
finder algorithm. In this note, I test a fast approximation for $r$ which
should have good autodifferentiation properties. The idea is to start from the
small $r$ approximation to the solution of $f(r, \mathcal{R}_0) = 0$ and then
apply one step of Newton's method. The small $r$ approximation is given by
```math
r_0 = { \mathcal{R}_0 - 1 \over \mathcal{R}_0 \langle W \rangle }.
```
where $\langle W \rangle$ is the mean of the generation interval.
To rapidly improve the estimate for `r` we use Newton steps given by
```math
r_{n+1} = r_n - {\mathcal{R}_0 G(r) - 1\over \mathcal{R}_0 G'(r)}.
```
=#

using TestEnv
TestEnv.activate()
using EpiAware
using Distributions
using StatsPlots

# Create a discrete probability mass function (PMF) for a negative binomial distribution
# with left truncation at 1.

w =
create_discrete_pmf(NegativeBinomial(2, 0.5), D = 20.0) |>
p -> p[2:end] ./ sum(p[2:end])

##

jitter = 1e-17
idxs = 0:10
doubling_times = [1.0, 3.5, 7.0, 14.0]

errors = mapreduce(hcat, doubling_times) do T_2
true_r = log(2) / T_2 # 7 day doubling time
R0 = growth_rate_to_reproductive_ratio(true_r, w)

return map(idxs) do ns
@time r = R_to_r(R0, w, newton_steps = ns)
abs(r - true_r) + jitter
end
end

plot(
idxs,
errors,
yscale = :log10,
xlabel = "Newton steps",
ylabel = "abs. Error",
title = "Fast approximation for r",
lab = ["T_2 = 1.0" "T_2 = 3.5" "T_2 = 7.0" "T_2 = 14.0"],
yticks = [0.0, 1e-15, 1e-10, 1e-5, 1e0] |> x -> (x .+ jitter, string.(x)),
xticks = 0:2:10,
)
31 changes: 15 additions & 16 deletions EpiAware/test/test_epimodel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ end
end


@testitem "Renewal function" begin
@testitem "Renewal function: internal generate infs" begin
using LinearAlgebra
gen_int = [0.2, 0.3, 0.5]
delay_int = [0.1, 0.4, 0.5]
Expand All @@ -61,7 +61,12 @@ end
transformation = exp

data = EpiData(gen_int, delay_int, cluster_coeff, time_horizon, transformation)
renewal_model = Renewal(data)
epimodel = Renewal(data)

function generate_infs(recent_incidence, Rt)
new_incidence = Rt * dot(recent_incidence, epimodel.data.gen_int)
[new_incidence; recent_incidence[1:(epimodel.data.len_gen_int-1)]], new_incidence
end

recent_incidence = [10, 20, 30]
Rt = 1.5
Expand All @@ -71,7 +76,7 @@ end
[expected_new_incidence; recent_incidence[1:2]], expected_new_incidence


@test renewal_model(recent_incidence, Rt) == expected_output
@test generate_infs(recent_incidence, Rt) == expected_output
end

@testitem "ExpGrowthRate function" begin
Expand All @@ -84,15 +89,11 @@ end
data = EpiData(gen_int, delay_int, cluster_coeff, time_horizon, transformation)
rt_model = ExpGrowthRate(data)

recent_incidence = [10, 20, 30]
rt = log(2) / 7.0 # doubling time of 7 days

expected_new_incidence = recent_incidence[end] * exp(rt)
expected_output = log(expected_new_incidence), expected_new_incidence
recent_incidence = [10.0, 20.0, 30.0]
log_init = log(5.0)
rt = [log(recent_incidence[1]) - log_init; diff(log.(recent_incidence))]


@test rt_model(log(recent_incidence[end]), rt)[1] expected_output[1]
@test rt_model(log(recent_incidence[end]), rt)[2] expected_output[2]
@test rt_model(rt, (init = log_init,)) recent_incidence
end

@testitem "DirectInfections function" begin
Expand All @@ -105,11 +106,9 @@ end
data = EpiData(gen_int, delay_int, cluster_coeff, time_horizon, transformation)
direct_inf_model = DirectInfections(data)

recent_log_incidence = [10, 20, 30] .|> log

expected_new_incidence = exp(recent_log_incidence[end])
expected_output = nothing, expected_new_incidence
log_incidence = [10, 20, 30] .|> log

expected_incidence = exp.(log_incidence)

@test direct_inf_model(nothing, recent_log_incidence[end]) == expected_output
@test direct_inf_model(log_incidence, nothing) expected_incidence
end
Loading

0 comments on commit 399c766

Please sign in to comment.