diff --git a/EpiAware/src/EpiAware.jl b/EpiAware/src/EpiAware.jl index 1d3c4162d..360c8118e 100644 --- a/EpiAware/src/EpiAware.jl +++ b/EpiAware/src/EpiAware.jl @@ -35,16 +35,19 @@ export scan, create_discrete_pmf, growth_rate_to_reproductive_ratio, generate_observation_kernel, - default_rw_priors + default_rw_priors, + neg_MGF, + dneg_MGF_dr, + R_to_r # Exported types -export EpiData, Renewal, ExpGrowthRate, DirectInfections +export EpiData, Renewal, ExpGrowthRate, DirectInfections, AbstractEpiModel # Exported Turing model constructors export make_epi_inference_model, random_walk -include("utilities.jl") include("epimodel.jl") +include("utilities.jl") include("models.jl") include("latent-processes.jl") diff --git a/EpiAware/src/epimodel.jl b/EpiAware/src/epimodel.jl index 4e099248c..052d9f6f4 100644 --- a/EpiAware/src/epimodel.jl +++ b/EpiAware/src/epimodel.jl @@ -60,24 +60,34 @@ struct DirectInfections <: AbstractEpiModel data::EpiData end -function (epi_model::DirectInfections)(recent_incidence, unc_I_t) - nothing, epi_model.data.transformation(unc_I_t) +function (epimodel::DirectInfections)(_It, latent_process_aux) + epimodel.data.transformation.(_It) end struct ExpGrowthRate <: AbstractEpiModel data::EpiData end -function (epi_model::ExpGrowthRate)(unc_recent_incidence, rt) - new_unc_recent_incidence = unc_recent_incidence + rt - new_unc_recent_incidence, epi_model.data.transformation(new_unc_recent_incidence) +function (epimodel::ExpGrowthRate)(rt, latent_process_aux) + latent_process_aux.init .+ cumsum(rt) .|> exp end struct Renewal <: AbstractEpiModel data::EpiData end -function (epi_model::Renewal)(recent_incidence, Rt) - new_incidence = Rt * dot(recent_incidence, epi_model.data.gen_int) - [new_incidence; recent_incidence[1:(epi_model.data.len_gen_int-1)]], new_incidence +function (epimodel::Renewal)(_Rt, latent_process_aux) + I₀ = epimodel.data.transformation(latent_process_aux.init) + Rt = epimodel.data.transformation.(_Rt) + + r_approx = R_to_r(Rt[1], epimodel) + init = I₀ * [exp(-r_approx * t) for t = 0:(epimodel.data.len_gen_int-1)] + + function generate_infs(recent_incidence, Rt) + new_incidence = Rt * dot(recent_incidence, epimodel.data.gen_int) + [new_incidence; recent_incidence[1:(epimodel.data.len_gen_int-1)]], new_incidence + end + + I_t, _ = scan(generate_infs, init, Rt) + return I_t end diff --git a/EpiAware/src/latent-processes.jl b/EpiAware/src/latent-processes.jl index 828dd9eff..5db208e11 100644 --- a/EpiAware/src/latent-processes.jl +++ b/EpiAware/src/latent-processes.jl @@ -14,12 +14,12 @@ end rw = Vector{T}(undef, n) ϵ_t ~ MvNormal(ones(n)) σ²_RW ~ latent_process_priors.var_RW_dist - init_rw_value ~ latent_process_priors.init_rw_value_dist + init ~ latent_process_priors.init_rw_value_dist σ_RW = sqrt(σ²_RW) - rw[1] = init_rw_value + σ_RW * ϵ_t[1] + rw[1] = σ_RW * ϵ_t[1] for t = 2:n rw[t] = rw[t-1] + σ_RW * ϵ_t[t] end - return rw, (; σ_RW, init_rw_value, init = rw[1]) + return rw, (; σ_RW, init) end diff --git a/EpiAware/src/models.jl b/EpiAware/src/models.jl index 948e50351..b127362c3 100644 --- a/EpiAware/src/models.jl +++ b/EpiAware/src/models.jl @@ -12,11 +12,11 @@ #Latent process time_steps = epimodel.data.time_horizon - @submodel latent_process, latent_process_parameters = + @submodel latent_process, latent_process_aux = latent_process(time_steps; latent_process_priors = latent_process_priors) #Transform into infections - I_t, _ = scan(epimodel, latent_process_parameters.init, latent_process) + I_t = epimodel(latent_process, latent_process_aux) #Predictive distribution case_pred_dists = @@ -27,5 +27,5 @@ y_t ~ arraydist(case_pred_dists) #Generate quantities - return (; I_t, latent_process_parameters) + return (; I_t, latent_process, latent_process_aux) end diff --git a/EpiAware/src/utilities.jl b/EpiAware/src/utilities.jl index ddef40845..9388b87e1 100644 --- a/EpiAware/src/utilities.jl +++ b/EpiAware/src/utilities.jl @@ -17,7 +17,7 @@ value. This is similar to the JAX function `jax.lax.scan`. - `ys`: An array containing the result values of applying `f` to each element of `xs`. - `carry`: The final accumulator value. """ -function scan(f, init, xs) +function scan(f::Function, init, xs::Vector{T}) where {T<:Union{Integer,AbstractFloat}} carry = init ys = similar(xs) for (i, x) in enumerate(xs) @@ -109,6 +109,69 @@ function create_discrete_pmf(dist::Distribution; Δd = 1.0, D) ts .|> (t -> ∫F(dist, t, Δd)) |> diff |> p -> p ./ sum(p) end +""" + neg_MGF(r, w::AbstractVector) + +Compute the negative moment generating function (MGF) for a given rate `r` and weights `w`. + +# Arguments +- `r`: The rate parameter. +- `w`: An abstract vector of weights. + +# Returns +The value of the negative MGF. + +""" +function neg_MGF(r, w::AbstractVector) + return sum([w[i] * exp(-r * i) for i = 1:length(w)]) +end + +function dneg_MGF_dr(r, w::AbstractVector) + return -sum([w[i] * i * exp(-r * i) for i = 1:length(w)]) +end + +""" + R_to_r(R₀, w::Vector{T}; newton_steps = 2, Δd = 1.0) + +This function computes an approximation to the exponential growth rate `r` +given the reproductive ratio `R₀` and the discretized generation interval `w` with +discretized interval width `Δd`. This is based on the implicit solution of + +```math +G(r) - {1 \\over R_0} = 0. +``` + +where + +```math +G(r) = \\sum_{i=1}^n w_i e^{-r i}. +``` + +is the negative moment generating function (MGF) of the generation interval distribution. + +The two step approximation is based on: + 1. Direct solution of implicit equation for a small `r` approximation. + 2. Improving the approximation using Newton's method for a fixed number of steps `newton_steps`. + +Returns: +- The approximate value of `r`. +""" +function R_to_r(R₀, w::Vector{T}; newton_steps = 2, Δd = 1.0) where {T<:AbstractFloat} + mean_gen_time = dot(w, 1:length(w)) * Δd + # Small r approximation as initial guess + r_approx = (R₀ - 1) / (R₀ * mean_gen_time) + # Newton's method + for _ = 1:newton_steps + r_approx -= (R₀ * neg_MGF(r_approx, w) - 1) / (R₀ * dneg_MGF_dr(r_approx, w)) + end + return r_approx +end + +function R_to_r(R₀, epimodel::AbstractEpiModel; newton_steps = 2, Δd = 1.0) + R_to_r(R₀, epimodel.data.gen_int; newton_steps = newton_steps, Δd = Δd) +end + + """ growth_rate_to_reproductive_ratio(r, w) @@ -123,7 +186,7 @@ Compute the reproductive ratio given exponential growth rate `r` - The reproductive ratio. """ function growth_rate_to_reproductive_ratio(r, w::AbstractVector) - return 1 / sum([w[i] * exp(-r * i) for i = 1:length(w)]) + return 1 / neg_MGF(r, w::AbstractVector) end diff --git a/EpiAware/test/Aqua.jl b/EpiAware/test/Aqua.jl index 13e56dcc7..7cfc9f488 100644 --- a/EpiAware/test/Aqua.jl +++ b/EpiAware/test/Aqua.jl @@ -1,6 +1,6 @@ @testitem "Aqua.jl" begin using Aqua - Aqua.test_all(EpiAware, ambiguities = false) + Aqua.test_all(EpiAware, ambiguities = false, persistent_tasks = false) Aqua.test_ambiguities(EpiAware) end diff --git a/EpiAware/test/predictive_checking/fast_approx_for_r.jl b/EpiAware/test/predictive_checking/fast_approx_for_r.jl new file mode 100644 index 000000000..f0c314a88 --- /dev/null +++ b/EpiAware/test/predictive_checking/fast_approx_for_r.jl @@ -0,0 +1,73 @@ +#= +# Fast approximation for `r` from `R₀` +I use the negative moment generating function (MGF). + +Let +```math +G(r) = \sum_{i=1}^{\infty} w_i e^{-r i}. +``` +and +```math +f(r, \mathcal{R}_0) = \mathcal{R}_0 G(r) - 1. +``` +then the connection between `R₀` and `r` is given by +```math +f(r, \mathcal{R}_0) = 0. +``` +Given an estimate of $\mathcal{R}_0$ we implicit solve for $r$ using a root +finder algorithm. In this note, I test a fast approximation for $r$ which +should have good autodifferentiation properties. The idea is to start from the +small $r$ approximation to the solution of $f(r, \mathcal{R}_0) = 0$ and then +apply one step of Newton's method. The small $r$ approximation is given by +```math +r_0 = { \mathcal{R}_0 - 1 \over \mathcal{R}_0 \langle W \rangle }. +``` +where $\langle W \rangle$ is the mean of the generation interval. + +To rapidly improve the estimate for `r` we use Newton steps given by +```math +r_{n+1} = r_n - {\mathcal{R}_0 G(r) - 1\over \mathcal{R}_0 G'(r)}. +``` + +=# + +using TestEnv +TestEnv.activate() +using EpiAware +using Distributions +using StatsPlots + +# Create a discrete probability mass function (PMF) for a negative binomial distribution +# with left truncation at 1. + +w = + create_discrete_pmf(NegativeBinomial(2, 0.5), D = 20.0) |> + p -> p[2:end] ./ sum(p[2:end]) + +## + +jitter = 1e-17 +idxs = 0:10 +doubling_times = [1.0, 3.5, 7.0, 14.0] + +errors = mapreduce(hcat, doubling_times) do T_2 + true_r = log(2) / T_2 # 7 day doubling time + R0 = growth_rate_to_reproductive_ratio(true_r, w) + + return map(idxs) do ns + @time r = R_to_r(R0, w, newton_steps = ns) + abs(r - true_r) + jitter + end +end + +plot( + idxs, + errors, + yscale = :log10, + xlabel = "Newton steps", + ylabel = "abs. Error", + title = "Fast approximation for r", + lab = ["T_2 = 1.0" "T_2 = 3.5" "T_2 = 7.0" "T_2 = 14.0"], + yticks = [0.0, 1e-15, 1e-10, 1e-5, 1e0] |> x -> (x .+ jitter, string.(x)), + xticks = 0:2:10, +) diff --git a/EpiAware/test/test_epimodel.jl b/EpiAware/test/test_epimodel.jl index 82e2c703e..e28a909d0 100644 --- a/EpiAware/test/test_epimodel.jl +++ b/EpiAware/test/test_epimodel.jl @@ -52,7 +52,7 @@ end end -@testitem "Renewal function" begin +@testitem "Renewal function: internal generate infs" begin using LinearAlgebra gen_int = [0.2, 0.3, 0.5] delay_int = [0.1, 0.4, 0.5] @@ -61,7 +61,12 @@ end transformation = exp data = EpiData(gen_int, delay_int, cluster_coeff, time_horizon, transformation) - renewal_model = Renewal(data) + epimodel = Renewal(data) + + function generate_infs(recent_incidence, Rt) + new_incidence = Rt * dot(recent_incidence, epimodel.data.gen_int) + [new_incidence; recent_incidence[1:(epimodel.data.len_gen_int-1)]], new_incidence + end recent_incidence = [10, 20, 30] Rt = 1.5 @@ -71,7 +76,7 @@ end [expected_new_incidence; recent_incidence[1:2]], expected_new_incidence - @test renewal_model(recent_incidence, Rt) == expected_output + @test generate_infs(recent_incidence, Rt) == expected_output end @testitem "ExpGrowthRate function" begin @@ -84,15 +89,11 @@ end data = EpiData(gen_int, delay_int, cluster_coeff, time_horizon, transformation) rt_model = ExpGrowthRate(data) - recent_incidence = [10, 20, 30] - rt = log(2) / 7.0 # doubling time of 7 days - - expected_new_incidence = recent_incidence[end] * exp(rt) - expected_output = log(expected_new_incidence), expected_new_incidence + recent_incidence = [10.0, 20.0, 30.0] + log_init = log(5.0) + rt = [log(recent_incidence[1]) - log_init; diff(log.(recent_incidence))] - - @test rt_model(log(recent_incidence[end]), rt)[1] ≈ expected_output[1] - @test rt_model(log(recent_incidence[end]), rt)[2] ≈ expected_output[2] + @test rt_model(rt, (init = log_init,)) ≈ recent_incidence end @testitem "DirectInfections function" begin @@ -105,11 +106,9 @@ end data = EpiData(gen_int, delay_int, cluster_coeff, time_horizon, transformation) direct_inf_model = DirectInfections(data) - recent_log_incidence = [10, 20, 30] .|> log - - expected_new_incidence = exp(recent_log_incidence[end]) - expected_output = nothing, expected_new_incidence + log_incidence = [10, 20, 30] .|> log + expected_incidence = exp.(log_incidence) - @test direct_inf_model(nothing, recent_log_incidence[end]) == expected_output + @test direct_inf_model(log_incidence, nothing) ≈ expected_incidence end diff --git a/EpiAware/test/test_models.jl b/EpiAware/test/test_models.jl index 8aa75bc4c..38d7ca0be 100644 --- a/EpiAware/test/test_models.jl +++ b/EpiAware/test/test_models.jl @@ -28,10 +28,8 @@ # Underlying log-infections are const value 1 for all time steps and # any other unfixed parameters - fixed_test_mdl = fix( - test_mdl, - (init_rw_value = log(1.0), σ²_RW = 0.0, neg_bin_cluster_factor = 0.05), - ) + fixed_test_mdl = + fix(test_mdl, (init = log(1.0), σ²_RW = 0.0, neg_bin_cluster_factor = 0.05)) X = rand(fixed_test_mdl) expected_I_t = [1.0 for _ = 1:epimodel.data.time_horizon] gen = generated_quantities(fixed_test_mdl, rand(fixed_test_mdl)) @@ -39,3 +37,81 @@ # Perform tests @test gen.I_t ≈ expected_I_t end + +@testitem "exp growth with RW latent process" begin + using Distributions, Turing, DynamicPPL + # Define test inputs + y_t = missing # Data will be generated from the model + data = EpiData([0.2, 0.3, 0.5], [0.1, 0.4, 0.5], 0.8, 10, exp) + latent_process_priors = default_rw_priors() + transform_function = exp + n_generate_ahead = 0 + pos_shift = 1e-6 + neg_bin_cluster_factor = 0.5 + neg_bin_cluster_factor_prior = Gamma(3, 0.05 / 3) + + epimodel = ExpGrowthRate(data) + + # Call the function + test_mdl = make_epi_inference_model( + y_t, + epimodel, + random_walk; + latent_process_priors, + pos_shift, + neg_bin_cluster_factor, + neg_bin_cluster_factor_prior, + ) + + # Define expected outputs for a conditional model + # Underlying log-infections are const value 1 for all time steps and + # any other unfixed parameters + + fixed_test_mdl = + fix(test_mdl, (init = log(1.0), σ²_RW = 0.0, neg_bin_cluster_factor = 0.05)) + X = rand(fixed_test_mdl) + expected_I_t = [1.0 for _ = 1:epimodel.data.time_horizon] + gen = generated_quantities(fixed_test_mdl, rand(fixed_test_mdl)) + + # # Perform tests + @test gen.I_t ≈ expected_I_t +end + +@testitem "Renewal with RW latent process" begin + using Distributions, Turing, DynamicPPL + # Define test inputs + y_t = missing # Data will be generated from the model + data = EpiData([0.2, 0.3, 0.5], [0.1, 0.4, 0.5], 0.8, 10, exp) + latent_process_priors = default_rw_priors() + transform_function = exp + n_generate_ahead = 0 + pos_shift = 1e-6 + neg_bin_cluster_factor = 0.5 + neg_bin_cluster_factor_prior = Gamma(3, 0.05 / 3) + + epimodel = Renewal(data) + + # Call the function + test_mdl = make_epi_inference_model( + y_t, + epimodel, + random_walk; + latent_process_priors, + pos_shift, + neg_bin_cluster_factor, + neg_bin_cluster_factor_prior, + ) + + # Define expected outputs for a conditional model + # Underlying log-infections are const value 1 for all time steps and + # any other unfixed parameters + + fixed_test_mdl = + fix(test_mdl, (init = log(1.0), σ²_RW = 0.0, neg_bin_cluster_factor = 0.05)) + X = rand(fixed_test_mdl) + expected_I_t = [1.0 for _ = 1:epimodel.data.time_horizon] + gen = generated_quantities(fixed_test_mdl, rand(fixed_test_mdl)) + + # # Perform tests + @test gen.I_t ≈ expected_I_t +end diff --git a/EpiAware/test/test_utilities.jl b/EpiAware/test/test_utilities.jl index a4ff75329..e52fe9d56 100644 --- a/EpiAware/test/test_utilities.jl +++ b/EpiAware/test/test_utilities.jl @@ -119,3 +119,52 @@ end end end + +@testitem "Testing neg_MGF function" begin + # Test case 1: Testing with positive r and non-empty weight vector + @testset "Test case 1" begin + r = 0.5 + w = [0.2, 0.3, 0.5] + expected_result = 0.2 * exp(-0.5 * 1) + 0.3 * exp(-0.5 * 2) + 0.5 * exp(-0.5 * 3) + result = neg_MGF(r, w) + @test result ≈ expected_result atol = 1e-15 + end + + # Test case 2: Testing with zero r and non-empty weight vector + @testset "Test case 2" begin + r = 0 + w = [0.1, 0.2, 0.3, 0.4] + expected_result = + 0.1 * exp(-0 * 1) + 0.2 * exp(-0 * 2) + 0.3 * exp(-0 * 3) + 0.4 * exp(-0 * 4) + result = neg_MGF(r, w) + @test result ≈ expected_result atol = 1e-15 + end + +end + +@testitem "Testing dneg_MGF_dr function" begin + # Test case 1: Testing with positive r and non-empty weight vector + @testset "Test case 1" begin + r = 0.5 + w = [0.2, 0.3, 0.5] + expected_result = + -(0.2 * 1 * exp(-0.5 * 1) + 0.3 * 2 * exp(-0.5 * 2) + 0.5 * 3 * exp(-0.5 * 3)) + result = dneg_MGF_dr(r, w) + @test result ≈ expected_result atol = 1e-15 + end + + # Test case 2: Testing with zero r and non-empty weight vector + @testset "Test case 2" begin + r = 0 + w = [0.1, 0.2, 0.3, 0.4] + expected_result = -( + 0.1 * 1 * exp(-0 * 1) + + 0.2 * 2 * exp(-0 * 2) + + 0.3 * 3 * exp(-0 * 3) + + 0.4 * 4 * exp(-0 * 4) + ) + result = dneg_MGF_dr(r, w) + @test result ≈ expected_result atol = 1e-15 + end + +end