Skip to content

Commit

Permalink
Merge branch 'main' into issue408
Browse files Browse the repository at this point in the history
  • Loading branch information
seabbs authored Oct 24, 2024
2 parents d442dd5 + 0b6a162 commit 54693fc
Show file tree
Hide file tree
Showing 9 changed files with 39 additions and 14 deletions.
5 changes: 4 additions & 1 deletion EpiAware/src/EpiAwareUtils/EpiAwareUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@ using DynamicPPL: Model, fix, condition, @submodel, @model
using MCMCChains: Chains
using Random: AbstractRNG, randexp
using Tables: rowtable
import Base: eltype

using Distributions, DocStringExtensions, QuadGK, Statistics, Turing

#Export Structures
export HalfNormal, DirectSample, SafePoisson, SafeNegativeBinomial
export HalfNormal, DirectSample, SafePoisson, SafeNegativeBinomial, SafeIntValued, SafeInt,
SafeDiscreteUnivariateDistribution

#Export functions
export scan, spread_draws, censored_cdf, censored_pmf, get_param_array, prefix_submodel, ∫F
Expand All @@ -32,6 +34,7 @@ include("turing-methods.jl")
include("DirectSample.jl")
include("post-inference.jl")
include("get_param_array.jl")
include("SafeInt.jl")
include("SafePoisson.jl")
include("SafeNegativeBinomial.jl")

Expand Down
16 changes: 16 additions & 0 deletions EpiAware/src/EpiAwareUtils/SafeInt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
const SafeInt = Union{Int, BigInt}

"""
A type to represent real-valued distributions, the purpose of this type is to avoid problems
with the `eltype` function when having `rand` calls in the model.
"""
struct SafeIntValued <: Distributions.ValueSupport end
function Base.eltype(::Type{<:Distributions.Sampleable{F, SafeIntValued}}) where {F}
SafeInt
end

"""
A constant alias for `Distribution{Univariate, SafeIntValued}`. This type represents a univariate distribution with real-valued outcomes.
"""
const SafeDiscreteUnivariateDistribution = Distributions.Distribution{
Distributions.Univariate, SafeIntValued}
2 changes: 1 addition & 1 deletion EpiAware/src/EpiAwareUtils/SafeNegativeBinomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ var(d)
2.4617291430060293e40
```
"
struct SafeNegativeBinomial{T <: Real} <: DiscreteUnivariateDistribution
struct SafeNegativeBinomial{T <: Real} <: SafeDiscreteUnivariateDistribution
r::T
p::T

Expand Down
6 changes: 3 additions & 3 deletions EpiAware/src/EpiAwareUtils/SafePoisson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ var(d)
7.016735912097631e20
```
"
struct SafePoisson{T <: Real} <: DiscreteUnivariateDistribution
struct SafePoisson{T <: Real} <: SafeDiscreteUnivariateDistribution
λ::T

SafePoisson{T}::Real) where {T <: Real} = new{T}(λ)
Expand Down Expand Up @@ -86,7 +86,7 @@ Distributions.rate(d::SafePoisson) = d.λ
### Statistics

Distributions.mean(d::SafePoisson) = d.λ
Distributions.mode(d::SafePoisson) = _safe_int_floor(d.λ)
Distributions.mode(d::SafePoisson) = floor(d.λ)
Distributions.var(d::SafePoisson) = d.λ
Distributions.skewness(d::SafePoisson) = one(typeof(d.λ)) / sqrt(d.λ)
Distributions.kurtosis(d::SafePoisson) = one(typeof(d.λ)) / d.λ
Expand Down Expand Up @@ -229,7 +229,7 @@ function log1pmx(x::Float64)
end

# Procedure F
function procf(λ, K::Int, s::Float64)
function procf(λ, K::SafeInt, s::Float64)
# can be pre-computed, but does not seem to affect performance
ω = 0.3989422804014327 / s
b1 = 0.041666666666666664 / λ
Expand Down
8 changes: 8 additions & 0 deletions EpiAware/test/EpiAwareUtils/SafeInt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
@testitem "SafeInt Type Tests" begin
using Distributions
struct DummySampleable <: Sampleable{Univariate, SafeIntValued} end

@test SafeIntValued <: Distributions.ValueSupport
@test eltype(DummySampleable) <: Union{Int, BigInt}
@test SafeDiscreteUnivariateDistribution == Distribution{Univariate, SafeIntValued}
end
2 changes: 1 addition & 1 deletion EpiAware/test/EpiAwareUtils/SafeNegativeBinomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ end

dist = SafeNegativeBinomial(r, p)
@testset "Large value of mean samples a BigInt with SafePoisson" begin
@test rand(dist) isa BigInt
@test rand(dist) isa Union{Int, BigInt}
end
@testset "Large value of mean sample failure with Poisson" begin
_dist = EpiAware.EpiAwareUtils._negbin(dist)
Expand Down
8 changes: 4 additions & 4 deletions EpiAware/test/EpiAwareUtils/SafePoisson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
λ = 10.0
dist = SafePoisson(λ)
@test typeof(dist) <: SafePoisson
@test rand(dist) isa Int
@test rand(dist, 10) isa Vector{Int}
@test rand(dist, 10, 10) isa Array{Int}
@test rand(dist) isa SafeInt
@test rand(dist, 10) isa Vector{SafeInt}
@test rand(dist, 10, 10) isa Array{SafeInt}
end

@testitem "Check distribution properties of SafePoisson" begin
Expand Down Expand Up @@ -54,7 +54,7 @@ end
bigλ = exp(48.0) #Large value of λ
dist = SafePoisson(bigλ)
@testset "Large value of mean samples a BigInt with SafePoisson" begin
@test rand(dist) isa BigInt
@test rand(dist) isa SafeInt
end
@testset "Large value of mean sample failure with Poisson" begin
_dist = Poisson(dist.λ)
Expand Down
4 changes: 1 addition & 3 deletions EpiAware/test/EpiObsModels/modifiers/LatentDelay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,7 @@ end
ExpGrowthRate,
Renewal] .|>
em_type -> em_type(
data = EpiData([0.2, 0.5, 0.3],
em_type == Renewal ? softplus : exp
),
data = EpiData([0.2, 0.5, 0.3], exp),
initialisation_prior = Normal(log(100.0), 0.01)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ end

mdl = generate_observations(model, missing, 10)
draw = rand(mdl)
@test typeof(draw[:var"Test.y_t[1]"]) <: Int
@test typeof(draw[:var"Test.y_t[1]"]) <: Real
end

0 comments on commit 54693fc

Please sign in to comment.