Skip to content

Commit

Permalink
Overflow safety for negative-binomial (#173)
Browse files Browse the repository at this point in the history
  • Loading branch information
SamuelBrand1 authored Mar 27, 2024
1 parent b12e482 commit 7943392
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 6 deletions.
10 changes: 7 additions & 3 deletions EpiAware/docs/src/examples/getting_started.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ begin
using LinearAlgebra
using Transducers
using ReverseDiff

Random.seed!(1)
end

# ╔═╡ 3ebc8384-f73d-4597-83a7-07a3744fed61
Expand Down Expand Up @@ -157,9 +159,11 @@ The observation model is a negative binomial distribution parameterised with mea
```math
\sigma_{\text{rel}} =(1/\sqrt{\mu}) + (1 / \sqrt{r}).
```
It is standard to use a half-t distribution for standard deviation priors (e.g. as argued in this [paper](http://www.stat.columbia.edu/~gelman/research/published/taumain.pdf)); we specialise this to a Half-Normal prior and use an _a priori_ assumption that a typical observation fluctuation around the mean (when the mean is $\sim\mathcal{O}(10^2)$) would be 10%. This implies a standard deviation prior,
It is standard to use a half-t distribution for standard deviation priors (e.g. as argued in this [paper](http://www.stat.columbia.edu/~gelman/research/published/taumain.pdf)); we specialise this to a Half-Normal prior and use an _a priori_ assumption that a typical observation fluctuation around the mean (when the mean is $\sim\mathcal{O}(10^2)$) would be 1%, which is close to Poisson noise.
This implies a standard deviation prior,
```math
1 / \sqrt{r} \sim \text{HalfNormal}\Big(0.1 ~\sqrt{{\pi \over 2}}\Big).
1 / \sqrt{r} \sim \text{HalfNormal}\Big(0.01 ~\sqrt{{\pi \over 2}}\Big).
```
The $\sqrt{{\pi \over 2}}$ factor ensures the correct prior mean (see [here](https://en.wikipedia.org/wiki/Half-normal_distribution)).
Expand All @@ -185,7 +189,7 @@ We choose a simple observation model where infections are observed 0, 1, 2, 3 da

# ╔═╡ 448669bc-99f4-4823-b15e-fcc9040ba31b
obs_model = LatentDelay(
NegativeBinomialError(),
NegativeBinomialError(cluster_factor_prior = HalfNormal(0.01)),
fill(0.25, 4)
)

Expand Down
8 changes: 5 additions & 3 deletions EpiAware/src/EpiObsModels/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@ Compute the mean-cluster factor negative binomial distribution.
A `NegativeBinomial` distribution object.
"""
function NegativeBinomialMeanClust(μ, α)
ex_σ² =* μ^2) + 1e-6
p = μ /+ ex_σ² + 1e-6)
r = μ^2 / ex_σ²
= clamp(μ, 1e-6, 1e17)
= clamp(α, 1e-6, Inf)
ex_σ² = (_α *^2)
p = clamp(_μ / (_μ + ex_σ²), 1e-17, 1 - 1e-17)
r = clamp(_μ^2 / ex_σ², 1e-17, 1e17)
return NegativeBinomial(r, p)
end
20 changes: 20 additions & 0 deletions EpiAware/test/EpiObsModels/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,23 @@
@test K == expected_K
end
end

@testitem "Check overflow safety of Negative Binomial sampling" begin
using Distributions
big_mu = 1e30
alpha = 0.5
big_alpha = 1e30

ex_σ² = (alpha * big_mu^2)
p = big_mu / (big_mu + ex_σ²)
r = big_mu^2 / ex_σ²

#Direct definition
nb = NegativeBinomial(r, p)

@test_throws InexactError rand(nb) #Throws error due to overflow

#Safe versions
@test rand(EpiAware.EpiObsModels.NegativeBinomialMeanClust(big_mu, alpha)) isa Int
@test rand(EpiAware.EpiObsModels.NegativeBinomialMeanClust(big_mu, big_alpha)) isa Int
end

0 comments on commit 7943392

Please sign in to comment.