Skip to content

Commit

Permalink
Discrete valued dists expecting rand return of Union{int,BigInt}
Browse files Browse the repository at this point in the history
  • Loading branch information
SamuelBrand1 authored and seabbs committed Oct 11, 2024
1 parent 1ff41c2 commit 7c51d1f
Show file tree
Hide file tree
Showing 9 changed files with 39 additions and 35 deletions.
6 changes: 3 additions & 3 deletions EpiAware/src/EpiAwareUtils/EpiAwareUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ import Base: eltype
using Distributions, DocStringExtensions, QuadGK, Statistics, Turing

#Export Structures
export HalfNormal, DirectSample, SafePoisson, SafeNegativeBinomial, RealValued,
RealUnivariateDistribution
export HalfNormal, DirectSample, SafePoisson, SafeNegativeBinomial, SafeIntValued, SafeInt,
SafeDiscreteUnivariateDistribution

#Export functions
export scan, spread_draws, censored_cdf, censored_pmf, get_param_array, prefix_submodel, ∫F
Expand All @@ -34,7 +34,7 @@ include("turing-methods.jl")
include("DirectSample.jl")
include("post-inference.jl")
include("get_param_array.jl")
include("RealValued.jl")
include("SafeInt.jl")
include("SafePoisson.jl")
include("SafeNegativeBinomial.jl")

Expand Down
12 changes: 0 additions & 12 deletions EpiAware/src/EpiAwareUtils/RealValued.jl

This file was deleted.

16 changes: 16 additions & 0 deletions EpiAware/src/EpiAwareUtils/SafeInt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
const SafeInt = Union{Int, BigInt}

"""
A type to represent real-valued distributions, the purpose of this type is to avoid problems
with the `eltype` function when having `rand` calls in the model.
"""
struct SafeIntValued <: Distributions.ValueSupport end
function Base.eltype(::Type{<:Distributions.Sampleable{F, SafeIntValued}}) where {F}
Union{Int, BigInt}
end

"""
A constant alias for `Distribution{Univariate, RealValued}`. This type represents a univariate distribution with real-valued outcomes.
"""
const SafeDiscreteUnivariateDistribution = Distributions.Distribution{
Distributions.Univariate, SafeIntValued}
2 changes: 1 addition & 1 deletion EpiAware/src/EpiAwareUtils/SafeNegativeBinomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ var(d)
2.4617291430060293e40
```
"
struct SafeNegativeBinomial{T <: Real} <: RealUnivariateDistribution
struct SafeNegativeBinomial{T <: Real} <: SafeDiscreteUnivariateDistribution
r::T
p::T

Expand Down
12 changes: 6 additions & 6 deletions EpiAware/src/EpiAwareUtils/SafePoisson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ var(d)
7.016735912097631e20
```
"
struct SafePoisson{T <: Real} <: RealUnivariateDistribution
struct SafePoisson{T <: Real} <: SafeDiscreteUnivariateDistribution
λ::T

SafePoisson{T}::Real) where {T <: Real} = new{T}(λ)
Expand Down Expand Up @@ -142,12 +142,12 @@ ad_rand(λ) = ad_rand(Random.GLOBAL_RNG, λ)
function ad_rand(rng::AbstractRNG, λ)
s = sqrt(λ)
d = 6.0 * λ^2
L = floor- 1.1484)
L = _safe_int_floor- 1.1484)
# Step N
G = λ + s * randn(rng)

if G >= 0.0
K = floor(G)
K = _safe_int_floor(G)
# Step I
if K >= L
return K
Expand Down Expand Up @@ -177,7 +177,7 @@ function ad_rand(rng::AbstractRNG, λ)
continue
end

K = floor+ s * T)
K = _safe_int_floor+ s * T)
px, py, fx, fy = procf(λ, K, s)
c = 0.1069 / λ

Expand Down Expand Up @@ -229,7 +229,7 @@ function log1pmx(x::Float64)
end

# Procedure F
function procf(λ, K, s::Float64)
function procf(λ, K::SafeInt, s::Float64)
# can be pre-computed, but does not seem to affect performance
ω = 0.3989422804014327 / s
b1 = 0.041666666666666664 / λ
Expand All @@ -241,7 +241,7 @@ function procf(λ, K, s::Float64)

if K < 10
px = -float(λ)
py = λ^K / factorial(floor(Int, K))
py = λ^K / factorial(K)
else
δ = 0.08333333333333333 / K
δ -= 4.8 * δ^3
Expand Down
8 changes: 0 additions & 8 deletions EpiAware/test/EpiAwareUtils/RealValued.jl

This file was deleted.

8 changes: 8 additions & 0 deletions EpiAware/test/EpiAwareUtils/SafeInt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
@testitem "SafeInt Type Tests" begin
using Distributions
struct DummySampleable <: Sampleable{Univariate, SafeIntValued} end

@test SafeIntValued <: Distributions.ValueSupport
@test eltype(DummySampleable) <: Union{Int, BigInt}
@test SafeDiscreteUnivariateDistribution == Distribution{Univariate, SafeIntValued}
end
2 changes: 1 addition & 1 deletion EpiAware/test/EpiAwareUtils/SafeNegativeBinomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ end

dist = SafeNegativeBinomial(r, p)
@testset "Large value of mean samples a BigInt with SafePoisson" begin
@test rand(dist) isa Real
@test rand(dist) isa Union{Int, BigInt}
end
@testset "Large value of mean sample failure with Poisson" begin
_dist = EpiAware.EpiAwareUtils._negbin(dist)
Expand Down
8 changes: 4 additions & 4 deletions EpiAware/test/EpiAwareUtils/SafePoisson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
λ = 10.0
dist = SafePoisson(λ)
@test typeof(dist) <: SafePoisson
@test rand(dist) isa Real
@test rand(dist, 10) isa Vector{Real}
@test rand(dist, 10, 10) isa Array{Real}
@test rand(dist) isa SafeInt
@test rand(dist, 10) isa Vector{SafeInt}
@test rand(dist, 10, 10) isa Array{SafeInt}
end

@testitem "Check distribution properties of SafePoisson" begin
Expand Down Expand Up @@ -54,7 +54,7 @@ end
bigλ = exp(48.0) #Large value of λ
dist = SafePoisson(bigλ)
@testset "Large value of mean samples a BigInt with SafePoisson" begin
@test rand(dist) isa Real
@test rand(dist) isa SafeInt
end
@testset "Large value of mean sample failure with Poisson" begin
_dist = Poisson(dist.λ)
Expand Down

0 comments on commit 7c51d1f

Please sign in to comment.