Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 252: Composing complex models #296

Merged
merged 17 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions EpiAware/src/EpiAwareUtils/EpiAwareUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ module EpiAwareUtils
using ..EpiAwareBase

using DataFramesMeta: DataFrame, @rename!
using DynamicPPL: Model, fix, condition
using DynamicPPL: Model, fix, condition, @submodel, @model
using MCMCChains: Chains
using Random: AbstractRNG
using Tables: rowtable
Expand All @@ -17,12 +17,13 @@ using Distributions, DocStringExtensions, QuadGK, Statistics, Turing
export HalfNormal, DirectSample

#Export functions
export scan, spread_draws, censored_pmf, get_param_array
export scan, spread_draws, censored_pmf, get_param_array, prefix_submodel

include("docstrings.jl")
include("censored_pmf.jl")
include("HalfNormal.jl")
include("scan.jl")
include("prefix_submodel.jl")
include("turing-methods.jl")
include("DirectSample.jl")
include("post-inference.jl")
Expand Down
30 changes: 30 additions & 0 deletions EpiAware/src/EpiAwareUtils/prefix_submodel.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
@doc raw"
Generate a submodel with an optional prefix. A lightweight wrapper around the `@submodel` macro from DynamicPPL.jl.

# Arguments

- `model::AbstractModel`: The model to be used.
- `fn::Function`: The Turing @model function to be applied to the model.
- `prefix::String`: The prefix to be used. If the prefix is an empty string, the submodel is created without a prefix.

# Returns

- `submodel`: The returns from the submodel are passed through.

# Examples

```julia
using EpiAware
submodel = prefix_submodel(CombineLatentModels([FixedIntercept(0.1), AR()]), generate_latent, \"Test\", 10)
rand(submodel)
```
"
@model function prefix_submodel(
seabbs marked this conversation as resolved.
Show resolved Hide resolved
model::AbstractModel, fn::Function, prefix::String, kwargs...)
if prefix == ""
@submodel submodel = fn(model, kwargs...)
else
@submodel prefix=eval(prefix) submodel=fn(model, kwargs...)
end
return submodel
end
6 changes: 3 additions & 3 deletions EpiAware/src/EpiLatentModels/EpiLatentModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ module EpiLatentModels

using ..EpiAwareBase

using ..EpiAwareUtils: HalfNormal
using ..EpiAwareUtils: HalfNormal, prefix_submodel

using LogExpFunctions: softmax

Expand All @@ -26,7 +26,7 @@ export RepeatEach, RepeatBlock
export broadcast_dayofweek, broadcast_weekly, equal_dimensions

# Export tools for modifying latent models
export DiffLatentModel, TransformLatentModel
export DiffLatentModel, TransformLatentModel, PrefixLatentModel

include("docstrings.jl")
include("models/Intercept.jl")
Expand All @@ -35,9 +35,9 @@ include("models/AR.jl")
include("models/HierarchicalNormal.jl")
include("modifiers/DiffLatentModel.jl")
include("modifiers/TransformLatentModel.jl")
include("modifiers/PrefixLatentModel.jl")
include("manipulators/CombineLatentModels.jl")
include("manipulators/ConcatLatentModels.jl")

include("manipulators/broadcast/LatentModel.jl")
include("manipulators/broadcast/rules.jl")
include("manipulators/broadcast/helpers.jl")
Expand Down
37 changes: 28 additions & 9 deletions EpiAware/src/EpiLatentModels/manipulators/CombineLatentModels.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
@doc raw"
The `CombineLatentModels` struct.

This struct is used to combine multiple latent models into a single latent model.
This struct is used to combine multiple latent models into a single latent model. If a prefix is supplied wraps each model with `PrefixLatentModel`.

# Constructors

- `CombineLatentModels(models::M) where {M <: AbstractVector{<:AbstractTuringLatentModel}}`: Constructs a `CombineLatentModels` instance with specified models, ensuring that there are at least two models.
- `CombineLatentModels(; models::M) where {M <: AbstractVector{<:AbstractTuringLatentModel}}`: Constructs a `CombineLatentModels` instance with specified models, ensuring that there are at least two models.
- `CombineLatentModels(models::M, prefixes::P) where {M <: AbstractVector{<:AbstractTuringLatentModel}, P <: AbstractVector{<:String}}`: Constructs a `CombineLatentModels` instance with specified models and prefixes, ensuring that there are at least two models and the number of models and prefixes are equal.
- `CombineLatentModels(models::M) where {M <: AbstractVector{<:AbstractTuringLatentModel}}`: Constructs a `CombineLatentModels` instance with specified models, automatically generating prefixes for each model. The
automatic prefixes are of the form `Combine.1`, `Combine.2`, etc.

