Skip to content

Commit

Permalink
Issue 160: Add ObservationErrorModel type and functions (#268)
Browse files Browse the repository at this point in the history
* add to package and reorg

* Revert "add to package and reorg"

This reverts commit 31b89d3.

* add ObservationErrorModel type and functions

* move error models and make Poisson use new abstract type

* tests passing

* update handling of missingness to maintain expected_obs length

* update handling of missingness

* change type of some tests

* get tests passing

* add docs for abstract methods

* add docs for new specific methods

* add unit tests

* move abstract type

* remove numerical pading as an option and hard code

* correct type specification

* fix constructor tests

* add a test for new Y_t == y_t check
  • Loading branch information
seabbs authored Jun 11, 2024
1 parent 8e8bf26 commit 1b60c01
Show file tree
Hide file tree
Showing 14 changed files with 180 additions and 156 deletions.
11 changes: 8 additions & 3 deletions EpiAware/src/EpiAwareBase/EpiAwareBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,14 @@ using DocStringExtensions
#Export models
export AbstractModel, AbstractEpiModel, AbstractLatentModel, AbstractObservationModel

# Export Turing-based models
export AbstractTuringEpiModel, AbstractTuringLatentModel, AbstractTuringIntercept,
AbstractTuringObservationModel, AbstractTuringRenewal
# Export Turing-based models EpiModels
export AbstractTuringEpiModel, AbstractTuringRenewal

# Export Turing-based latent models
export AbstractTuringLatentModel, AbstractTuringIntercept

# Export Turing-based observation models
export AbstractTuringObservationModel, AbstractTuringObservationErrorModel

# Export support types
export AbstractBroadcastRule
Expand Down
6 changes: 6 additions & 0 deletions EpiAware/src/EpiAwareBase/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ A abstract type representing a Turing-based observation model.
"""
abstract type AbstractTuringObservationModel <: AbstractObservationModel end

"""
The abstract supertype for all structs that defines a Turing-based model for
generating observation errors.
"""
abstract type AbstractTuringObservationErrorModel <: AbstractTuringObservationModel end

"""
Abstract supertype for all `EpiAware` problems.
"""
Expand Down
10 changes: 7 additions & 3 deletions EpiAware/src/EpiObsModels/EpiObsModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@ using ..EpiLatentModels: HierarchicalNormal, broadcast_dayofweek

using Turing, Distributions, DocStringExtensions, SparseArrays

# Observation models
# Observation error models
export PoissonError, NegativeBinomialError

# Observation error model functions
export generate_observation_error_priors, observation_error

# Observation model modifiers
export LatentDelay, Ascertainment, StackObservationModels

Expand All @@ -25,8 +28,9 @@ include("LatentDelay.jl")
include("ascertainment/Ascertainment.jl")
include("ascertainment/helpers.jl")
include("StackObservationModels.jl")
include("PoissonError.jl")
include("NegativeBinomialError.jl")
include("ObservationErrorModels/methods.jl")
include("ObservationErrorModels/NegativeBinomialError.jl")
include("ObservationErrorModels/PoissonError.jl")
include("utils.jl")

end
11 changes: 7 additions & 4 deletions EpiAware/src/EpiObsModels/LatentDelay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,16 @@ Generates observations based on the `LatentDelay` observation model.
"
@model function EpiAwareBase.generate_observations(obs_model::LatentDelay, y_t, Y_t)
@assert length(obs_model.pmf)<=length(Y_t) "The delay PMF must be shorter than or equal to the observation vector"
first_Y_t = findfirst(!ismissing, Y_t)
trunc_Y_t = Y_t[first_Y_t:end]
@assert length(obs_model.pmf)<=length(trunc_Y_t) "The delay PMF must be shorter than or equal to the observation vector"

kernel = generate_observation_kernel(obs_model.pmf, length(Y_t), partial = false)
expected_obs = kernel * Y_t
kernel = generate_observation_kernel(obs_model.pmf, length(trunc_Y_t), partial = false)
expected_obs = kernel * trunc_Y_t
complete_obs = vcat(fill(missing, length(obs_model.pmf) + first_Y_t - 2), expected_obs)

@submodel y_t, obs_aux = generate_observations(
obs_model.model, y_t, expected_obs)
obs_model.model, y_t, complete_obs)

return y_t, (; obs_aux...)
end
68 changes: 0 additions & 68 deletions EpiAware/src/EpiObsModels/NegativeBinomialError.jl

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
@doc raw"
The `NegativeBinomialError` struct represents an observation model for negative binomial errors. It is a subtype of `AbstractTuringObservationModel`.
## Constructors
- `NegativeBinomialError(; cluster_factor_prior::Distribution = HalfNormal(0.1))`: Constructs a `NegativeBinomialError` object with default values for the cluster factor prior.
- `NegativeBinomialError(cluster_factor_prior::Distribution)`: Constructs a `NegativeBinomialError` object with a specified cluster factor prior.
## Examples
```julia
using Distributions, Turing, EpiAware
nb = NegativeBinomialError()
nb_model = generate_observations(nb, missing, fill(10, 10))
rand(nb_model)
```
"
@kwdef struct NegativeBinomialError{S <: Sampleable} <:
AbstractTuringObservationErrorModel
"The prior distribution for the cluster factor."
cluster_factor_prior::S = HalfNormal(0.01)
end

@doc raw"
Generates observation error priors based on the `NegativeBinomialError` observation model. This function generates the cluster factor prior for the negative binomial error model.
"
@model function generate_observation_error_priors(
obs_model::NegativeBinomialError, Y_t, y_t)
cluster_factor ~ obs_model.cluster_factor_prior
sq_cluster_factor = cluster_factor^2
return (; sq_cluster_factor)
end

@doc raw"
This function generates the observation error model based on the negative binomial error model with a positive shift. It dispatches to the `NegativeBinomialMeanClust` distribution.
"
function observation_error(obs_model::NegativeBinomialError, Y_t, sq_cluster_factor)
return NegativeBinomialMeanClust(Y_t,
sq_cluster_factor)
end
25 changes: 25 additions & 0 deletions EpiAware/src/EpiObsModels/ObservationErrorModels/PoissonError.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
@doc raw"
The `PoissonError` struct represents an observation model for Poisson errors. It
is a subtype of `AbstractTuringObservationErrorModel`.
## Constructors
- `PoissonError()`: Constructs a `PoissonError` object.
## Examples
```julia
using Distributions, Turing, EpiAware
poi = PoissonError()
poi_model = generate_observations(poi, missing, fill(10, 10))
rand(poi_model)
```
"
struct PoissonError <: AbstractTuringObservationErrorModel
end

@doc raw"
The observation error model for Poisson errors. This function generates the
observation error model based on the Poisson error model.
"
function observation_error(obs_model::PoissonError, Y_t)
return Poisson(Y_t)
end
41 changes: 41 additions & 0 deletions EpiAware/src/EpiObsModels/ObservationErrorModels/methods.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
@doc raw"
Generates observations from an observation error model. It provides support for missing values in observations (`y_t`), and missing values at the beginning of the expected observations (`Y_t`). It also pads the expected observations with a small value (1e-6) to mitigate potential numerical issues.
It dispatches to the `observation_error` function to generate the observation error distribution which uses priors generated by `generate_observation_error_priors` submodel. For most observation error models specific implementations of `observation_error` and `generate_observation_error_priors` are required but a specific implementation of `generate_observations` is not required.
"
@model function EpiAwareBase.generate_observations(
obs_model::AbstractTuringObservationErrorModel,
y_t,
Y_t)
@submodel priors = generate_observation_error_priors(obs_model, y_t, Y_t)

if ismissing(y_t)
y_t = Vector{Union{Real, Missing}}(missing, length(Y_t))
else
@assert length(y_t)==length(Y_t) "The observation vector and expected observation vector must have the same length."
end

pad_Y_t = Y_t .+ 1e-6

for i in findfirst(!ismissing, Y_t):length(Y_t)
y_t[i] ~ observation_error(obs_model, pad_Y_t[i], priors...)
end

return y_t, priors
end

@doc raw"
Generates priors for the observation error model. This should return a named tuple containing the priors required for generating the observation error distribution.
"
@model function generate_observation_error_priors(
obs_model::AbstractTuringObservationErrorModel, y_t, Y_t)
return NamedTuple()
end

@doc raw"
The observation error distribution for the observation error model. This function should return the distribution for the observation error given the expected observation value `Y_t` and the priors generated by `generate_observation_error_priors`.
"
function observation_error(obs_model::AbstractTuringObservationErrorModel, Y_t)
@info "No concrete implementation for `observation_error` is defined."
return nothing
end
51 changes: 0 additions & 51 deletions EpiAware/src/EpiObsModels/PoissonError.jl

This file was deleted.

15 changes: 5 additions & 10 deletions EpiAware/test/EpiAwareUtils/generate_epiware.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@

@testitem "`generate_epiaware` with direct infections and RW latent process runs" begin
using Distributions, Turing, DynamicPPL
# Define test inputs
y_t = missing # Data will be generated from the model
data = EpiData([0.2, 0.3, 0.5], exp)
pos_shift = 1e-6
time_horizon = 100

#Define the epi_model
Expand Down Expand Up @@ -32,7 +30,7 @@
gen = generated_quantities(test_mdl, rand(test_mdl))

#Check model sampled
@test eltype(gen.generated_y_t) <: Int
@test eltype(gen.generated_y_t) <: Union{Missing, Real}
@test eltype(gen.I_t) <: AbstractFloat
@test length(gen.I_t) == time_horizon
end
Expand All @@ -42,7 +40,6 @@ end
# Define test inputs
y_t = missing# rand(1:10, 365) # Data will be generated from the model
data = EpiData([0.2, 0.3, 0.5], exp)
pos_shift = 1e-6

#Define the epi_model
epi_model = ExpGrowthRate(data, Normal())
Expand All @@ -56,7 +53,7 @@ end
#Define the observation model - no delay model
time_horizon = 5
obs_model = NegativeBinomialError(
truncated(Gamma(5, 0.05 / 5), 1e-3, 1.0); pos_shift
truncated(Gamma(5, 0.05 / 5), 1e-3, 1.0)
)

# Create full epi model and sample from it
Expand All @@ -70,7 +67,7 @@ end
gens = generated_quantities(test_mdl, chn)

#Check model sampled
@test eltype(gens[1].generated_y_t) <: Int
@test eltype(gens[1].generated_y_t) <: Union{Missing, Real}
@test eltype(gens[1].I_t) <: AbstractFloat
@test length(gens[1].I_t) == time_horizon
end
Expand All @@ -80,7 +77,6 @@ end
# Define test inputs
y_t = missing# rand(1:10, 365) # Data will be generated from the model
data = EpiData([0.2, 0.3, 0.5], exp)
pos_shift = 1e-6

#Define the epi_model
epi_model = Renewal(data, Normal())
Expand All @@ -94,8 +90,7 @@ end
#Define the observation model - no delay model
time_horizon = 5
obs_model = NegativeBinomialError(
truncated(Gamma(5, 0.05 / 5), 1e-3, 1.0);
pos_shift
truncated(Gamma(5, 0.05 / 5), 1e-3, 1.0)
)

# Create full epi model and sample from it
Expand All @@ -110,7 +105,7 @@ end
gens = generated_quantities(test_mdl, chn)

#Check model sampled
@test eltype(gens[1].generated_y_t) <: Int
@test eltype(gens[1].generated_y_t) <: Union{Missing, Real}
@test eltype(gens[1].I_t) <: AbstractFloat
@test length(gens[1].I_t) == time_horizon
end
Loading

0 comments on commit 1b60c01

Please sign in to comment.