diff --git a/EpiAware/src/EpiAwareUtils/EpiAwareUtils.jl b/EpiAware/src/EpiAwareUtils/EpiAwareUtils.jl index 22aceaa91..9635f5ef1 100644 --- a/EpiAware/src/EpiAwareUtils/EpiAwareUtils.jl +++ b/EpiAware/src/EpiAwareUtils/EpiAwareUtils.jl @@ -19,7 +19,7 @@ export HalfNormal, DirectSample, SafePoisson, SafeNegativeBinomial, SafeIntValue SafeDiscreteUnivariateDistribution #Export functions -export scan, spread_draws, censored_cdf, censored_pmf, get_param_array, prefix_submodel, ∫F +export spread_draws, censored_cdf, censored_pmf, get_param_array, prefix_submodel, ∫F # Export accumulate tools export get_state, accumulate_scan @@ -27,7 +27,6 @@ export get_state, accumulate_scan include("docstrings.jl") include("censored_pmf.jl") include("HalfNormal.jl") -include("scan.jl") include("accumulate_scan.jl") include("prefix_submodel.jl") include("turing-methods.jl") diff --git a/EpiAware/src/EpiAwareUtils/SafeNegativeBinomial.jl b/EpiAware/src/EpiAwareUtils/SafeNegativeBinomial.jl index f141215c6..60a29c76d 100644 --- a/EpiAware/src/EpiAwareUtils/SafeNegativeBinomial.jl +++ b/EpiAware/src/EpiAwareUtils/SafeNegativeBinomial.jl @@ -84,7 +84,6 @@ struct SafeNegativeBinomial{T <: Real} <: SafeDiscreteUnivariateDistribution end end -#Outer constructors make AD work function SafeNegativeBinomial(r::T, p::T) where {T <: Real} return SafeNegativeBinomial{T}(r, p) end diff --git a/EpiAware/src/EpiAwareUtils/scan.jl b/EpiAware/src/EpiAwareUtils/scan.jl deleted file mode 100644 index fcf3396cb..000000000 --- a/EpiAware/src/EpiAwareUtils/scan.jl +++ /dev/null @@ -1,47 +0,0 @@ -""" -Apply `f` to each element of `xs` and accumulate the results. - -`f` must be a [callable](https://docs.julialang.org/en/v1/manual/methods/#Function-like-objects) - on a sub-type of `AbstractModel`. - -### Design note -`scan` is being restricted to `AbstractModel` sub-types to ensure: - 1. That compiler specialization is [activated](https://docs.julialang.org/en/v1/manual/performance-tips/#Be-aware-of-when-Julia-avoids-specializing) - 2. Also avoids potential compiler [overhead](https://docs.julialang.org/en/v1/devdocs/functions/#compiler-efficiency-issues) - from specialisation on `f<: Function`. - - - -# Arguments -- `f`: A callable/functor that takes two arguments, `carry` and `x`, and returns a new - `carry` and a result `y`. -- `init`: The initial value for the `carry` variable. -- `xs`: An iterable collection of elements. - -# Returns -- `ys`: An array containing the results of applying `f` to each element of `xs`. -- `carry`: The final value of the `carry` variable after processing all elements of `xs`. - -# Examples - -```jldoctest -using EpiAware - -struct Adder <: EpiAwareBase.AbstractModel end -function (a::Adder)(carry, x) - carry + x, carry + x -end - -scan(Adder(), 0, 1:5) -#output -([1, 3, 6, 10, 15], 15) -""" -function scan(f::F, init, xs) where {F <: EpiAwareBase.AbstractModel} - carry = init - ys = similar(xs) - for (i, x) in enumerate(xs) - carry, y = f(carry, x) - ys[i] = y - end - return ys, carry -end diff --git a/EpiAware/src/EpiInfModels/EpiData.jl b/EpiAware/src/EpiInfModels/EpiData.jl index e88b0e58b..eb393bae3 100644 --- a/EpiAware/src/EpiInfModels/EpiData.jl +++ b/EpiAware/src/EpiInfModels/EpiData.jl @@ -50,16 +50,16 @@ struct EpiData{T <: Real, F <: Function} length(gen_int), transformation) end +end - function EpiData(; gen_distribution::ContinuousDistribution, - D_gen = nothing, - Δd = 1.0, - transformation::Function = exp) - gen_int = censored_pmf(gen_distribution, Δd = Δd, D = D_gen) |> - p -> p[2:end] ./ sum(p[2:end]) +function EpiData(; gen_distribution::ContinuousDistribution, + D_gen = nothing, + Δd = 1.0, + transformation::Function = exp) + gen_int = censored_pmf(gen_distribution, Δd = Δd, D = D_gen) |> + p -> p[2:end] ./ sum(p[2:end]) - return EpiData(gen_int, transformation) - end + return EpiData(gen_int, transformation) end @doc raw" diff --git a/EpiAware/src/EpiInfModels/Renewal.jl b/EpiAware/src/EpiInfModels/Renewal.jl index 557f94c5b..d276d2de8 100644 --- a/EpiAware/src/EpiInfModels/Renewal.jl +++ b/EpiAware/src/EpiInfModels/Renewal.jl @@ -86,18 +86,6 @@ struct Renewal{E, S <: Sampleable, A} <: initialisation_prior::S recurrent_step::A - function Renewal(data::EpiData; initialisation_prior = Normal()) - rev_gen_int = reverse(data.gen_int) - recurrent_step = ConstantRenewalStep(rev_gen_int) - return Renewal(data, initialisation_prior, recurrent_step) - end - - function Renewal(; data::EpiData, initialisation_prior = Normal()) - rev_gen_int = reverse(data.gen_int) - recurrent_step = ConstantRenewalStep(rev_gen_int) - return Renewal(data, initialisation_prior, recurrent_step) - end - function Renewal(data::E, initialisation_prior::S, recurrent_step::A) where { @@ -106,6 +94,18 @@ struct Renewal{E, S <: Sampleable, A} <: end end +function Renewal(data::EpiData; initialisation_prior = Normal()) + rev_gen_int = reverse(data.gen_int) + recurrent_step = ConstantRenewalStep(rev_gen_int) + return Renewal(data, initialisation_prior, recurrent_step) +end + +function Renewal(; data::EpiData, initialisation_prior = Normal()) + rev_gen_int = reverse(data.gen_int) + recurrent_step = ConstantRenewalStep(rev_gen_int) + return Renewal(data, initialisation_prior, recurrent_step) +end + """ Create the initial state of the `Renewal` model. diff --git a/EpiAware/src/EpiLatentModels/manipulators/broadcast/LatentModel.jl b/EpiAware/src/EpiLatentModels/manipulators/broadcast/LatentModel.jl index 1976e936f..6737b648a 100644 --- a/EpiAware/src/EpiLatentModels/manipulators/broadcast/LatentModel.jl +++ b/EpiAware/src/EpiLatentModels/manipulators/broadcast/LatentModel.jl @@ -27,12 +27,6 @@ struct BroadcastLatentModel{ "The broadcast rule to be applied." broadcast_rule::B - function BroadcastLatentModel(model::M; period::Integer, - broadcast_rule::B) where { - M <: AbstractTuringLatentModel, B <: AbstractBroadcastRule} - BroadcastLatentModel(model, period, broadcast_rule) - end - function BroadcastLatentModel(model::M, period::Integer, broadcast_rule::B) where { M <: AbstractTuringLatentModel, B <: AbstractBroadcastRule} @@ -42,6 +36,12 @@ struct BroadcastLatentModel{ end end +function BroadcastLatentModel(model::M; period::Integer, + broadcast_rule::B) where { + M <: AbstractTuringLatentModel, B <: AbstractBroadcastRule} + BroadcastLatentModel(model, period, broadcast_rule) +end + @doc raw" Generates latent periods using the specified `model` and `n` number of samples. diff --git a/EpiAware/test/EpiAwareUtils/scan.jl b/EpiAware/test/EpiAwareUtils/scan.jl deleted file mode 100644 index 8e833e56d..000000000 --- a/EpiAware/test/EpiAwareUtils/scan.jl +++ /dev/null @@ -1,52 +0,0 @@ -@testitem "Testing scan function with addition" begin - # Test case 1: Testing with addition function - function add(a, b) - return a + b, a + b - end - - xs = [1, 2, 3, 4, 5] - expected_ys = [1, 3, 6, 10, 15] - expected_carry = 15 - - # Check that a generic function CAN'T be used - @test_throws MethodError scan(add, 0, xs) - - # Check that a callable subtype of `AbstractEpiModel` CAN be used - struct TestEpiModelAdd <: AbstractEpiModel - end - function (epi_model::TestEpiModelAdd)(a, b) - return a + b, a + b - end - - ys, carry = scan(TestEpiModelAdd(), 0, xs) - - @test ys == expected_ys - @test carry == expected_carry -end - -@testitem "Testing scan function with multiplication" begin - # Test case 2: Testing with multiplication function - function multiply(a, b) - return a * b, a * b - end - - xs = [1, 2, 3, 4, 5] - expected_ys = [1, 2, 6, 24, 120] - expected_carry = 120 - - # Check that a generic function CAN'T be used - @test_throws MethodError ys, carry=scan(multiply, 1, xs) - - # Check that a callable subtype of `AbstractEpiModel` CAN be used - struct TestEpiModelMult <: AbstractEpiModel - end - - function (epi_model::TestEpiModelMult)(a, b) - return a * b, a * b - end - - ys, carry = scan(TestEpiModelMult(), 1, xs) - - @test ys == expected_ys - @test carry == expected_carry -end