diff --git a/EpiAware/src/EpiAwareUtils/EpiAwareUtils.jl b/EpiAware/src/EpiAwareUtils/EpiAwareUtils.jl index bcabad2d0..af0bab2f4 100644 --- a/EpiAware/src/EpiAwareUtils/EpiAwareUtils.jl +++ b/EpiAware/src/EpiAwareUtils/EpiAwareUtils.jl @@ -10,11 +10,13 @@ using DynamicPPL: Model, fix, condition, @submodel, @model using MCMCChains: Chains using Random: AbstractRNG, randexp using Tables: rowtable +import Base: eltype using Distributions, DocStringExtensions, QuadGK, Statistics, Turing #Export Structures -export HalfNormal, DirectSample, SafePoisson, SafeNegativeBinomial +export HalfNormal, DirectSample, SafePoisson, SafeNegativeBinomial, RealValued, + RealUnivariateDistribution #Export functions export scan, spread_draws, censored_cdf, censored_pmf, get_param_array, prefix_submodel, ∫F @@ -32,6 +34,7 @@ include("turing-methods.jl") include("DirectSample.jl") include("post-inference.jl") include("get_param_array.jl") +include("RealValued.jl") include("SafePoisson.jl") include("SafeNegativeBinomial.jl") diff --git a/EpiAware/src/EpiAwareUtils/RealValued.jl b/EpiAware/src/EpiAwareUtils/RealValued.jl new file mode 100644 index 000000000..9b5c4c462 --- /dev/null +++ b/EpiAware/src/EpiAwareUtils/RealValued.jl @@ -0,0 +1,12 @@ +""" +A type to represent real-valued distributions, the purpose of this type is to avoid problems +with the `eltype` function when having `rand` calls in the model. +""" +struct RealValued <: Distributions.ValueSupport end +Base.eltype(::Type{<:Distributions.Sampleable{F, RealValued}}) where {F} = Real + +""" +A constant alias for `Distribution{Univariate, RealValued}`. This type represents a univariate distribution with real-valued outcomes. +""" +const RealUnivariateDistribution = Distributions.Distribution{ + Distributions.Univariate, RealValued} diff --git a/EpiAware/src/EpiAwareUtils/SafeNegativeBinomial.jl b/EpiAware/src/EpiAwareUtils/SafeNegativeBinomial.jl index 4341bcc28..010cdb0ac 100644 --- a/EpiAware/src/EpiAwareUtils/SafeNegativeBinomial.jl +++ b/EpiAware/src/EpiAwareUtils/SafeNegativeBinomial.jl @@ -65,7 +65,7 @@ var(d) 2.4617291430060293e40 ``` " -struct SafeNegativeBinomial{T <: Real} <: DiscreteUnivariateDistribution +struct SafeNegativeBinomial{T <: Real} <: RealUnivariateDistribution r::T p::T diff --git a/EpiAware/src/EpiAwareUtils/SafePoisson.jl b/EpiAware/src/EpiAwareUtils/SafePoisson.jl index 46e5201ad..0ec1122b9 100644 --- a/EpiAware/src/EpiAwareUtils/SafePoisson.jl +++ b/EpiAware/src/EpiAwareUtils/SafePoisson.jl @@ -45,7 +45,7 @@ var(d) 7.016735912097631e20 ``` " -struct SafePoisson{T <: Real} <: DiscreteUnivariateDistribution +struct SafePoisson{T <: Real} <: RealUnivariateDistribution λ::T SafePoisson{T}(λ::Real) where {T <: Real} = new{T}(λ) @@ -86,7 +86,7 @@ Distributions.rate(d::SafePoisson) = d.λ ### Statistics Distributions.mean(d::SafePoisson) = d.λ -Distributions.mode(d::SafePoisson) = _safe_int_floor(d.λ) +Distributions.mode(d::SafePoisson) = floor(d.λ) Distributions.var(d::SafePoisson) = d.λ Distributions.skewness(d::SafePoisson) = one(typeof(d.λ)) / sqrt(d.λ) Distributions.kurtosis(d::SafePoisson) = one(typeof(d.λ)) / d.λ @@ -142,12 +142,12 @@ ad_rand(λ) = ad_rand(Random.GLOBAL_RNG, λ) function ad_rand(rng::AbstractRNG, λ) s = sqrt(λ) d = 6.0 * λ^2 - L = _safe_int_floor(λ - 1.1484) + L = floor(λ - 1.1484) # Step N G = λ + s * randn(rng) if G >= 0.0 - K = _safe_int_floor(G) + K = floor(G) # Step I if K >= L return K @@ -177,7 +177,7 @@ function ad_rand(rng::AbstractRNG, λ) continue end - K = _safe_int_floor(λ + s * T) + K = floor(λ + s * T) px, py, fx, fy = procf(λ, K, s) c = 0.1069 / λ @@ -229,7 +229,7 @@ function log1pmx(x::Float64) end # Procedure F -function procf(λ, K::Int, s::Float64) +function procf(λ, K, s::Float64) # can be pre-computed, but does not seem to affect performance ω = 0.3989422804014327 / s b1 = 0.041666666666666664 / λ @@ -241,7 +241,7 @@ function procf(λ, K::Int, s::Float64) if K < 10 px = -float(λ) - py = λ^K / factorial(K) + py = λ^K / factorial(floor(Int, K)) else δ = 0.08333333333333333 / K δ -= 4.8 * δ^3 diff --git a/EpiAware/test/EpiAwareUtils/RealValued.jl b/EpiAware/test/EpiAwareUtils/RealValued.jl new file mode 100644 index 000000000..13205084d --- /dev/null +++ b/EpiAware/test/EpiAwareUtils/RealValued.jl @@ -0,0 +1,8 @@ +@testitem "RealValued Type Tests" begin + using Distributions + struct DummySampleable <: Sampleable{Univariate, RealValued} end + + @test RealValued <: Distributions.ValueSupport + @test eltype(DummySampleable) == Real + @test RealUnivariateDistribution == Distribution{Univariate, RealValued} +end diff --git a/EpiAware/test/EpiAwareUtils/SafeNegativeBinomial.jl b/EpiAware/test/EpiAwareUtils/SafeNegativeBinomial.jl index 8c828e49d..7e5bb98cd 100644 --- a/EpiAware/test/EpiAwareUtils/SafeNegativeBinomial.jl +++ b/EpiAware/test/EpiAwareUtils/SafeNegativeBinomial.jl @@ -70,7 +70,7 @@ end dist = SafeNegativeBinomial(r, p) @testset "Large value of mean samples a BigInt with SafePoisson" begin - @test rand(dist) isa BigInt + @test rand(dist) isa Real end @testset "Large value of mean sample failure with Poisson" begin _dist = EpiAware.EpiAwareUtils._negbin(dist) diff --git a/EpiAware/test/EpiAwareUtils/SafePoisson.jl b/EpiAware/test/EpiAwareUtils/SafePoisson.jl index 6630e066b..e4f889254 100644 --- a/EpiAware/test/EpiAwareUtils/SafePoisson.jl +++ b/EpiAware/test/EpiAwareUtils/SafePoisson.jl @@ -2,9 +2,9 @@ λ = 10.0 dist = SafePoisson(λ) @test typeof(dist) <: SafePoisson - @test rand(dist) isa Int - @test rand(dist, 10) isa Vector{Int} - @test rand(dist, 10, 10) isa Array{Int} + @test rand(dist) isa Real + @test rand(dist, 10) isa Vector{Real} + @test rand(dist, 10, 10) isa Array{Real} end @testitem "Check distribution properties of SafePoisson" begin @@ -54,7 +54,7 @@ end bigλ = exp(48.0) #Large value of λ dist = SafePoisson(bigλ) @testset "Large value of mean samples a BigInt with SafePoisson" begin - @test rand(dist) isa BigInt + @test rand(dist) isa Real end @testset "Large value of mean sample failure with Poisson" begin _dist = Poisson(dist.λ) diff --git a/EpiAware/test/EpiObsModels/modifiers/PrefixObservationModel.jl b/EpiAware/test/EpiObsModels/modifiers/PrefixObservationModel.jl index b46ba54c4..346d2b9ba 100644 --- a/EpiAware/test/EpiObsModels/modifiers/PrefixObservationModel.jl +++ b/EpiAware/test/EpiObsModels/modifiers/PrefixObservationModel.jl @@ -10,5 +10,5 @@ end mdl = generate_observations(model, missing, 10) draw = rand(mdl) - @test typeof(draw[:var"Test.y_t[1]"]) <: Int + @test typeof(draw[:var"Test.y_t[1]"]) <: Real end