# Examples

Expand All @@ -17,15 +17,33 @@ latent_model = generate_latent(combined_model, 10)
latent_model()
```
"
@kwdef struct CombineLatentModels{M <: AbstractVector{<:AbstractTuringLatentModel}} <:
@kwdef struct CombineLatentModels{
M <: AbstractVector{<:AbstractTuringLatentModel}, P <: AbstractVector{<:String}} <:
AbstractTuringLatentModel
"A vector of latent models"
models::M
"A vector of prefixes for the latent models"
prefixes::P

function CombineLatentModels(models::M) where {M <:
AbstractVector{<:AbstractTuringLatentModel}}
function CombineLatentModels(models::M,
prefixes::P) where {
M <: AbstractVector{<:AbstractTuringLatentModel},
P <: AbstractVector{<:String}}
@assert length(models)>1 "At least two models are required"
return new{AbstractVector{<:AbstractTuringLatentModel}}(models)
@assert length(models)==length(prefixes) "The number of models and prefixes must be equal"
for i in eachindex(models)
if (prefixes[i] != "")
models[i] = PrefixLatentModel(models[i], prefixes[i])
end
end
return new{AbstractVector{<:AbstractTuringLatentModel}, AbstractVector{<:String}}(
models, prefixes)
end

function CombineLatentModels(models::M) where {
M <: AbstractVector{<:AbstractTuringLatentModel}}
prefixes = "Combine." .* string.(1:length(models))
return CombineLatentModels(models, prefixes)
end
end

Expand All @@ -49,7 +67,8 @@ Generate latent variables using a combination of multiple latent models.
return final_latent, (; latent_aux...)
end

@model function _accumulate_latents(models, index, acc_latent, acc_aux, n, n_models)
@model function _accumulate_latents(
models, index, acc_latent, acc_aux, n, n_models)
if index > n_models
return acc_latent, (; acc_aux...)
else
Expand Down
50 changes: 34 additions & 16 deletions EpiAware/src/EpiLatentModels/manipulators/ConcatLatentModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@ This struct is used to concatenate multiple latent models into a single latent m

# Constructors

- `ConcatLatentModels(models::M, no_models::Int, dimension_adaptor::Function) where {M <: AbstractVector{<:AbstractTuringLatentModel}}`: Constructs a `ConcatLatentModels` instance with specified models, number of models, and dimension adaptor.
- `ConcatLatentModels(models::M, dimension_adaptor::Function) where {M <: AbstractVector{<:AbstractTuringLatentModel}}`: Constructs a `ConcatLatentModels` instance with specified models and dimension adaptor, ensuring that there are at least two models. The default dimension adaptor is `equal_dimensions`.
- `ConcatLatentModels(; models::M, dimension_adaptor::Function) where {M <: AbstractVector{<:AbstractTuringLatentModel}}`: Constructs a `ConcatLatentModels` instance with specified models and dimension adaptor, ensuring that there are at least two models. The default dimension adaptor is `equal_dimensions`.
- `ConcatLatentModels(models::M, no_models::I, dimension_adaptor::F, prefixes::P) where {M <: AbstractVector{<:AbstractTuringLatentModel}, I <: Int, F <: Function, P <: AbstractVector{String}}`: Constructs a `ConcatLatentModels` instance with specified models, number of models, dimension adaptor, and prefixes.
- `ConcatLatentModels(models::M, dimension_adaptor::F; prefixes::P = \"Concat.\" * string.(1:length(models))) where {M <: AbstractVector{<:AbstractTuringLatentModel}, F <: Function}`: Constructs a `ConcatLatentModels` instance with specified models and dimension adaptor. The number of models is automatically determined as are the prefixes (of the form `Concat.1`, `Concat.2`, etc.) by default.
- `ConcatLatentModels(models::M; dimension_adaptor::Function, prefixes::P) where {M <: AbstractVector{<:AbstractTuringLatentModel}, P <: AbstractVector{String}}`: Constructs a `ConcatLatentModels` instance with specified models, dimension adaptor, prefixes, and automatically determines the number of models.The default dimension adaptor is `equal_dimensions`. The default prefixes are of the form `Concat.1`, `Concat.2`, etc.
- `ConcatLatentModels(; models::M, dimension_adaptor::Function, prefixes::P) where {M <: AbstractVector{<:AbstractTuringLatentModel}, P <: AbstractVector{String}}`: Constructs a `ConcatLatentModels` instance with specified models, dimension adaptor, prefixes, and automatically determines the number of models. The default dimension adaptor is `equal_dimensions`. The default prefixes are of the form `Concat.1`, `Concat.2`, etc.

# Examples

Expand All @@ -19,46 +20,62 @@ latent_model()
```
"
struct ConcatLatentModels{
M <: AbstractVector{<:AbstractTuringLatentModel}, N <: Int, F <: Function} <:
M <: AbstractVector{<:AbstractTuringLatentModel}, N <: Int, F <: Function, P <:
AbstractVector{<:String}} <:
AbstractTuringLatentModel
"A vector of latent models"
models::M
"The number of models in the collection"
no_models::N
"The dimension function for the latent variables. By default this divides the number of latent variables by the number of models and returns a vector of dimensions rounding up the first element and rounding down the rest."
dimension_adaptor::F
"A vector of prefixes for the latent models"
prefixes::P

function ConcatLatentModels(models::M,
no_models::I,
dimension_adaptor::F) where {
dimension_adaptor::F, prefixes::P) where {
M <: AbstractVector{<:AbstractTuringLatentModel}, I <: Int,
F <: Function}
F <: Function, P <: AbstractVector{<:String}}
@assert length(models)>1 "At least two models are required"
@assert length(models)==no_models "no_models must be equal to the number of models"
# check all dimension functions take a single n and return an integer
check_dim = dimension_adaptor(no_models, no_models)
@assert typeof(check_dim)<:AbstractVector{Int} "Output of dimension_adaptor must be a vector of integers"
@assert length(check_dim)==no_models "The vector of dimensions must have the same length as the number of models"
return new{AbstractVector{<:AbstractTuringLatentModel}, Int, Function}(
models, no_models, dimension_adaptor)
@assert length(prefixes)==no_models "The number of models and prefixes must be equal"
for i in eachindex(models)
if (prefixes[i] != "")
models[i] = PrefixLatentModel(models[i], prefixes[i])
end
end
return new{
AbstractVector{<:AbstractTuringLatentModel}, Int, Function,
AbstractVector{<:String}}(
models, no_models, dimension_adaptor, prefixes)
end

function ConcatLatentModels(models::M,
dimension_adaptor::Function) where {
function ConcatLatentModels(models::M, dimension_adaptor::Function;
prefixes = nothing) where {
M <: AbstractVector{<:AbstractTuringLatentModel}}
return ConcatLatentModels(models, length(models), dimension_adaptor)
no_models = length(models)
if isnothing(prefixes)
prefixes = "Concat." .* string.(1:no_models)
end
return ConcatLatentModels(models, no_models, dimension_adaptor, prefixes)
end

function ConcatLatentModels(models::M;
dimension_adaptor::Function = equal_dimensions) where {
dimension_adaptor::Function = equal_dimensions,
prefixes = nothing) where {
M <: AbstractVector{<:AbstractTuringLatentModel}}
return ConcatLatentModels(models, dimension_adaptor)
return ConcatLatentModels(models, dimension_adaptor; prefixes = prefixes)
end

function ConcatLatentModels(; models::M,
dimension_adaptor::Function = equal_dimensions) where {
dimension_adaptor::Function = equal_dimensions, prefixes = nothing) where {
M <: AbstractVector{<:AbstractTuringLatentModel}}
return ConcatLatentModels(models, dimension_adaptor)
return ConcatLatentModels(models, dimension_adaptor; prefixes = prefixes)
end
end

Expand Down Expand Up @@ -102,7 +119,8 @@ Generate latent variables by concatenating multiple latent models.
end

@model function _concat_latents(
models, index::Int, acc_latent, acc_aux, dims::AbstractVector{<:Int}, n_models::Int)
models, index::Int, acc_latent, acc_aux,
dims::AbstractVector{<:Int}, n_models::Int)
if index > n_models
return acc_latent, (; acc_aux...)
else
Expand Down
6 changes: 3 additions & 3 deletions EpiAware/src/EpiLatentModels/modifiers/DiffLatentModel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ terms ``Z_1, \ldots, Z_d`` are inferred.

## Constructors

- `DiffLatentModel(latentmodel, init_prior_distribution::Distribution; d::Int)`
Constructs a `DiffLatentModel` for `d`-fold differencing with `latentmodel` as the
- `DiffLatentModel(latent_model, init_prior_distribution::Distribution; d::Int)`
Constructs a `DiffLatentModel` for `d`-fold differencing with `latent_model` as the
undifferenced latent process. All initial terms have common prior
`init_prior_distribution`.
- `DiffLatentModel(;model, init_priors::Vector{D} where {D <: Distribution})`
Constructs a `DiffLatentModel` for `d`-fold differencing with `latentmodel` as the
Constructs a `DiffLatentModel` for `d`-fold differencing with `latent_model` as the
undifferenced latent process. The `d` initial terms have priors given by the vector
`init_priors`, therefore `length(init_priors)` sets `d`.

Expand Down
28 changes: 28 additions & 0 deletions EpiAware/src/EpiLatentModels/modifiers/PrefixLatentModel.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
@doc raw"
Generate a latent model with a prefix. A lightweight wrapper around `EpiAwareUtils.prefix_submodel`.

# Constructors
- `PrefixLatentModel(model::M, prefix::P)`: Create a `PrefixLatentModel` with the latent model `model` and the prefix `prefix`.
- `PrefixLatentModel(; model::M, prefix::P)`: Create a `PrefixLatentModel` with the latent model `model` and the prefix `prefix`.

# Examples
```julia
using EpiAware
latent_model = PrefixLatentModel(model = HierarchicalNormal(), prefix = \"Test\")
mdl = generate_latent(latent_model, 10)
rand(mdl)
```
"
@kwdef struct PrefixLatentModel{M <: AbstractTuringLatentModel, P <: String} <:
AbstractTuringLatentModel
"The latent model"
model::M
"The prefix for the latent model"
prefix::P
end

@model function EpiAwareBase.generate_latent(latent_model::PrefixLatentModel, n)
@submodel submodel = prefix_submodel(
latent_model.model, generate_latent, latent_model.prefix, n)
return submodel
end
16 changes: 10 additions & 6 deletions EpiAware/src/EpiObsModels/EpiObsModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ module EpiObsModels

using ..EpiAwareBase

using ..EpiAwareUtils: censored_pmf, HalfNormal
using ..EpiAwareUtils: censored_pmf, HalfNormal, prefix_submodel

using ..EpiLatentModels: HierarchicalNormal, broadcast_dayofweek
using ..EpiLatentModels: HierarchicalNormal, broadcast_dayofweek, PrefixLatentModel

using Turing, Distributions, DocStringExtensions, SparseArrays

Expand All @@ -18,15 +18,19 @@ export PoissonError, NegativeBinomialError
export generate_observation_error_priors, observation_error

# Observation model modifiers
export LatentDelay, Ascertainment, StackObservationModels
export LatentDelay, Ascertainment, PrefixObservationModel

# Observation model manipulators
export StackObservationModels

# helper functions
export ascertainment_dayofweek

include("docstrings.jl")
include("LatentDelay.jl")
include("ascertainment/Ascertainment.jl")
include("ascertainment/helpers.jl")
include("modifiers/LatentDelay.jl")
include("modifiers/ascertainment/Ascertainment.jl")
include("modifiers/ascertainment/helpers.jl")
include("modifiers/PrefixObservationModel.jl")
include("StackObservationModels.jl")
include("ObservationErrorModels/methods.jl")
include("ObservationErrorModels/NegativeBinomialError.jl")
Expand Down
9 changes: 6 additions & 3 deletions EpiAware/src/EpiObsModels/StackObservationModels.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
@doc raw"

A stack of observation models that are looped over to generate observations for
each model in the stack. Note that the model names are used to prefix the parameters in each model (so if I have a model named `cases` and a parameter `y_t`, the parameter in the model will be `cases.y_t`).
each model in the stack. Note that the model names are used to prefix the parameters in each model (so if I have a model named `cases` and a parameter `y_t`, the parameter in the model will be `cases.y_t`). Inside the constructor `PrefixObservationModel` is wrapped around each observation model.

## Constructors

Expand Down Expand Up @@ -48,7 +48,10 @@ deaths_y_t
N <: AbstractString
}
@assert length(models)==length(model_names) "The number of models and model names must be equal."
new{typeof(models), typeof(model_names)}(models, model_names)
wrapped_models = [PrefixObservationModel(models[i], model_names[i])
for i in eachindex(models)]
new{AbstractVector{<:AbstractTuringObservationModel}, typeof(model_names)}(
wrapped_models, model_names)
end

function StackObservationModels(models::NamedTuple{
Expand Down Expand Up @@ -77,7 +80,7 @@ Generate observations from a stack of observation models. Assumes a 1 to 1 mappi

obs = ()
for (model, model_name) in zip(obs_model.models, obs_model.model_names)
@submodel prefix=eval(model_name) obs_tmp=generate_observations(
@submodel obs_tmp = generate_observations(
model, y_t[Symbol(model_name)], Y_t[Symbol(model_name)])
obs = obs..., obs_tmp...
seabbs marked this conversation as resolved.
Show resolved Hide resolved
end
Expand Down
Loading
Loading