Skip to content

Commit

Permalink
add poisson errors (#182)
Browse files Browse the repository at this point in the history
* add poisson errors

* format fixes
  • Loading branch information
SamuelBrand1 authored Apr 18, 2024
1 parent ca38828 commit 9a3a519
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 1 deletion.
3 changes: 2 additions & 1 deletion EpiAware/src/EpiObsModels/EpiObsModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@ using ..EpiAwareUtils: censored_pmf, HalfNormal

using Turing, Distributions, DocStringExtensions, SparseArrays

export NegativeBinomialError, LatentDelay, Ascertainment
export PoissonError, NegativeBinomialError, LatentDelay, Ascertainment

include("docstrings.jl")
include("LatentDelay.jl")
include("Ascertainment.jl")
include("PoissonError.jl")
include("NegativeBinomialError.jl")
include("utils.jl")

Expand Down
49 changes: 49 additions & 0 deletions EpiAware/src/EpiObsModels/PoissonError.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
@doc raw"
The `PoissonError` struct represents an observation model for Poisson errors. It
is a subtype of `AbstractTuringObservationModel`.
## Constructors
- `PoissonError(; pos_shift::AbstractFloat = 0.)`: Constructs a `PoissonError`
object with default values for the cluster factor prior and positive shift.
## Examples
```julia
using Distributions, Turing, EpiAware
poi = PoissonError()
poi_model = generate_observations(poi, missing, fill(10, 10))
rand(poi_model)
```
"
struct PoissonError{T <: AbstractFloat} <: AbstractTuringObservationModel
"The positive shift value."
pos_shift::T

function PoissonError(; pos_shift::AbstractFloat = 0.0)
@assert pos_shift>=0.0 "The positive shift value must be non-negative."
new{typeof(pos_shift)}(pos_shift)
end
end

@doc raw"
Generate observations using the `PoissonError` observation model.
# Arguments
- `obs_model::PoissonError`: The observation model.
- `y_t`: The observed values.
- `Y_t`: The true values.
# Returns
- `y_t`: The generated observations.
- An empty named tuple.
"
@model function EpiAwareBase.generate_observations(obs_model::PoissonError, y_t, Y_t)
if ismissing(y_t)
y_t = Vector{Int}(undef, length(Y_t))
end

for i in eachindex(y_t)
y_t[i] ~ Poisson(Y_t[i] + obs_model.pos_shift)
end

return y_t, NamedTuple()
end
35 changes: 35 additions & 0 deletions EpiAware/test/EpiObsModels/PoissonError.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
@testitem "PoissonErrorConstructor" begin
using Distributions
# Test default constructor
poi = PoissonError()
@test poi.pos_shift zero(Float64)
poi_float = PoissonError(; pos_shift = 0.0f0)
@test poi_float.pos_shift zero(Float32)

# Test constructor with pos_shift
poi2 = PoissonError(; pos_shift = 1e-3)
@test poi2.pos_shift 1e-3
end

@testitem "Testing PoissonError against theoretical properties" begin
using Distributions, Turing, HypothesisTests, DynamicPPL

# Set up test parameters
n = 100 # Number of observations
μ = 10.0 # Mean of the poisson distribution

# Define the observation model
poi_obs_model = PoissonError(pos_shift = 0.0)

# Generate observations from the model
Y_t = fill(μ, n) # True values
model = generate_observations(poi_obs_model, missing, Y_t)
samples = sample(model, Prior(), 1000; progress = false)

obs_samples = samples |>
chn -> mapreduce(vcat, generated_quantities(model, chn)) do gen
gen[1]
end

@test isapprox(mean(obs_samples), μ, atol = 0.1) # Test the mean
end

0 comments on commit 9a3a519

Please sign in to comment.