From 986ef7e48d2e635c4390514e85991b81c8ef9674 Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Thu, 10 Oct 2024 22:29:05 +0100 Subject: [PATCH] Discrete valued dists expecting rand return of Union{int,BigInt} --- EpiAware/src/EpiAwareUtils/EpiAwareUtils.jl | 6 +++--- EpiAware/src/EpiAwareUtils/RealValued.jl | 12 ------------ EpiAware/src/EpiAwareUtils/SafeInt.jl | 16 ++++++++++++++++ .../src/EpiAwareUtils/SafeNegativeBinomial.jl | 2 +- EpiAware/src/EpiAwareUtils/SafePoisson.jl | 12 ++++++------ EpiAware/test/EpiAwareUtils/RealValued.jl | 8 -------- EpiAware/test/EpiAwareUtils/SafeInt.jl | 8 ++++++++ .../test/EpiAwareUtils/SafeNegativeBinomial.jl | 2 +- EpiAware/test/EpiAwareUtils/SafePoisson.jl | 8 ++++---- 9 files changed, 39 insertions(+), 35 deletions(-) delete mode 100644 EpiAware/src/EpiAwareUtils/RealValued.jl create mode 100644 EpiAware/src/EpiAwareUtils/SafeInt.jl delete mode 100644 EpiAware/test/EpiAwareUtils/RealValued.jl create mode 100644 EpiAware/test/EpiAwareUtils/SafeInt.jl diff --git a/EpiAware/src/EpiAwareUtils/EpiAwareUtils.jl b/EpiAware/src/EpiAwareUtils/EpiAwareUtils.jl index af0bab2f4..22aceaa91 100644 --- a/EpiAware/src/EpiAwareUtils/EpiAwareUtils.jl +++ b/EpiAware/src/EpiAwareUtils/EpiAwareUtils.jl @@ -15,8 +15,8 @@ import Base: eltype using Distributions, DocStringExtensions, QuadGK, Statistics, Turing #Export Structures -export HalfNormal, DirectSample, SafePoisson, SafeNegativeBinomial, RealValued, - RealUnivariateDistribution +export HalfNormal, DirectSample, SafePoisson, SafeNegativeBinomial, SafeIntValued, SafeInt, + SafeDiscreteUnivariateDistribution #Export functions export scan, spread_draws, censored_cdf, censored_pmf, get_param_array, prefix_submodel, ∫F @@ -34,7 +34,7 @@ include("turing-methods.jl") include("DirectSample.jl") include("post-inference.jl") include("get_param_array.jl") -include("RealValued.jl") +include("SafeInt.jl") include("SafePoisson.jl") include("SafeNegativeBinomial.jl") diff --git a/EpiAware/src/EpiAwareUtils/RealValued.jl b/EpiAware/src/EpiAwareUtils/RealValued.jl deleted file mode 100644 index 9b5c4c462..000000000 --- a/EpiAware/src/EpiAwareUtils/RealValued.jl +++ /dev/null @@ -1,12 +0,0 @@ -""" -A type to represent real-valued distributions, the purpose of this type is to avoid problems -with the `eltype` function when having `rand` calls in the model. -""" -struct RealValued <: Distributions.ValueSupport end -Base.eltype(::Type{<:Distributions.Sampleable{F, RealValued}}) where {F} = Real - -""" -A constant alias for `Distribution{Univariate, RealValued}`. This type represents a univariate distribution with real-valued outcomes. -""" -const RealUnivariateDistribution = Distributions.Distribution{ - Distributions.Univariate, RealValued} diff --git a/EpiAware/src/EpiAwareUtils/SafeInt.jl b/EpiAware/src/EpiAwareUtils/SafeInt.jl new file mode 100644 index 000000000..a4d26fad4 --- /dev/null +++ b/EpiAware/src/EpiAwareUtils/SafeInt.jl @@ -0,0 +1,16 @@ +const SafeInt = Union{Int, BigInt} + +""" +A type to represent real-valued distributions, the purpose of this type is to avoid problems +with the `eltype` function when having `rand` calls in the model. +""" +struct SafeIntValued <: Distributions.ValueSupport end +function Base.eltype(::Type{<:Distributions.Sampleable{F, SafeIntValued}}) where {F} + Union{Int, BigInt} +end + +""" +A constant alias for `Distribution{Univariate, RealValued}`. This type represents a univariate distribution with real-valued outcomes. +""" +const SafeDiscreteUnivariateDistribution = Distributions.Distribution{ + Distributions.Univariate, SafeIntValued} diff --git a/EpiAware/src/EpiAwareUtils/SafeNegativeBinomial.jl b/EpiAware/src/EpiAwareUtils/SafeNegativeBinomial.jl index 010cdb0ac..e57714b80 100644 --- a/EpiAware/src/EpiAwareUtils/SafeNegativeBinomial.jl +++ b/EpiAware/src/EpiAwareUtils/SafeNegativeBinomial.jl @@ -65,7 +65,7 @@ var(d) 2.4617291430060293e40 ``` " -struct SafeNegativeBinomial{T <: Real} <: RealUnivariateDistribution +struct SafeNegativeBinomial{T <: Real} <: SafeDiscreteUnivariateDistribution r::T p::T diff --git a/EpiAware/src/EpiAwareUtils/SafePoisson.jl b/EpiAware/src/EpiAwareUtils/SafePoisson.jl index 0ec1122b9..c85243ebc 100644 --- a/EpiAware/src/EpiAwareUtils/SafePoisson.jl +++ b/EpiAware/src/EpiAwareUtils/SafePoisson.jl @@ -45,7 +45,7 @@ var(d) 7.016735912097631e20 ``` " -struct SafePoisson{T <: Real} <: RealUnivariateDistribution +struct SafePoisson{T <: Real} <: SafeDiscreteUnivariateDistribution λ::T SafePoisson{T}(λ::Real) where {T <: Real} = new{T}(λ) @@ -142,12 +142,12 @@ ad_rand(λ) = ad_rand(Random.GLOBAL_RNG, λ) function ad_rand(rng::AbstractRNG, λ) s = sqrt(λ) d = 6.0 * λ^2 - L = floor(λ - 1.1484) + L = _safe_int_floor(λ - 1.1484) # Step N G = λ + s * randn(rng) if G >= 0.0 - K = floor(G) + K = _safe_int_floor(G) # Step I if K >= L return K @@ -177,7 +177,7 @@ function ad_rand(rng::AbstractRNG, λ) continue end - K = floor(λ + s * T) + K = _safe_int_floor(λ + s * T) px, py, fx, fy = procf(λ, K, s) c = 0.1069 / λ @@ -229,7 +229,7 @@ function log1pmx(x::Float64) end # Procedure F -function procf(λ, K, s::Float64) +function procf(λ, K::SafeInt, s::Float64) # can be pre-computed, but does not seem to affect performance ω = 0.3989422804014327 / s b1 = 0.041666666666666664 / λ @@ -241,7 +241,7 @@ function procf(λ, K, s::Float64) if K < 10 px = -float(λ) - py = λ^K / factorial(floor(Int, K)) + py = λ^K / factorial(K) else δ = 0.08333333333333333 / K δ -= 4.8 * δ^3 diff --git a/EpiAware/test/EpiAwareUtils/RealValued.jl b/EpiAware/test/EpiAwareUtils/RealValued.jl deleted file mode 100644 index 13205084d..000000000 --- a/EpiAware/test/EpiAwareUtils/RealValued.jl +++ /dev/null @@ -1,8 +0,0 @@ -@testitem "RealValued Type Tests" begin - using Distributions - struct DummySampleable <: Sampleable{Univariate, RealValued} end - - @test RealValued <: Distributions.ValueSupport - @test eltype(DummySampleable) == Real - @test RealUnivariateDistribution == Distribution{Univariate, RealValued} -end diff --git a/EpiAware/test/EpiAwareUtils/SafeInt.jl b/EpiAware/test/EpiAwareUtils/SafeInt.jl new file mode 100644 index 000000000..d1a59dfab --- /dev/null +++ b/EpiAware/test/EpiAwareUtils/SafeInt.jl @@ -0,0 +1,8 @@ +@testitem "SafeInt Type Tests" begin + using Distributions + struct DummySampleable <: Sampleable{Univariate, SafeIntValued} end + + @test SafeIntValued <: Distributions.ValueSupport + @test eltype(DummySampleable) <: Union{Int, BigInt} + @test SafeDiscreteUnivariateDistribution == Distribution{Univariate, SafeIntValued} +end diff --git a/EpiAware/test/EpiAwareUtils/SafeNegativeBinomial.jl b/EpiAware/test/EpiAwareUtils/SafeNegativeBinomial.jl index 7e5bb98cd..4e1ee09db 100644 --- a/EpiAware/test/EpiAwareUtils/SafeNegativeBinomial.jl +++ b/EpiAware/test/EpiAwareUtils/SafeNegativeBinomial.jl @@ -70,7 +70,7 @@ end dist = SafeNegativeBinomial(r, p) @testset "Large value of mean samples a BigInt with SafePoisson" begin - @test rand(dist) isa Real + @test rand(dist) isa Union{Int, BigInt} end @testset "Large value of mean sample failure with Poisson" begin _dist = EpiAware.EpiAwareUtils._negbin(dist) diff --git a/EpiAware/test/EpiAwareUtils/SafePoisson.jl b/EpiAware/test/EpiAwareUtils/SafePoisson.jl index e4f889254..874cf3423 100644 --- a/EpiAware/test/EpiAwareUtils/SafePoisson.jl +++ b/EpiAware/test/EpiAwareUtils/SafePoisson.jl @@ -2,9 +2,9 @@ λ = 10.0 dist = SafePoisson(λ) @test typeof(dist) <: SafePoisson - @test rand(dist) isa Real - @test rand(dist, 10) isa Vector{Real} - @test rand(dist, 10, 10) isa Array{Real} + @test rand(dist) isa SafeInt + @test rand(dist, 10) isa Vector{SafeInt} + @test rand(dist, 10, 10) isa Array{SafeInt} end @testitem "Check distribution properties of SafePoisson" begin @@ -54,7 +54,7 @@ end bigλ = exp(48.0) #Large value of λ dist = SafePoisson(bigλ) @testset "Large value of mean samples a BigInt with SafePoisson" begin - @test rand(dist) isa Real + @test rand(dist) isa SafeInt end @testset "Large value of mean sample failure with Poisson" begin _dist = Poisson(dist.λ)