Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 160: Add ObservationErrorModel type and functions #268

Merged
merged 17 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions EpiAware/src/EpiAwareBase/EpiAwareBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,14 @@ using DocStringExtensions
#Export models
export AbstractModel, AbstractEpiModel, AbstractLatentModel, AbstractObservationModel

# Export Turing-based models
export AbstractTuringEpiModel, AbstractTuringLatentModel, AbstractTuringIntercept,
AbstractTuringObservationModel, AbstractTuringRenewal
# Export Turing-based models EpiModels
export AbstractTuringEpiModel, AbstractTuringRenewal

# Export Turing-based latent models
export AbstractTuringLatentModel, AbstractTuringIntercept

# Export Turing-based observation models
export AbstractTuringObservationModel, AbstractTuringObservationErrorModel

# Export support types
export AbstractBroadcastRule
Expand Down
6 changes: 6 additions & 0 deletions EpiAware/src/EpiAwareBase/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ A abstract type representing a Turing-based observation model.
"""
abstract type AbstractTuringObservationModel <: AbstractObservationModel end

"""
The abstract supertype for all structs that defines a Turing-based model for
generating observation errors.
"""
abstract type AbstractTuringObservationErrorModel <: AbstractTuringObservationModel end

"""
Abstract supertype for all `EpiAware` problems.
"""
Expand Down
10 changes: 7 additions & 3 deletions EpiAware/src/EpiObsModels/EpiObsModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@ using ..EpiLatentModels: HierarchicalNormal, broadcast_dayofweek

using Turing, Distributions, DocStringExtensions, SparseArrays

# Observation models
# Observation error models
export PoissonError, NegativeBinomialError

# Observation error model functions
export generate_observation_error_priors, observation_error

# Observation model modifiers
export LatentDelay, Ascertainment, StackObservationModels

Expand All @@ -25,8 +28,9 @@ include("LatentDelay.jl")
include("ascertainment/Ascertainment.jl")
include("ascertainment/helpers.jl")
include("StackObservationModels.jl")
include("PoissonError.jl")
include("NegativeBinomialError.jl")
include("ObservationErrorModels/methods.jl")
include("ObservationErrorModels/NegativeBinomialError.jl")
include("ObservationErrorModels/PoissonError.jl")
include("utils.jl")

end
11 changes: 7 additions & 4 deletions EpiAware/src/EpiObsModels/LatentDelay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,16 @@ Generates observations based on the `LatentDelay` observation model.

"
@model function EpiAwareBase.generate_observations(obs_model::LatentDelay, y_t, Y_t)
@assert length(obs_model.pmf)<=length(Y_t) "The delay PMF must be shorter than or equal to the observation vector"
first_Y_t = findfirst(!ismissing, Y_t)
seabbs marked this conversation as resolved.
Show resolved Hide resolved
trunc_Y_t = Y_t[first_Y_t:end]
@assert length(obs_model.pmf)<=length(trunc_Y_t) "The delay PMF must be shorter than or equal to the observation vector"

kernel = generate_observation_kernel(obs_model.pmf, length(Y_t), partial = false)
expected_obs = kernel * Y_t
kernel = generate_observation_kernel(obs_model.pmf, length(trunc_Y_t), partial = false)
expected_obs = kernel * trunc_Y_t
complete_obs = vcat(fill(missing, length(obs_model.pmf) + first_Y_t - 2), expected_obs)

@submodel y_t, obs_aux = generate_observations(
obs_model.model, y_t, expected_obs)
obs_model.model, y_t, complete_obs)

return y_t, (; obs_aux...)
end
68 changes: 0 additions & 68 deletions EpiAware/src/EpiObsModels/NegativeBinomialError.jl

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
@doc raw"

The `NegativeBinomialError` struct represents an observation model for negative binomial errors. It is a subtype of `AbstractTuringObservationModel`.

## Constructors
- `NegativeBinomialError(; cluster_factor_prior::Distribution = HalfNormal(0.1))`: Constructs a `NegativeBinomialError` object with default values for the cluster factor prior.
- `NegativeBinomialError(cluster_factor_prior::Distribution)`: Constructs a `NegativeBinomialError` object with a specified cluster factor prior.

## Examples
```julia
using Distributions, Turing, EpiAware
nb = NegativeBinomialError()
nb_model = generate_observations(nb, missing, fill(10, 10))
rand(nb_model)
```
"
@kwdef struct NegativeBinomialError{S <: Sampleable} <:
AbstractTuringObservationErrorModel
"The prior distribution for the cluster factor."
cluster_factor_prior::S = HalfNormal(0.01)
end

@doc raw"
Generates observation error priors based on the `NegativeBinomialError` observation model. This function generates the cluster factor prior for the negative binomial error model.
"
@model function generate_observation_error_priors(
obs_model::NegativeBinomialError, Y_t, y_t)
cluster_factor ~ obs_model.cluster_factor_prior
sq_cluster_factor = cluster_factor^2
return (; sq_cluster_factor)
end

@doc raw"
This function generates the observation error model based on the negative binomial error model with a positive shift. It dispatches to the `NegativeBinomialMeanClust` distribution.
"
function observation_error(obs_model::NegativeBinomialError, Y_t, sq_cluster_factor)
return NegativeBinomialMeanClust(Y_t,
sq_cluster_factor)
end
25 changes: 25 additions & 0 deletions EpiAware/src/EpiObsModels/ObservationErrorModels/PoissonError.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
@doc raw"
The `PoissonError` struct represents an observation model for Poisson errors. It
is a subtype of `AbstractTuringObservationErrorModel`.

## Constructors
- `PoissonError()`: Constructs a `PoissonError` object.

## Examples
```julia
using Distributions, Turing, EpiAware
poi = PoissonError()
poi_model = generate_observations(poi, missing, fill(10, 10))
rand(poi_model)
```
"
struct PoissonError <: AbstractTuringObservationErrorModel
end

@doc raw"
The observation error model for Poisson errors. This function generates the
observation error model based on the Poisson error model.
"
function observation_error(obs_model::PoissonError, Y_t)
return Poisson(Y_t)
end
41 changes: 41 additions & 0 deletions EpiAware/src/EpiObsModels/ObservationErrorModels/methods.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
@doc raw"
Generates observations from an observation error model. It provides support for missing values in observations (`y_t`), and missing values at the beginning of the expected observations (`Y_t`). It also pads the expected observations with a small value (1e-6) to mitigate potential numerical issues.

It dispatches to the `observation_error` function to generate the observation error distribution which uses priors generated by `generate_observation_error_priors` submodel. For most observation error models specific implementations of `observation_error` and `generate_observation_error_priors` are required but a specific implementation of `generate_observations` is not required.
"
@model function EpiAwareBase.generate_observations(
obs_model::AbstractTuringObservationErrorModel,
y_t,
Y_t)
@submodel priors = generate_observation_error_priors(obs_model, y_t, Y_t)

if ismissing(y_t)
y_t = Vector{Union{Real, Missing}}(missing, length(Y_t))
else
@assert length(y_t)==length(Y_t) "The observation vector and expected observation vector must have the same length."
end

pad_Y_t = Y_t .+ 1e-6

for i in findfirst(!ismissing, Y_t):length(Y_t)
y_t[i] ~ observation_error(obs_model, pad_Y_t[i], priors...)
end

return y_t, priors
end

@doc raw"
Generates priors for the observation error model. This should return a named tuple containing the priors required for generating the observation error distribution.
"
@model function generate_observation_error_priors(
obs_model::AbstractTuringObservationErrorModel, y_t, Y_t)
return NamedTuple()
end

@doc raw"
The observation error distribution for the observation error model. This function should return the distribution for the observation error given the expected observation value `Y_t` and the priors generated by `generate_observation_error_priors`.
"
function observation_error(obs_model::AbstractTuringObservationErrorModel, Y_t)
@info "No concrete implementation for `observation_error` is defined."
return nothing
end
51 changes: 0 additions & 51 deletions EpiAware/src/EpiObsModels/PoissonError.jl

This file was deleted.

15 changes: 5 additions & 10 deletions EpiAware/test/EpiAwareUtils/generate_epiware.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@

@testitem "`generate_epiaware` with direct infections and RW latent process runs" begin
using Distributions, Turing, DynamicPPL
# Define test inputs
y_t = missing # Data will be generated from the model
data = EpiData([0.2, 0.3, 0.5], exp)
pos_shift = 1e-6
time_horizon = 100

#Define the epi_model
Expand Down Expand Up @@ -32,7 +30,7 @@
gen = generated_quantities(test_mdl, rand(test_mdl))

#Check model sampled
@test eltype(gen.generated_y_t) <: Int
@test eltype(gen.generated_y_t) <: Union{Missing, Real}
@test eltype(gen.I_t) <: AbstractFloat
@test length(gen.I_t) == time_horizon
end
Expand All @@ -42,7 +40,6 @@ end
# Define test inputs
y_t = missing# rand(1:10, 365) # Data will be generated from the model
data = EpiData([0.2, 0.3, 0.5], exp)
pos_shift = 1e-6

#Define the epi_model
epi_model = ExpGrowthRate(data, Normal())
Expand All @@ -56,7 +53,7 @@ end
#Define the observation model - no delay model
time_horizon = 5
obs_model = NegativeBinomialError(
truncated(Gamma(5, 0.05 / 5), 1e-3, 1.0); pos_shift
truncated(Gamma(5, 0.05 / 5), 1e-3, 1.0)
)

# Create full epi model and sample from it
Expand All @@ -70,7 +67,7 @@ end
gens = generated_quantities(test_mdl, chn)

#Check model sampled
@test eltype(gens[1].generated_y_t) <: Int
@test eltype(gens[1].generated_y_t) <: Union{Missing, Real}
@test eltype(gens[1].I_t) <: AbstractFloat
@test length(gens[1].I_t) == time_horizon
end
Expand All @@ -80,7 +77,6 @@ end
# Define test inputs
y_t = missing# rand(1:10, 365) # Data will be generated from the model
data = EpiData([0.2, 0.3, 0.5], exp)
pos_shift = 1e-6

#Define the epi_model
epi_model = Renewal(data, Normal())
Expand All @@ -94,8 +90,7 @@ end
#Define the observation model - no delay model
time_horizon = 5
obs_model = NegativeBinomialError(
truncated(Gamma(5, 0.05 / 5), 1e-3, 1.0);
pos_shift
truncated(Gamma(5, 0.05 / 5), 1e-3, 1.0)
)

# Create full epi model and sample from it
Expand All @@ -110,7 +105,7 @@ end
gens = generated_quantities(test_mdl, chn)

#Check model sampled
@test eltype(gens[1].generated_y_t) <: Int
@test eltype(gens[1].generated_y_t) <: Union{Missing, Real}
@test eltype(gens[1].I_t) <: AbstractFloat
@test length(gens[1].I_t) == time_horizon
end
Loading
Loading