Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 340: Simplify LatentDelay #388

Merged
merged 3 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions EpiAware/src/EpiObsModels/EpiObsModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ module EpiObsModels

using ..EpiAwareBase

using ..EpiAwareUtils: censored_pmf, HalfNormal, prefix_submodel
using ..EpiAwareUtils

using ..EpiLatentModels: HierarchicalNormal, broadcast_dayofweek, PrefixLatentModel

using Turing, Distributions, DocStringExtensions, SparseArrays
using Turing, Distributions, DocStringExtensions, SparseArrays, LinearAlgebra

# Observation error models
export PoissonError, NegativeBinomialError
Expand Down
47 changes: 38 additions & 9 deletions EpiAware/src/EpiObsModels/modifiers/LatentDelay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ observed data.

## Fields
- `model::M`: The underlying observation model.
- `pmf::T`: The probability mass function (PMF) representing the delay distribution.
- `rev_pmf::T`: The probability mass function (PMF) representing the delay distribution reversed.

## Constructors
- `LatentDelay(model::M, distribution::C; D = nothing, Δd = 1.0)
Expand All @@ -23,14 +23,14 @@ observed data.
```julia
using Distributions, Turing, EpiAware
obs = LatentDelay(NegativeBinomialError(), truncated(Normal(5.0, 2.0), 0.0, Inf))
obs_model = generate_observations(obs, missing, fill(10, 10))
rand(obs_model)
obs_model = generate_observations(obs, missing, fill(10, 30))
obs_model()
```
"
struct LatentDelay{M <: AbstractTuringObservationModel, T <: AbstractVector{<:Real}} <:
AbstractTuringObservationModel
model::M
pmf::T
rev_pmf::T

function LatentDelay(model::M, distribution::C; D = nothing,
Δd = 1.0) where {
Expand All @@ -43,7 +43,8 @@ struct LatentDelay{M <: AbstractTuringObservationModel, T <: AbstractVector{<:Re
pmf::T) where {M <: AbstractTuringObservationModel, T <: AbstractVector{<:Real}}
@assert all(pmf .>= 0) "Delay interval must be non-negative"
@assert isapprox(sum(pmf), 1) "Delay interval must sum to 1"
new{typeof(model), typeof(pmf)}(model, pmf)
rev_pmf = reverse(pmf)
new{typeof(model), typeof(rev_pmf)}(model, rev_pmf)
end
end

Expand All @@ -62,14 +63,42 @@ Generates observations based on the `LatentDelay` observation model.
@model function EpiAwareBase.generate_observations(obs_model::LatentDelay, y_t, Y_t)
first_Y_t = findfirst(!ismissing, Y_t)
trunc_Y_t = Y_t[first_Y_t:end]
@assert length(obs_model.pmf)<=length(trunc_Y_t) "The delay PMF must be shorter than or equal to the observation vector"
pmf_length = length(obs_model.rev_pmf)
@assert pmf_length<=length(trunc_Y_t) "The delay PMF must be shorter than or equal to the observation vector"

kernel = generate_observation_kernel(obs_model.pmf, length(trunc_Y_t), partial = false)
expected_obs = kernel * trunc_Y_t
complete_obs = vcat(fill(missing, length(obs_model.pmf) + first_Y_t - 2), expected_obs)
expected_obs = accumulate_scan(
LDStep(obs_model.rev_pmf),
(; val = 0, current = trunc_Y_t[1:(pmf_length)]),
vcat(trunc_Y_t[(pmf_length + 1):end], 0.0)
)

complete_obs = vcat(fill(missing, pmf_length + first_Y_t - 2), expected_obs)

@submodel y_t = generate_observations(
obs_model.model, y_t, complete_obs)

return y_t
end

@doc raw"
The LatentDelay step function struct
"
struct LDStep{D <: AbstractVector{<:Real}} <: AbstractAccumulationStep
rev_pmf::D
end

@doc raw"
The LatentDelay step function method for `accumulate_scan`.
"
function (ld::LDStep)(state, ϵ)
val = dot(ld.rev_pmf, state.current)
current = vcat(state.current[2:end], ϵ)
return (; val, current)
end

@doc raw"
The LatentDelay step function method for get_state.
"
function EpiAwareUtils.get_state(acc_step::LDStep, initial_state, state)
return state .|> x -> x.val
end
8 changes: 4 additions & 4 deletions EpiAware/test/EpiObsModels/modifiers/LatentDelay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
obs_model = LatentDelay(dummy_model, delay_int)

@test obs_model.model == dummy_model
@test obs_model.pmf == delay_int
@test obs_model.rev_pmf == reverse(delay_int)

# Test case 2
delay_distribution = Uniform(0.0, 20.0)
Expand All @@ -20,7 +20,7 @@
obs_model = LatentDelay(dummy_model, delay_distribution, D = D_delay, Δd = Δd)

@test obs_model.model == dummy_model
@test length(obs_model.pmf) == D_delay
@test length(obs_model.rev_pmf) == D_delay

# Test case 3: check default right truncation
delay_distribution = Gamma(3, 15 / 3)
Expand All @@ -30,7 +30,7 @@
obs_model = LatentDelay(dummy_model, delay_distribution, D = D_delay, Δd = Δd)

nn_perc_rounded = invlogcdf(delay_distribution, log(0.99)) |> x -> round(Int64, x)
@test length(obs_model.pmf) == nn_perc_rounded
@test length(obs_model.rev_pmf) == nn_perc_rounded
end

@testitem "Testing delay obs against theoretical properties" begin
Expand Down Expand Up @@ -129,7 +129,7 @@ end
obs_model = LatentDelay(TestObs(), delay_int)

I_t = [10.0, 20.0, 30.0, 40.0, 50.0]
expected_obs = [missing, missing, 23.0, 33.0, 43.0]
expected_obs = [missing, missing, 17.0, 27.0, 37.0]

@testset "Test with entirely missing data" begin
mdl = generate_observations(obs_model, missing, I_t)
Expand Down
Loading