-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Issue 160: Add ObservationErrorModel type and functions (#268)
* add to package and reorg * Revert "add to package and reorg" This reverts commit 31b89d3. * add ObservationErrorModel type and functions * move error models and make Poisson use new abstract type * tests passing * update handling of missingness to maintain expected_obs length * update handling of missingness * change type of some tests * get tests passing * add docs for abstract methods * add docs for new specific methods * add unit tests * move abstract type * remove numerical pading as an option and hard code * correct type specification * fix constructor tests * add a test for new Y_t == y_t check
- Loading branch information
Showing
14 changed files
with
180 additions
and
156 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
39 changes: 39 additions & 0 deletions
39
EpiAware/src/EpiObsModels/ObservationErrorModels/NegativeBinomialError.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
@doc raw" | ||
The `NegativeBinomialError` struct represents an observation model for negative binomial errors. It is a subtype of `AbstractTuringObservationModel`. | ||
## Constructors | ||
- `NegativeBinomialError(; cluster_factor_prior::Distribution = HalfNormal(0.1))`: Constructs a `NegativeBinomialError` object with default values for the cluster factor prior. | ||
- `NegativeBinomialError(cluster_factor_prior::Distribution)`: Constructs a `NegativeBinomialError` object with a specified cluster factor prior. | ||
## Examples | ||
```julia | ||
using Distributions, Turing, EpiAware | ||
nb = NegativeBinomialError() | ||
nb_model = generate_observations(nb, missing, fill(10, 10)) | ||
rand(nb_model) | ||
``` | ||
" | ||
@kwdef struct NegativeBinomialError{S <: Sampleable} <: | ||
AbstractTuringObservationErrorModel | ||
"The prior distribution for the cluster factor." | ||
cluster_factor_prior::S = HalfNormal(0.01) | ||
end | ||
|
||
@doc raw" | ||
Generates observation error priors based on the `NegativeBinomialError` observation model. This function generates the cluster factor prior for the negative binomial error model. | ||
" | ||
@model function generate_observation_error_priors( | ||
obs_model::NegativeBinomialError, Y_t, y_t) | ||
cluster_factor ~ obs_model.cluster_factor_prior | ||
sq_cluster_factor = cluster_factor^2 | ||
return (; sq_cluster_factor) | ||
end | ||
|
||
@doc raw" | ||
This function generates the observation error model based on the negative binomial error model with a positive shift. It dispatches to the `NegativeBinomialMeanClust` distribution. | ||
" | ||
function observation_error(obs_model::NegativeBinomialError, Y_t, sq_cluster_factor) | ||
return NegativeBinomialMeanClust(Y_t, | ||
sq_cluster_factor) | ||
end |
25 changes: 25 additions & 0 deletions
25
EpiAware/src/EpiObsModels/ObservationErrorModels/PoissonError.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
@doc raw" | ||
The `PoissonError` struct represents an observation model for Poisson errors. It | ||
is a subtype of `AbstractTuringObservationErrorModel`. | ||
## Constructors | ||
- `PoissonError()`: Constructs a `PoissonError` object. | ||
## Examples | ||
```julia | ||
using Distributions, Turing, EpiAware | ||
poi = PoissonError() | ||
poi_model = generate_observations(poi, missing, fill(10, 10)) | ||
rand(poi_model) | ||
``` | ||
" | ||
struct PoissonError <: AbstractTuringObservationErrorModel | ||
end | ||
|
||
@doc raw" | ||
The observation error model for Poisson errors. This function generates the | ||
observation error model based on the Poisson error model. | ||
" | ||
function observation_error(obs_model::PoissonError, Y_t) | ||
return Poisson(Y_t) | ||
end |
41 changes: 41 additions & 0 deletions
41
EpiAware/src/EpiObsModels/ObservationErrorModels/methods.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
@doc raw" | ||
Generates observations from an observation error model. It provides support for missing values in observations (`y_t`), and missing values at the beginning of the expected observations (`Y_t`). It also pads the expected observations with a small value (1e-6) to mitigate potential numerical issues. | ||
It dispatches to the `observation_error` function to generate the observation error distribution which uses priors generated by `generate_observation_error_priors` submodel. For most observation error models specific implementations of `observation_error` and `generate_observation_error_priors` are required but a specific implementation of `generate_observations` is not required. | ||
" | ||
@model function EpiAwareBase.generate_observations( | ||
obs_model::AbstractTuringObservationErrorModel, | ||
y_t, | ||
Y_t) | ||
@submodel priors = generate_observation_error_priors(obs_model, y_t, Y_t) | ||
|
||
if ismissing(y_t) | ||
y_t = Vector{Union{Real, Missing}}(missing, length(Y_t)) | ||
else | ||
@assert length(y_t)==length(Y_t) "The observation vector and expected observation vector must have the same length." | ||
end | ||
|
||
pad_Y_t = Y_t .+ 1e-6 | ||
|
||
for i in findfirst(!ismissing, Y_t):length(Y_t) | ||
y_t[i] ~ observation_error(obs_model, pad_Y_t[i], priors...) | ||
end | ||
|
||
return y_t, priors | ||
end | ||
|
||
@doc raw" | ||
Generates priors for the observation error model. This should return a named tuple containing the priors required for generating the observation error distribution. | ||
" | ||
@model function generate_observation_error_priors( | ||
obs_model::AbstractTuringObservationErrorModel, y_t, Y_t) | ||
return NamedTuple() | ||
end | ||
|
||
@doc raw" | ||
The observation error distribution for the observation error model. This function should return the distribution for the observation error given the expected observation value `Y_t` and the priors generated by `generate_observation_error_priors`. | ||
" | ||
function observation_error(obs_model::AbstractTuringObservationErrorModel, Y_t) | ||
@info "No concrete implementation for `observation_error` is defined." | ||
return nothing | ||
end |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.