Skip to content

Commit

Permalink
New latent models (#203)
Browse files Browse the repository at this point in the history
* Add intercept struct and method

* add a basic test

* add more tests for intercept

* add more tests for intercept

* fix spelling

* rename broadcast helpers

* first pass at CombineeLatentModels

* spin out TransformLatentModels

* fix TransformLatentModel

* add tests for TransformLatentModel and fix Intercept tests for new output structure

* add basic tests for CombineLatentModels

* add FixedIntercpt version of Intercept

* fix FixedIntecept tests and drop uses of scale where we can

* fix ascertainment test

* add HierarchicalNormal

* add abstract method for intercepts

* Add intercept struct and method

* add a basic test

* add more tests for intercept

* add more tests for intercept

* fix spelling

* rename broadcast helpers

* first pass at CombineeLatentModels

* spin out TransformLatentModels

* fix TransformLatentModel

* add tests for TransformLatentModel and fix Intercept tests for new output structure

* add basic tests for CombineLatentModels

* add FixedIntercpt version of Intercept

* fix FixedIntecept tests and drop uses of scale where we can

* fix ascertainment test

* add HierarchicalNormal

* add abstract method for intercepts

* Test conditioning on CombineLatentModels and sampling using NUTS

* modify test to do KS test against theoretical posterior distribution

---------

Co-authored-by: Samuel Brand <[email protected]>
  • Loading branch information
seabbs and SamuelBrand1 authored May 10, 2024
1 parent 4ab6cc2 commit 4262a9b
Show file tree
Hide file tree
Showing 18 changed files with 454 additions and 60 deletions.
3 changes: 2 additions & 1 deletion EpiAware/src/EpiAwareBase/EpiAwareBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ using DocStringExtensions
export AbstractModel, AbstractEpiModel, AbstractLatentModel, AbstractObservationModel

# Export Turing-based models
export AbstractTuringEpiModel, AbstractTuringLatentModel, AbstractTuringObservationModel
export AbstractTuringEpiModel, AbstractTuringLatentModel, AbstractTuringIntercept,
AbstractTuringObservationModel

# Export support types
export AbstractBroadcastRule
Expand Down
5 changes: 5 additions & 0 deletions EpiAware/src/EpiAwareBase/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ A abstract type representing a Turing-based Latent model.
"""
abstract type AbstractTuringLatentModel <: AbstractLatentModel end

"""
A abstract type used to define the common interface for intercept models.
"""
abstract type AbstractTuringIntercept <: AbstractTuringLatentModel end

"""
An abstract type representing a broadcast rule.
"""
Expand Down
62 changes: 62 additions & 0 deletions EpiAware/src/EpiLatentModels/CombineLatentModels.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
@doc raw"
The `CombineLatentModels` struct.
This struct is used to combine multiple latent models into a single latent model.
# Constructors
- `CombineLatentModels(models::M) where {M <: AbstractVector{<:AbstractTuringLatentModel}}`: Constructs a `CombineLatentModels` instance with specified models, ensuring that there are at least two models.
- `CombineLatentModels(; models::M) where {M <: AbstractVector{<:AbstractTuringLatentModel}}`: Constructs a `CombineLatentModels` instance with specified models, ensuring that there are at least two models.
# Examples
```julia
using EpiAware, Distributions
combined_model = CombineLatentModels([Intercept(Normal(2, 0.2)), AR()])
latent_model = generate_latent(combined_model, 10)
latent_model()
```
"
@kwdef struct CombineLatentModels{M <: AbstractVector{<:AbstractTuringLatentModel}} <:
AbstractTuringLatentModel
"A vector of latent models"
models::M

function CombineLatentModels(models::M) where {M <:
AbstractVector{<:AbstractTuringLatentModel}}
@assert length(models)>1 "At least two models are required"
return new{AbstractVector{<:AbstractTuringLatentModel}}(models)
end
end

@doc raw"
Generate latent variables using a combination of multiple latent models.
# Arguments
- `latent_models::CombineLatentModels`: An instance of the `CombineLatentModels` type representing the collection of latent models.
- `n`: The number of latent variables to generate.
# Returns
- `combined_latents`: The combined latent variables generated from all the models.
- `latent_aux`: A tuple containing the auxiliary latent variables generated from each individual model.
# Example
"
@model function EpiAwareBase.generate_latent(latent_models::CombineLatentModels, n)
@submodel final_latent, latent_aux = _accumulate_latents(
latent_models.models, 1, fill(0.0, n), [], n, length(latent_models.models))

return final_latent, (; latent_aux...)
end

@model function _accumulate_latents(models, index, acc_latent, acc_aux, n, n_models)
if index > n_models
return acc_latent, (; acc_aux...)
else
@submodel latent, new_aux = generate_latent(models[index], n)
@submodel updated_latent, updated_aux = _accumulate_latents(
models, index + 1, acc_latent .+ latent,
(; acc_aux..., new_aux...), n, n_models)
return updated_latent, (; updated_aux...)
end
end
11 changes: 9 additions & 2 deletions EpiAware/src/EpiLatentModels/EpiLatentModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,24 @@ using ..EpiAwareUtils: HalfNormal
using Turing, Distributions, DocStringExtensions, LinearAlgebra

#Export models
export RandomWalk, AR, DiffLatentModel, BroadcastLatentModel
export FixedIntercept, Intercept, RandomWalk, AR, HierarchicalNormal

# Export tools for manipulating latent models
export CombineLatentModels, TransformLatentModel, DiffLatentModel, BroadcastLatentModel

# Export broadcast rules
export RepeatEach, RepeatBlock

# Export helper functions
export dayofweek, weekly
export broadcast_dayofweek, broadcast_weekly

include("docstrings.jl")
include("Intercept.jl")
include("RandomWalk.jl")
include("AR.jl")
include("HierarchicalNormal.jl")
include("CombineLatentModels.jl")
include("TransformLatentModel.jl")
include("DiffLatentModel.jl")
include("broadcast/LatentModel.jl")
include("broadcast/rules.jl")
Expand Down
41 changes: 41 additions & 0 deletions EpiAware/src/EpiLatentModels/HierarchicalNormal.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
@doc raw"
The `HierarchicalNormal` struct represents a non-centered hierarchical normal distribution.
## Constructors
- `HierarchicalNormal(mean, std_prior)`: Constructs a `HierarchicalNormal` instance with the specified mean and standard deviation prior.
- `HierarchicalNormal(; mean = 0.0, std_prior = truncated(Normal(0,1), 0, Inf))`: Constructs a `HierarchicalNormal` instance with the specified mean and standard deviation prior using named arguments and with default values.
## Examples
```julia
using Distributions, EpiAware
hnorm = HierarchicalNormal(0.0, truncated(Normal(0, 1), 0, Inf))
hnorm_model = generate_latent(hnorm, 10)
hnorm_model()
```
"
@kwdef struct HierarchicalNormal{R <: Real, D <: Sampleable} <: AbstractTuringLatentModel
mean::R = 0.0
std_prior::D = truncated(Normal(0, 1), 0, Inf)
end

@doc raw"
function EpiAwareBase.generate_latent(obs_model::HierarchicalNormal, n)
Generate latent variables from the hierarchical normal distribution.
# Arguments
- `obs_model::HierarchicalNormal`: The hierarchical normal distribution model.
- `n`: Number of latent variables to generate.
# Returns
- `η_t`: Generated latent variables.
- `std`: Standard deviation used in the generation.
"
@model function EpiAwareBase.generate_latent(obs_model::HierarchicalNormal, n)
std ~ obs_model.std_prior
ϵ_t ~ MvNormal(I(n))
η_t = obs_model.mean .+ std .* ϵ_t
return η_t, (; std = std)
end
78 changes: 78 additions & 0 deletions EpiAware/src/EpiLatentModels/Intercept.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
@doc raw"
The `Intercept` struct is used to model the intercept of a latent process. It
broadcasts a single intercept value to a length `n` latent process.
## Constructors
- Intercept(intercept_prior)
- Intercept(; intercept_prior)
## Examples
```julia
using Distributions, Turing, EpiAware
int = Intercept(Normal(0, 1))
int_model = generate_latent(int, 10)
rand(int_model)
int_model()
```
"
@kwdef struct Intercept{D <: Sampleable} <: AbstractTuringIntercept
"Prior distribution for the intercept."
intercept_prior::D
end

@doc raw"
Generate a latent intercept series.
# Arguments
- `latent_model::Intercept`: The intercept model.
- `n::Int`: The length of the intercept series.
# Returns
- `intercept::Vector{Float64}`: The generated intercept series.
- `metadata::NamedTuple`: A named tuple containing the intercept value.
"
@model function EpiAwareBase.generate_latent(latent_model::Intercept, n)
intercept ~ latent_model.intercept_prior
return fill(intercept, n), (; intercept = intercept)
end

@doc raw"
A variant of the `Intercept` struct that represents a fixed intercept value for a latent model.
# Constructors
- `FixedIntercept(intercept)` : Constructs a `FixedIntercept` instance with the specified intercept value.
- `FixedIntercept(; intercept)` : Constructs a `FixedIntercept` instance with the specified intercept value using named arguments.
# Examples
```julia
using EpiAware
fi = FixedIntercept(2.0)
fi_model = generate_latent(fi, 10)
fi_model()
```
"
@kwdef struct FixedIntercept{F <: Real} <: AbstractTuringIntercept
intercept::F
end

@doc raw"
Generate a latent intercept series with a fixed intercept value.
# Arguments
- `latent_model::FixedIntercept`: The fixed intercept latent model.
- `n`: The number of latent variables to generate.
# Returns
- `latent_vars`: An array of length `n` filled with the fixed intercept value.
- `metadata`: A named tuple containing the intercept value.
"
@model function EpiAwareBase.generate_latent(latent_model::FixedIntercept, n)
return fill(latent_model.intercept, n), (; intercept = latent_model.intercept)
end
44 changes: 44 additions & 0 deletions EpiAware/src/EpiLatentModels/TransformLatentModel.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
@doc raw"
The `TransformLatentModel` struct represents a latent model that applies a transformation function to the latent variables generated by another latent model.
## Constructors
- `TransformLatentModel(model, trans_function)`: Constructs a `TransformLatentModel` instance with the specified latent model and transformation function.
- `TransformLatentModel(; model, trans_function)`: Constructs a `TransformLatentModel` instance with the specified latent model and transformation function using named arguments.
## Example
```julia
using EpiAware, Distributions
trans = TransformLatentModel(Intercept(Normal(2, 0.2)), x -> x .|> exp)
trans_model = generate_latent(trans, 5)
trans_model()
```
"
@kwdef struct TransformLatentModel{M <: AbstractTuringLatentModel, F <: Function} <:
AbstractTuringLatentModel
"The latent model to transform."
model::M
"The transformation function."
trans_function::F
end

"""
generate_latent(model::TransformLatentModel, n)
Generate latent variables using the specified `TransformLatentModel`.
# Arguments
- `model::TransformLatentModel`: The `TransformLatentModel` to generate latent variables from.
- `n`: The number of latent variables to generate.
# Returns
- `transformed`: The transformed latent variables.
- `latent_aux`: Additional auxiliary variables generated by the underlying latent model.
"""
@model function EpiAwareBase.generate_latent(model::TransformLatentModel, n)
@submodel untransformed, latent_aux = generate_latent(model.model, n)
latent = model.trans_function(untransformed)
return latent, (; latent_aux)
end
4 changes: 2 additions & 2 deletions EpiAware/src/EpiLatentModels/broadcast/helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Constructs a `BroadcastLatentModel` appropriate for modelling the day of the wee
# Returns
- `BroadcastLatentModel`: The broadcast latent model.
"
function dayofweek(model::AbstractTuringLatentModel)
function broadcast_dayofweek(model::AbstractTuringLatentModel)
return BroadcastLatentModel(model, 7, RepeatEach())
end

Expand All @@ -20,6 +20,6 @@ Constructs a `BroadcastLatentModel` appropriate for modelling piecewise constant
# Returns
- `BroadcastLatentModel`: The broadcast latent model.
"
function weekly(model::AbstractTuringLatentModel)
function broadcast_weekly(model::AbstractTuringLatentModel)
return BroadcastLatentModel(model, 7, RepeatBlock())
end
2 changes: 1 addition & 1 deletion EpiAware/src/EpiLatentModels/broadcast/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ end
@doc raw"
`RepeatBlock` is a struct that represents a broadcasting rule. It is a subtype of `AbstractBroadcastRule`.
It repeats the latent process in blocks of size `period`. An example of this rule is to repeat the latent process in blocks of size 7 to model a weekly process (though for this we also provide the `weekly` helper function).
It repeats the latent process in blocks of size `period`. An example of this rule is to repeat the latent process in blocks of size 7 to model a weekly process (though for this we also provide the `broadcast_weekly` helper function).
## Examples
```julia
Expand Down
12 changes: 1 addition & 11 deletions EpiAware/src/EpiObsModels/Ascertainment.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,7 @@ The `Ascertainment` struct represents an observation model that incorporates asc
# Examples
```julia
using EpiAware, Turing
struct Scale <: AbstractTuringLatentModel
end
@model function EpiAware.generate_latent(model::Scale, n::Int)
scale = 0.1
scale_vect = fill(scale, n)
return scale_vect, (; scale = scale)
end
obs = Ascertainment(NegativeBinomialError(), Scale(), x -> x)
obs = Ascertainment(NegativeBinomialError(), FixedIntercept(0.1), x -> x)
gen_obs = generate_observations(obs, missing, fill(100, 10))
rand(gen_obs)
```
Expand Down
39 changes: 0 additions & 39 deletions EpiAware/test/EpiLatentModels/Ascertainment.jl

This file was deleted.

Loading

0 comments on commit 4262a9b

Please sign in to comment.