Skip to content

Commit

Permalink
Merge branch 'main' into fix-fig1
Browse files Browse the repository at this point in the history
  • Loading branch information
SamuelBrand1 committed Dec 16, 2024
2 parents 382ee8a + ab73d9e commit e7c981f
Show file tree
Hide file tree
Showing 9 changed files with 32 additions and 133 deletions.
3 changes: 1 addition & 2 deletions EpiAware/src/EpiAwareUtils/EpiAwareUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,14 @@ export HalfNormal, DirectSample, SafePoisson, SafeNegativeBinomial, SafeIntValue
SafeDiscreteUnivariateDistribution

#Export functions
export scan, spread_draws, censored_cdf, censored_pmf, get_param_array, prefix_submodel, ∫F
export spread_draws, censored_cdf, censored_pmf, get_param_array, prefix_submodel, ∫F

# Export accumulate tools
export get_state, accumulate_scan

include("docstrings.jl")
include("censored_pmf.jl")
include("HalfNormal.jl")
include("scan.jl")
include("accumulate_scan.jl")
include("prefix_submodel.jl")
include("turing-methods.jl")
Expand Down
1 change: 0 additions & 1 deletion EpiAware/src/EpiAwareUtils/SafeNegativeBinomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ struct SafeNegativeBinomial{T <: Real} <: SafeDiscreteUnivariateDistribution
end
end

#Outer constructors make AD work
function SafeNegativeBinomial(r::T, p::T) where {T <: Real}
return SafeNegativeBinomial{T}(r, p)
end
Expand Down
47 changes: 0 additions & 47 deletions EpiAware/src/EpiAwareUtils/scan.jl

This file was deleted.

16 changes: 8 additions & 8 deletions EpiAware/src/EpiInfModels/EpiData.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,16 @@ struct EpiData{T <: Real, F <: Function}
length(gen_int),
transformation)
end
end

function EpiData(; gen_distribution::ContinuousDistribution,
D_gen = nothing,
Δd = 1.0,
transformation::Function = exp)
gen_int = censored_pmf(gen_distribution, Δd = Δd, D = D_gen) |>
p -> p[2:end] ./ sum(p[2:end])
function EpiData(; gen_distribution::ContinuousDistribution,
D_gen = nothing,
Δd = 1.0,
transformation::Function = exp)
gen_int = censored_pmf(gen_distribution, Δd = Δd, D = D_gen) |>
p -> p[2:end] ./ sum(p[2:end])

return EpiData(gen_int, transformation)
end
return EpiData(gen_int, transformation)
end

@doc raw"
Expand Down
24 changes: 12 additions & 12 deletions EpiAware/src/EpiInfModels/Renewal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,18 +86,6 @@ struct Renewal{E, S <: Sampleable, A} <:
initialisation_prior::S
recurrent_step::A

function Renewal(data::EpiData; initialisation_prior = Normal())
rev_gen_int = reverse(data.gen_int)
recurrent_step = ConstantRenewalStep(rev_gen_int)
return Renewal(data, initialisation_prior, recurrent_step)
end

function Renewal(; data::EpiData, initialisation_prior = Normal())
rev_gen_int = reverse(data.gen_int)
recurrent_step = ConstantRenewalStep(rev_gen_int)
return Renewal(data, initialisation_prior, recurrent_step)
end

function Renewal(data::E,
initialisation_prior::S,
recurrent_step::A) where {
Expand All @@ -106,6 +94,18 @@ struct Renewal{E, S <: Sampleable, A} <:
end
end

function Renewal(data::EpiData; initialisation_prior = Normal())
rev_gen_int = reverse(data.gen_int)
recurrent_step = ConstantRenewalStep(rev_gen_int)
return Renewal(data, initialisation_prior, recurrent_step)
end

function Renewal(; data::EpiData, initialisation_prior = Normal())
rev_gen_int = reverse(data.gen_int)
recurrent_step = ConstantRenewalStep(rev_gen_int)
return Renewal(data, initialisation_prior, recurrent_step)
end

"""
Create the initial state of the `Renewal` model.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,6 @@ struct BroadcastLatentModel{
"The broadcast rule to be applied."
broadcast_rule::B

function BroadcastLatentModel(model::M; period::Integer,
broadcast_rule::B) where {
M <: AbstractTuringLatentModel, B <: AbstractBroadcastRule}
BroadcastLatentModel(model, period, broadcast_rule)
end

function BroadcastLatentModel(model::M, period::Integer,
broadcast_rule::B) where {
M <: AbstractTuringLatentModel, B <: AbstractBroadcastRule}
Expand All @@ -42,6 +36,12 @@ struct BroadcastLatentModel{
end
end

function BroadcastLatentModel(model::M; period::Integer,
broadcast_rule::B) where {
M <: AbstractTuringLatentModel, B <: AbstractBroadcastRule}
BroadcastLatentModel(model, period, broadcast_rule)
end

@doc raw"
Generates latent periods using the specified `model` and `n` number of samples.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ The `TransformLatentModel` struct represents a latent model that applies a trans
## Constructors
- `TransformLatentModel(model, trans_function)`: Constructs a `TransformLatentModel` instance with the specified latent model and transformation function.
- `TransformLatentModel(; model, trans_function)`: Constructs a `TransformLatentModel` instance with the specified latent model and transformation function using named arguments.
- `TransformLatentModel(model, transform)`: Constructs a `TransformLatentModel` instance with the specified latent model and transformation function.
- `TransformLatentModel(; model, transform)`: Constructs a `TransformLatentModel` instance with the specified latent model and transformation function using named arguments.
## Example
Expand All @@ -20,7 +20,7 @@ trans_model()
"The latent model to transform."
model::M
"The transformation function."
trans_function::F
transform::F
end

"""
Expand All @@ -38,6 +38,6 @@ Generate latent variables using the specified `TransformLatentModel`.
"""
@model function EpiAwareBase.generate_latent(model::TransformLatentModel, n)
@submodel untransformed = generate_latent(model.model, n)
latent = model.trans_function(untransformed)
latent = model.transform(untransformed)
return latent
end
52 changes: 0 additions & 52 deletions EpiAware/test/EpiAwareUtils/scan.jl

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
trans = TransformLatentModel(Intercept(Normal(2, 0.2)), x -> x .|> exp)
@test typeof(trans) <: AbstractTuringLatentModel
@test trans.model == Intercept(Normal(2, 0.2))
@test trans.trans_function([1, 2, 3]) == [exp(1), exp(2), exp(3)]
@test trans.transform([1, 2, 3]) == [exp(1), exp(2), exp(3)]
end

@testitem "TransformLatentModel generate_latent method" begin
Expand Down

0 comments on commit e7c981f

Please sign in to comment.