-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add intercept struct and method * add a basic test * add more tests for intercept * add more tests for intercept * fix spelling * rename broadcast helpers * first pass at CombineeLatentModels * spin out TransformLatentModels * fix TransformLatentModel * add tests for TransformLatentModel and fix Intercept tests for new output structure * add basic tests for CombineLatentModels * add FixedIntercpt version of Intercept * fix FixedIntecept tests and drop uses of scale where we can * fix ascertainment test * add HierarchicalNormal * add abstract method for intercepts * Add intercept struct and method * add a basic test * add more tests for intercept * add more tests for intercept * fix spelling * rename broadcast helpers * first pass at CombineeLatentModels * spin out TransformLatentModels * fix TransformLatentModel * add tests for TransformLatentModel and fix Intercept tests for new output structure * add basic tests for CombineLatentModels * add FixedIntercpt version of Intercept * fix FixedIntecept tests and drop uses of scale where we can * fix ascertainment test * add HierarchicalNormal * add abstract method for intercepts * Test conditioning on CombineLatentModels and sampling using NUTS * modify test to do KS test against theoretical posterior distribution --------- Co-authored-by: Samuel Brand <[email protected]>
- Loading branch information
1 parent
4ab6cc2
commit 4262a9b
Showing
18 changed files
with
454 additions
and
60 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
@doc raw" | ||
The `CombineLatentModels` struct. | ||
This struct is used to combine multiple latent models into a single latent model. | ||
# Constructors | ||
- `CombineLatentModels(models::M) where {M <: AbstractVector{<:AbstractTuringLatentModel}}`: Constructs a `CombineLatentModels` instance with specified models, ensuring that there are at least two models. | ||
- `CombineLatentModels(; models::M) where {M <: AbstractVector{<:AbstractTuringLatentModel}}`: Constructs a `CombineLatentModels` instance with specified models, ensuring that there are at least two models. | ||
# Examples | ||
```julia | ||
using EpiAware, Distributions | ||
combined_model = CombineLatentModels([Intercept(Normal(2, 0.2)), AR()]) | ||
latent_model = generate_latent(combined_model, 10) | ||
latent_model() | ||
``` | ||
" | ||
@kwdef struct CombineLatentModels{M <: AbstractVector{<:AbstractTuringLatentModel}} <: | ||
AbstractTuringLatentModel | ||
"A vector of latent models" | ||
models::M | ||
|
||
function CombineLatentModels(models::M) where {M <: | ||
AbstractVector{<:AbstractTuringLatentModel}} | ||
@assert length(models)>1 "At least two models are required" | ||
return new{AbstractVector{<:AbstractTuringLatentModel}}(models) | ||
end | ||
end | ||
|
||
@doc raw" | ||
Generate latent variables using a combination of multiple latent models. | ||
# Arguments | ||
- `latent_models::CombineLatentModels`: An instance of the `CombineLatentModels` type representing the collection of latent models. | ||
- `n`: The number of latent variables to generate. | ||
# Returns | ||
- `combined_latents`: The combined latent variables generated from all the models. | ||
- `latent_aux`: A tuple containing the auxiliary latent variables generated from each individual model. | ||
# Example | ||
" | ||
@model function EpiAwareBase.generate_latent(latent_models::CombineLatentModels, n) | ||
@submodel final_latent, latent_aux = _accumulate_latents( | ||
latent_models.models, 1, fill(0.0, n), [], n, length(latent_models.models)) | ||
|
||
return final_latent, (; latent_aux...) | ||
end | ||
|
||
@model function _accumulate_latents(models, index, acc_latent, acc_aux, n, n_models) | ||
if index > n_models | ||
return acc_latent, (; acc_aux...) | ||
else | ||
@submodel latent, new_aux = generate_latent(models[index], n) | ||
@submodel updated_latent, updated_aux = _accumulate_latents( | ||
models, index + 1, acc_latent .+ latent, | ||
(; acc_aux..., new_aux...), n, n_models) | ||
return updated_latent, (; updated_aux...) | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
@doc raw" | ||
The `HierarchicalNormal` struct represents a non-centered hierarchical normal distribution. | ||
## Constructors | ||
- `HierarchicalNormal(mean, std_prior)`: Constructs a `HierarchicalNormal` instance with the specified mean and standard deviation prior. | ||
- `HierarchicalNormal(; mean = 0.0, std_prior = truncated(Normal(0,1), 0, Inf))`: Constructs a `HierarchicalNormal` instance with the specified mean and standard deviation prior using named arguments and with default values. | ||
## Examples | ||
```julia | ||
using Distributions, EpiAware | ||
hnorm = HierarchicalNormal(0.0, truncated(Normal(0, 1), 0, Inf)) | ||
hnorm_model = generate_latent(hnorm, 10) | ||
hnorm_model() | ||
``` | ||
" | ||
@kwdef struct HierarchicalNormal{R <: Real, D <: Sampleable} <: AbstractTuringLatentModel | ||
mean::R = 0.0 | ||
std_prior::D = truncated(Normal(0, 1), 0, Inf) | ||
end | ||
|
||
@doc raw" | ||
function EpiAwareBase.generate_latent(obs_model::HierarchicalNormal, n) | ||
Generate latent variables from the hierarchical normal distribution. | ||
# Arguments | ||
- `obs_model::HierarchicalNormal`: The hierarchical normal distribution model. | ||
- `n`: Number of latent variables to generate. | ||
# Returns | ||
- `η_t`: Generated latent variables. | ||
- `std`: Standard deviation used in the generation. | ||
" | ||
@model function EpiAwareBase.generate_latent(obs_model::HierarchicalNormal, n) | ||
std ~ obs_model.std_prior | ||
ϵ_t ~ MvNormal(I(n)) | ||
η_t = obs_model.mean .+ std .* ϵ_t | ||
return η_t, (; std = std) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
@doc raw" | ||
The `Intercept` struct is used to model the intercept of a latent process. It | ||
broadcasts a single intercept value to a length `n` latent process. | ||
## Constructors | ||
- Intercept(intercept_prior) | ||
- Intercept(; intercept_prior) | ||
## Examples | ||
```julia | ||
using Distributions, Turing, EpiAware | ||
int = Intercept(Normal(0, 1)) | ||
int_model = generate_latent(int, 10) | ||
rand(int_model) | ||
int_model() | ||
``` | ||
" | ||
@kwdef struct Intercept{D <: Sampleable} <: AbstractTuringIntercept | ||
"Prior distribution for the intercept." | ||
intercept_prior::D | ||
end | ||
|
||
@doc raw" | ||
Generate a latent intercept series. | ||
# Arguments | ||
- `latent_model::Intercept`: The intercept model. | ||
- `n::Int`: The length of the intercept series. | ||
# Returns | ||
- `intercept::Vector{Float64}`: The generated intercept series. | ||
- `metadata::NamedTuple`: A named tuple containing the intercept value. | ||
" | ||
@model function EpiAwareBase.generate_latent(latent_model::Intercept, n) | ||
intercept ~ latent_model.intercept_prior | ||
return fill(intercept, n), (; intercept = intercept) | ||
end | ||
|
||
@doc raw" | ||
A variant of the `Intercept` struct that represents a fixed intercept value for a latent model. | ||
# Constructors | ||
- `FixedIntercept(intercept)` : Constructs a `FixedIntercept` instance with the specified intercept value. | ||
- `FixedIntercept(; intercept)` : Constructs a `FixedIntercept` instance with the specified intercept value using named arguments. | ||
# Examples | ||
```julia | ||
using EpiAware | ||
fi = FixedIntercept(2.0) | ||
fi_model = generate_latent(fi, 10) | ||
fi_model() | ||
``` | ||
" | ||
@kwdef struct FixedIntercept{F <: Real} <: AbstractTuringIntercept | ||
intercept::F | ||
end | ||
|
||
@doc raw" | ||
Generate a latent intercept series with a fixed intercept value. | ||
# Arguments | ||
- `latent_model::FixedIntercept`: The fixed intercept latent model. | ||
- `n`: The number of latent variables to generate. | ||
# Returns | ||
- `latent_vars`: An array of length `n` filled with the fixed intercept value. | ||
- `metadata`: A named tuple containing the intercept value. | ||
" | ||
@model function EpiAwareBase.generate_latent(latent_model::FixedIntercept, n) | ||
return fill(latent_model.intercept, n), (; intercept = latent_model.intercept) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
@doc raw" | ||
The `TransformLatentModel` struct represents a latent model that applies a transformation function to the latent variables generated by another latent model. | ||
## Constructors | ||
- `TransformLatentModel(model, trans_function)`: Constructs a `TransformLatentModel` instance with the specified latent model and transformation function. | ||
- `TransformLatentModel(; model, trans_function)`: Constructs a `TransformLatentModel` instance with the specified latent model and transformation function using named arguments. | ||
## Example | ||
```julia | ||
using EpiAware, Distributions | ||
trans = TransformLatentModel(Intercept(Normal(2, 0.2)), x -> x .|> exp) | ||
trans_model = generate_latent(trans, 5) | ||
trans_model() | ||
``` | ||
" | ||
@kwdef struct TransformLatentModel{M <: AbstractTuringLatentModel, F <: Function} <: | ||
AbstractTuringLatentModel | ||
"The latent model to transform." | ||
model::M | ||
"The transformation function." | ||
trans_function::F | ||
end | ||
|
||
""" | ||
generate_latent(model::TransformLatentModel, n) | ||
Generate latent variables using the specified `TransformLatentModel`. | ||
# Arguments | ||
- `model::TransformLatentModel`: The `TransformLatentModel` to generate latent variables from. | ||
- `n`: The number of latent variables to generate. | ||
# Returns | ||
- `transformed`: The transformed latent variables. | ||
- `latent_aux`: Additional auxiliary variables generated by the underlying latent model. | ||
""" | ||
@model function EpiAwareBase.generate_latent(model::TransformLatentModel, n) | ||
@submodel untransformed, latent_aux = generate_latent(model.model, n) | ||
latent = model.trans_function(untransformed) | ||
return latent, (; latent_aux) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.