From 1042c12b8b42c2f14ed5058bb046a68111b0791e Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 10 Jun 2024 15:25:36 +0100 Subject: [PATCH 01/17] add to package and reorg --- .../src/EpiLatentModels/EpiLatentModels.jl | 29 +++--- .../{ => manipulators}/CombineLatentModels.jl | 0 .../manipulators/ConcatLatentModels.jl | 96 +++++++++++++++++++ .../broadcast/LatentModel.jl | 0 .../{ => manipulators}/broadcast/helpers.jl | 0 .../{ => manipulators}/broadcast/rules.jl | 0 .../src/EpiLatentModels/{ => models}/AR.jl | 0 .../{ => models}/HierarchicalNormal.jl | 0 .../EpiLatentModels/{ => models}/Intercept.jl | 0 .../{ => models}/RandomWalk.jl | 0 .../{ => modifiers}/DiffLatentModel.jl | 0 .../{ => modifiers}/TransformLatentModel.jl | 0 .../{ => manipulators}/CombineLatentModels.jl | 0 .../manipulators/ConcatLatentModels.jl | 1 + .../test/EpiLatentModels/{ => models}/AR.jl | 0 .../{ => models}/FixedIntercept.jl | 0 .../{ => models}/HierarchicalNormal.jl | 0 .../EpiLatentModels/{ => models}/Intercept.jl | 0 .../{ => models}/RandomWalk.jl | 0 .../{ => modifiers}/DiffLatentModel.jl | 0 .../{ => modifiers}/TransformLatentModel.jl | 0 21 files changed, 114 insertions(+), 12 deletions(-) rename EpiAware/src/EpiLatentModels/{ => manipulators}/CombineLatentModels.jl (100%) create mode 100644 EpiAware/src/EpiLatentModels/manipulators/ConcatLatentModels.jl rename EpiAware/src/EpiLatentModels/{ => manipulators}/broadcast/LatentModel.jl (100%) rename EpiAware/src/EpiLatentModels/{ => manipulators}/broadcast/helpers.jl (100%) rename EpiAware/src/EpiLatentModels/{ => manipulators}/broadcast/rules.jl (100%) rename EpiAware/src/EpiLatentModels/{ => models}/AR.jl (100%) rename EpiAware/src/EpiLatentModels/{ => models}/HierarchicalNormal.jl (100%) rename EpiAware/src/EpiLatentModels/{ => models}/Intercept.jl (100%) rename EpiAware/src/EpiLatentModels/{ => models}/RandomWalk.jl (100%) rename EpiAware/src/EpiLatentModels/{ => modifiers}/DiffLatentModel.jl (100%) rename EpiAware/src/EpiLatentModels/{ => modifiers}/TransformLatentModel.jl (100%) rename EpiAware/test/EpiLatentModels/{ => manipulators}/CombineLatentModels.jl (100%) create mode 100644 EpiAware/test/EpiLatentModels/manipulators/ConcatLatentModels.jl rename EpiAware/test/EpiLatentModels/{ => models}/AR.jl (100%) rename EpiAware/test/EpiLatentModels/{ => models}/FixedIntercept.jl (100%) rename EpiAware/test/EpiLatentModels/{ => models}/HierarchicalNormal.jl (100%) rename EpiAware/test/EpiLatentModels/{ => models}/Intercept.jl (100%) rename EpiAware/test/EpiLatentModels/{ => models}/RandomWalk.jl (100%) rename EpiAware/test/EpiLatentModels/{ => modifiers}/DiffLatentModel.jl (100%) rename EpiAware/test/EpiLatentModels/{ => modifiers}/TransformLatentModel.jl (100%) diff --git a/EpiAware/src/EpiLatentModels/EpiLatentModels.jl b/EpiAware/src/EpiLatentModels/EpiLatentModels.jl index bd45990a5..a3e71764b 100644 --- a/EpiAware/src/EpiLatentModels/EpiLatentModels.jl +++ b/EpiAware/src/EpiLatentModels/EpiLatentModels.jl @@ -17,25 +17,30 @@ using Turing, Distributions, DocStringExtensions, LinearAlgebra export FixedIntercept, Intercept, RandomWalk, AR, HierarchicalNormal # Export tools for manipulating latent models -export CombineLatentModels, TransformLatentModel, DiffLatentModel, BroadcastLatentModel +export CombineLatentModels, ConcatLatentModels, BroadcastLatentModel # Export broadcast rules export RepeatEach, RepeatBlock # Export helper functions -export broadcast_dayofweek, broadcast_weekly +export broadcast_dayofweek, broadcast_weekly, equal_dimensions + +# Export tools for modifying latent models +export DiffLatentModel, TransformLatentModel include("docstrings.jl") -include("Intercept.jl") -include("RandomWalk.jl") -include("AR.jl") -include("HierarchicalNormal.jl") -include("CombineLatentModels.jl") -include("TransformLatentModel.jl") -include("DiffLatentModel.jl") -include("broadcast/LatentModel.jl") -include("broadcast/rules.jl") -include("broadcast/helpers.jl") +include("models/Intercept.jl") +include("models/RandomWalk.jl") +include("models/AR.jl") +include("models/HierarchicalNormal.jl") +include("modifiers/DiffLatentModel.jl") +include("modifiers/TransformLatentModel.jl") +include("manipulators/CombineLatentModels.jl") +include("manipulators/ConcatLatentModels.jl") + +include("manipulators/broadcast/LatentModel.jl") +include("manipulators/broadcast/rules.jl") +include("manipulators/broadcast/helpers.jl") include("utils.jl") end diff --git a/EpiAware/src/EpiLatentModels/CombineLatentModels.jl b/EpiAware/src/EpiLatentModels/manipulators/CombineLatentModels.jl similarity index 100% rename from EpiAware/src/EpiLatentModels/CombineLatentModels.jl rename to EpiAware/src/EpiLatentModels/manipulators/CombineLatentModels.jl diff --git a/EpiAware/src/EpiLatentModels/manipulators/ConcatLatentModels.jl b/EpiAware/src/EpiLatentModels/manipulators/ConcatLatentModels.jl new file mode 100644 index 000000000..c4ec2db55 --- /dev/null +++ b/EpiAware/src/EpiLatentModels/manipulators/ConcatLatentModels.jl @@ -0,0 +1,96 @@ +@doc raw" +The `ConcatLatentModels` struct. + +This struct is used to concatenate multiple latent models into a single latent model. + +# Constructors + +- `ConcatLatentModels(models::M, no_models::Int, dimension_adapter::Function) where {M <: AbstractVector{<:AbstractTuringLatentModel}}`: Constructs a `ConcatLatentModels` instance with specified models, number of models, and dimension adapter. The default dimension adapter is `equal_dimensions`. +- `ConcatLatentModels(models::M, dimension_adapter::Function) where {M <: AbstractVector{<:AbstractTuringLatentModel}}`: Constructs a `ConcatLatentModels` instance with specified models and dimension adapter, ensuring that there are at least two models. The default dimension adapter is `equal_dimensions`. +- `ConcatLatentModels(; models::M, dimension_adapter::Function) where {M <: AbstractVector{<:AbstractTuringLatentModel}}`: Constructs a `ConcatLatentModels` instance with specified models and dimension adapter, ensuring that there are at least two models. The default dimension adapter is `equal_dimensions`. + +# Examples + +```julia +using EpiAware, Distributions +combined_model = ConcatLatentModels([Intercept(Normal(2, 0.2)), AR()]) +latent_model = generate_latent(combined_model, 10) +latent_model() +" +struct ConcatLatentModels{ + M <: AbstractVector{<:AbstractTuringLatentModel}, I <: Int, F <: Function} <: + AbstractTuringLatentModel + "A vector of latent models" + models::M + "The number of models in the collection" + no_models::I + "The dimension function for the latent variables. By default this divides the number of latent variables by the number of models and returns a vector of dimensions rounding up the first element and rounding down the rest." + dimension_adaptor::F = equal_dimensions + + function CombineLatentModels(models::M, + no_models::Int, + dimension_adaptor::F) where { + M <: + AbstractVector{<:AbstractTuringLatentModel}, I <: Int, F <: Function} + @assert length(models)>1 "At least two models are required" + @assert length(models)==no_models "no_models must be equal to the number of models" + # check all dimension functions take a single n and return an integer + check_dim = dimension_adaptor(no_models, no_models) + @assert all(isinteger, check_dim) + @assert all(x -> x > 0, check_dim) + @assert sum(check_dim) == no_models + @assert length(check_dim) == no_models + return new{AbstractVector{<:AbstractTuringLatentModel}, Int, Function}( + models, no_models, dimension_adaptor) + end + + function ConcatLatentModels(models::M, + dimension_adapter::Function = equal_dimensions) where {M <: + AbstractVector{<:AbstractTuringLatentModel}} + no_models = length(models) + return ConcatLatentModels(models, length(models), dimension_adapter) + end + + function ConcatLatentModels(; models::M, + dimension_adapter::Function = equal_dimensions) where {M <: + AbstractVector{<:AbstractTuringLatentModel}} + return ConcatLatentModels(models, dimension_adapter) + end +end + +function equal_dimensions(n::Int, m::Int) + return vcat(ceil(n / m), fill(floor(n / m), m - 1)) +end + +@doc raw" +Generate latent variables by concatenating multiple latent models. + +# Arguments +- `latent_models::ConcatLatentModels`: An instance of the `ConcatLatentModels` type representing the collection of latent models. +- `n`: The number of latent variables to generate. + +# Returns +- `concatenated_latents`: The combined latent variables generated from all the models. +- `latent_aux`: A tuple containing the auxiliary latent variables generated from each individual model. +" +@model function EpiAwareBase.generate_latent(latent_models::ConcatLatentModels, n) + @assert latent_models.no_models n_models + return acc_latent, (; acc_aux...) + else + @submodel latent, new_aux = generate_latent(models[index], dims[index]) + @submodel updated_latent, updated_aux = _concat_latents( + models, index + 1, vcat(acc_latent, latent), + (; acc_aux..., new_aux...), dims, n_models) + return updated_latent, (; updated_aux...) + end +end diff --git a/EpiAware/src/EpiLatentModels/broadcast/LatentModel.jl b/EpiAware/src/EpiLatentModels/manipulators/broadcast/LatentModel.jl similarity index 100% rename from EpiAware/src/EpiLatentModels/broadcast/LatentModel.jl rename to EpiAware/src/EpiLatentModels/manipulators/broadcast/LatentModel.jl diff --git a/EpiAware/src/EpiLatentModels/broadcast/helpers.jl b/EpiAware/src/EpiLatentModels/manipulators/broadcast/helpers.jl similarity index 100% rename from EpiAware/src/EpiLatentModels/broadcast/helpers.jl rename to EpiAware/src/EpiLatentModels/manipulators/broadcast/helpers.jl diff --git a/EpiAware/src/EpiLatentModels/broadcast/rules.jl b/EpiAware/src/EpiLatentModels/manipulators/broadcast/rules.jl similarity index 100% rename from EpiAware/src/EpiLatentModels/broadcast/rules.jl rename to EpiAware/src/EpiLatentModels/manipulators/broadcast/rules.jl diff --git a/EpiAware/src/EpiLatentModels/AR.jl b/EpiAware/src/EpiLatentModels/models/AR.jl similarity index 100% rename from EpiAware/src/EpiLatentModels/AR.jl rename to EpiAware/src/EpiLatentModels/models/AR.jl diff --git a/EpiAware/src/EpiLatentModels/HierarchicalNormal.jl b/EpiAware/src/EpiLatentModels/models/HierarchicalNormal.jl similarity index 100% rename from EpiAware/src/EpiLatentModels/HierarchicalNormal.jl rename to EpiAware/src/EpiLatentModels/models/HierarchicalNormal.jl diff --git a/EpiAware/src/EpiLatentModels/Intercept.jl b/EpiAware/src/EpiLatentModels/models/Intercept.jl similarity index 100% rename from EpiAware/src/EpiLatentModels/Intercept.jl rename to EpiAware/src/EpiLatentModels/models/Intercept.jl diff --git a/EpiAware/src/EpiLatentModels/RandomWalk.jl b/EpiAware/src/EpiLatentModels/models/RandomWalk.jl similarity index 100% rename from EpiAware/src/EpiLatentModels/RandomWalk.jl rename to EpiAware/src/EpiLatentModels/models/RandomWalk.jl diff --git a/EpiAware/src/EpiLatentModels/DiffLatentModel.jl b/EpiAware/src/EpiLatentModels/modifiers/DiffLatentModel.jl similarity index 100% rename from EpiAware/src/EpiLatentModels/DiffLatentModel.jl rename to EpiAware/src/EpiLatentModels/modifiers/DiffLatentModel.jl diff --git a/EpiAware/src/EpiLatentModels/TransformLatentModel.jl b/EpiAware/src/EpiLatentModels/modifiers/TransformLatentModel.jl similarity index 100% rename from EpiAware/src/EpiLatentModels/TransformLatentModel.jl rename to EpiAware/src/EpiLatentModels/modifiers/TransformLatentModel.jl diff --git a/EpiAware/test/EpiLatentModels/CombineLatentModels.jl b/EpiAware/test/EpiLatentModels/manipulators/CombineLatentModels.jl similarity index 100% rename from EpiAware/test/EpiLatentModels/CombineLatentModels.jl rename to EpiAware/test/EpiLatentModels/manipulators/CombineLatentModels.jl diff --git a/EpiAware/test/EpiLatentModels/manipulators/ConcatLatentModels.jl b/EpiAware/test/EpiLatentModels/manipulators/ConcatLatentModels.jl new file mode 100644 index 000000000..ae3e92f35 --- /dev/null +++ b/EpiAware/test/EpiLatentModels/manipulators/ConcatLatentModels.jl @@ -0,0 +1 @@ +#using TestEnv; TestEnv.activate() diff --git a/EpiAware/test/EpiLatentModels/AR.jl b/EpiAware/test/EpiLatentModels/models/AR.jl similarity index 100% rename from EpiAware/test/EpiLatentModels/AR.jl rename to EpiAware/test/EpiLatentModels/models/AR.jl diff --git a/EpiAware/test/EpiLatentModels/FixedIntercept.jl b/EpiAware/test/EpiLatentModels/models/FixedIntercept.jl similarity index 100% rename from EpiAware/test/EpiLatentModels/FixedIntercept.jl rename to EpiAware/test/EpiLatentModels/models/FixedIntercept.jl diff --git a/EpiAware/test/EpiLatentModels/HierarchicalNormal.jl b/EpiAware/test/EpiLatentModels/models/HierarchicalNormal.jl similarity index 100% rename from EpiAware/test/EpiLatentModels/HierarchicalNormal.jl rename to EpiAware/test/EpiLatentModels/models/HierarchicalNormal.jl diff --git a/EpiAware/test/EpiLatentModels/Intercept.jl b/EpiAware/test/EpiLatentModels/models/Intercept.jl similarity index 100% rename from EpiAware/test/EpiLatentModels/Intercept.jl rename to EpiAware/test/EpiLatentModels/models/Intercept.jl diff --git a/EpiAware/test/EpiLatentModels/RandomWalk.jl b/EpiAware/test/EpiLatentModels/models/RandomWalk.jl similarity index 100% rename from EpiAware/test/EpiLatentModels/RandomWalk.jl rename to EpiAware/test/EpiLatentModels/models/RandomWalk.jl diff --git a/EpiAware/test/EpiLatentModels/DiffLatentModel.jl b/EpiAware/test/EpiLatentModels/modifiers/DiffLatentModel.jl similarity index 100% rename from EpiAware/test/EpiLatentModels/DiffLatentModel.jl rename to EpiAware/test/EpiLatentModels/modifiers/DiffLatentModel.jl diff --git a/EpiAware/test/EpiLatentModels/TransformLatentModel.jl b/EpiAware/test/EpiLatentModels/modifiers/TransformLatentModel.jl similarity index 100% rename from EpiAware/test/EpiLatentModels/TransformLatentModel.jl rename to EpiAware/test/EpiLatentModels/modifiers/TransformLatentModel.jl From 3c9e97c6bd571256250d5b41a6b917443a2a998d Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 10 Jun 2024 15:27:54 +0100 Subject: [PATCH 02/17] Revert "add to package and reorg" This reverts commit 31b89d39050b3d244e5be3ddd4c7363b8af64809. --- .../src/EpiLatentModels/{models => }/AR.jl | 0 .../{manipulators => }/CombineLatentModels.jl | 0 .../{modifiers => }/DiffLatentModel.jl | 0 .../src/EpiLatentModels/EpiLatentModels.jl | 29 +++--- .../{models => }/HierarchicalNormal.jl | 0 .../EpiLatentModels/{models => }/Intercept.jl | 0 .../{models => }/RandomWalk.jl | 0 .../{modifiers => }/TransformLatentModel.jl | 0 .../broadcast/LatentModel.jl | 0 .../{manipulators => }/broadcast/helpers.jl | 0 .../{manipulators => }/broadcast/rules.jl | 0 .../manipulators/ConcatLatentModels.jl | 96 ------------------- .../test/EpiLatentModels/{models => }/AR.jl | 0 .../{manipulators => }/CombineLatentModels.jl | 0 .../{modifiers => }/DiffLatentModel.jl | 0 .../{models => }/FixedIntercept.jl | 0 .../{models => }/HierarchicalNormal.jl | 0 .../EpiLatentModels/{models => }/Intercept.jl | 0 .../{models => }/RandomWalk.jl | 0 .../{modifiers => }/TransformLatentModel.jl | 0 .../manipulators/ConcatLatentModels.jl | 1 - 21 files changed, 12 insertions(+), 114 deletions(-) rename EpiAware/src/EpiLatentModels/{models => }/AR.jl (100%) rename EpiAware/src/EpiLatentModels/{manipulators => }/CombineLatentModels.jl (100%) rename EpiAware/src/EpiLatentModels/{modifiers => }/DiffLatentModel.jl (100%) rename EpiAware/src/EpiLatentModels/{models => }/HierarchicalNormal.jl (100%) rename EpiAware/src/EpiLatentModels/{models => }/Intercept.jl (100%) rename EpiAware/src/EpiLatentModels/{models => }/RandomWalk.jl (100%) rename EpiAware/src/EpiLatentModels/{modifiers => }/TransformLatentModel.jl (100%) rename EpiAware/src/EpiLatentModels/{manipulators => }/broadcast/LatentModel.jl (100%) rename EpiAware/src/EpiLatentModels/{manipulators => }/broadcast/helpers.jl (100%) rename EpiAware/src/EpiLatentModels/{manipulators => }/broadcast/rules.jl (100%) delete mode 100644 EpiAware/src/EpiLatentModels/manipulators/ConcatLatentModels.jl rename EpiAware/test/EpiLatentModels/{models => }/AR.jl (100%) rename EpiAware/test/EpiLatentModels/{manipulators => }/CombineLatentModels.jl (100%) rename EpiAware/test/EpiLatentModels/{modifiers => }/DiffLatentModel.jl (100%) rename EpiAware/test/EpiLatentModels/{models => }/FixedIntercept.jl (100%) rename EpiAware/test/EpiLatentModels/{models => }/HierarchicalNormal.jl (100%) rename EpiAware/test/EpiLatentModels/{models => }/Intercept.jl (100%) rename EpiAware/test/EpiLatentModels/{models => }/RandomWalk.jl (100%) rename EpiAware/test/EpiLatentModels/{modifiers => }/TransformLatentModel.jl (100%) delete mode 100644 EpiAware/test/EpiLatentModels/manipulators/ConcatLatentModels.jl diff --git a/EpiAware/src/EpiLatentModels/models/AR.jl b/EpiAware/src/EpiLatentModels/AR.jl similarity index 100% rename from EpiAware/src/EpiLatentModels/models/AR.jl rename to EpiAware/src/EpiLatentModels/AR.jl diff --git a/EpiAware/src/EpiLatentModels/manipulators/CombineLatentModels.jl b/EpiAware/src/EpiLatentModels/CombineLatentModels.jl similarity index 100% rename from EpiAware/src/EpiLatentModels/manipulators/CombineLatentModels.jl rename to EpiAware/src/EpiLatentModels/CombineLatentModels.jl diff --git a/EpiAware/src/EpiLatentModels/modifiers/DiffLatentModel.jl b/EpiAware/src/EpiLatentModels/DiffLatentModel.jl similarity index 100% rename from EpiAware/src/EpiLatentModels/modifiers/DiffLatentModel.jl rename to EpiAware/src/EpiLatentModels/DiffLatentModel.jl diff --git a/EpiAware/src/EpiLatentModels/EpiLatentModels.jl b/EpiAware/src/EpiLatentModels/EpiLatentModels.jl index a3e71764b..bd45990a5 100644 --- a/EpiAware/src/EpiLatentModels/EpiLatentModels.jl +++ b/EpiAware/src/EpiLatentModels/EpiLatentModels.jl @@ -17,30 +17,25 @@ using Turing, Distributions, DocStringExtensions, LinearAlgebra export FixedIntercept, Intercept, RandomWalk, AR, HierarchicalNormal # Export tools for manipulating latent models -export CombineLatentModels, ConcatLatentModels, BroadcastLatentModel +export CombineLatentModels, TransformLatentModel, DiffLatentModel, BroadcastLatentModel # Export broadcast rules export RepeatEach, RepeatBlock # Export helper functions -export broadcast_dayofweek, broadcast_weekly, equal_dimensions - -# Export tools for modifying latent models -export DiffLatentModel, TransformLatentModel +export broadcast_dayofweek, broadcast_weekly include("docstrings.jl") -include("models/Intercept.jl") -include("models/RandomWalk.jl") -include("models/AR.jl") -include("models/HierarchicalNormal.jl") -include("modifiers/DiffLatentModel.jl") -include("modifiers/TransformLatentModel.jl") -include("manipulators/CombineLatentModels.jl") -include("manipulators/ConcatLatentModels.jl") - -include("manipulators/broadcast/LatentModel.jl") -include("manipulators/broadcast/rules.jl") -include("manipulators/broadcast/helpers.jl") +include("Intercept.jl") +include("RandomWalk.jl") +include("AR.jl") +include("HierarchicalNormal.jl") +include("CombineLatentModels.jl") +include("TransformLatentModel.jl") +include("DiffLatentModel.jl") +include("broadcast/LatentModel.jl") +include("broadcast/rules.jl") +include("broadcast/helpers.jl") include("utils.jl") end diff --git a/EpiAware/src/EpiLatentModels/models/HierarchicalNormal.jl b/EpiAware/src/EpiLatentModels/HierarchicalNormal.jl similarity index 100% rename from EpiAware/src/EpiLatentModels/models/HierarchicalNormal.jl rename to EpiAware/src/EpiLatentModels/HierarchicalNormal.jl diff --git a/EpiAware/src/EpiLatentModels/models/Intercept.jl b/EpiAware/src/EpiLatentModels/Intercept.jl similarity index 100% rename from EpiAware/src/EpiLatentModels/models/Intercept.jl rename to EpiAware/src/EpiLatentModels/Intercept.jl diff --git a/EpiAware/src/EpiLatentModels/models/RandomWalk.jl b/EpiAware/src/EpiLatentModels/RandomWalk.jl similarity index 100% rename from EpiAware/src/EpiLatentModels/models/RandomWalk.jl rename to EpiAware/src/EpiLatentModels/RandomWalk.jl diff --git a/EpiAware/src/EpiLatentModels/modifiers/TransformLatentModel.jl b/EpiAware/src/EpiLatentModels/TransformLatentModel.jl similarity index 100% rename from EpiAware/src/EpiLatentModels/modifiers/TransformLatentModel.jl rename to EpiAware/src/EpiLatentModels/TransformLatentModel.jl diff --git a/EpiAware/src/EpiLatentModels/manipulators/broadcast/LatentModel.jl b/EpiAware/src/EpiLatentModels/broadcast/LatentModel.jl similarity index 100% rename from EpiAware/src/EpiLatentModels/manipulators/broadcast/LatentModel.jl rename to EpiAware/src/EpiLatentModels/broadcast/LatentModel.jl diff --git a/EpiAware/src/EpiLatentModels/manipulators/broadcast/helpers.jl b/EpiAware/src/EpiLatentModels/broadcast/helpers.jl similarity index 100% rename from EpiAware/src/EpiLatentModels/manipulators/broadcast/helpers.jl rename to EpiAware/src/EpiLatentModels/broadcast/helpers.jl diff --git a/EpiAware/src/EpiLatentModels/manipulators/broadcast/rules.jl b/EpiAware/src/EpiLatentModels/broadcast/rules.jl similarity index 100% rename from EpiAware/src/EpiLatentModels/manipulators/broadcast/rules.jl rename to EpiAware/src/EpiLatentModels/broadcast/rules.jl diff --git a/EpiAware/src/EpiLatentModels/manipulators/ConcatLatentModels.jl b/EpiAware/src/EpiLatentModels/manipulators/ConcatLatentModels.jl deleted file mode 100644 index c4ec2db55..000000000 --- a/EpiAware/src/EpiLatentModels/manipulators/ConcatLatentModels.jl +++ /dev/null @@ -1,96 +0,0 @@ -@doc raw" -The `ConcatLatentModels` struct. - -This struct is used to concatenate multiple latent models into a single latent model. - -# Constructors - -- `ConcatLatentModels(models::M, no_models::Int, dimension_adapter::Function) where {M <: AbstractVector{<:AbstractTuringLatentModel}}`: Constructs a `ConcatLatentModels` instance with specified models, number of models, and dimension adapter. The default dimension adapter is `equal_dimensions`. -- `ConcatLatentModels(models::M, dimension_adapter::Function) where {M <: AbstractVector{<:AbstractTuringLatentModel}}`: Constructs a `ConcatLatentModels` instance with specified models and dimension adapter, ensuring that there are at least two models. The default dimension adapter is `equal_dimensions`. -- `ConcatLatentModels(; models::M, dimension_adapter::Function) where {M <: AbstractVector{<:AbstractTuringLatentModel}}`: Constructs a `ConcatLatentModels` instance with specified models and dimension adapter, ensuring that there are at least two models. The default dimension adapter is `equal_dimensions`. - -# Examples - -```julia -using EpiAware, Distributions -combined_model = ConcatLatentModels([Intercept(Normal(2, 0.2)), AR()]) -latent_model = generate_latent(combined_model, 10) -latent_model() -" -struct ConcatLatentModels{ - M <: AbstractVector{<:AbstractTuringLatentModel}, I <: Int, F <: Function} <: - AbstractTuringLatentModel - "A vector of latent models" - models::M - "The number of models in the collection" - no_models::I - "The dimension function for the latent variables. By default this divides the number of latent variables by the number of models and returns a vector of dimensions rounding up the first element and rounding down the rest." - dimension_adaptor::F = equal_dimensions - - function CombineLatentModels(models::M, - no_models::Int, - dimension_adaptor::F) where { - M <: - AbstractVector{<:AbstractTuringLatentModel}, I <: Int, F <: Function} - @assert length(models)>1 "At least two models are required" - @assert length(models)==no_models "no_models must be equal to the number of models" - # check all dimension functions take a single n and return an integer - check_dim = dimension_adaptor(no_models, no_models) - @assert all(isinteger, check_dim) - @assert all(x -> x > 0, check_dim) - @assert sum(check_dim) == no_models - @assert length(check_dim) == no_models - return new{AbstractVector{<:AbstractTuringLatentModel}, Int, Function}( - models, no_models, dimension_adaptor) - end - - function ConcatLatentModels(models::M, - dimension_adapter::Function = equal_dimensions) where {M <: - AbstractVector{<:AbstractTuringLatentModel}} - no_models = length(models) - return ConcatLatentModels(models, length(models), dimension_adapter) - end - - function ConcatLatentModels(; models::M, - dimension_adapter::Function = equal_dimensions) where {M <: - AbstractVector{<:AbstractTuringLatentModel}} - return ConcatLatentModels(models, dimension_adapter) - end -end - -function equal_dimensions(n::Int, m::Int) - return vcat(ceil(n / m), fill(floor(n / m), m - 1)) -end - -@doc raw" -Generate latent variables by concatenating multiple latent models. - -# Arguments -- `latent_models::ConcatLatentModels`: An instance of the `ConcatLatentModels` type representing the collection of latent models. -- `n`: The number of latent variables to generate. - -# Returns -- `concatenated_latents`: The combined latent variables generated from all the models. -- `latent_aux`: A tuple containing the auxiliary latent variables generated from each individual model. -" -@model function EpiAwareBase.generate_latent(latent_models::ConcatLatentModels, n) - @assert latent_models.no_models n_models - return acc_latent, (; acc_aux...) - else - @submodel latent, new_aux = generate_latent(models[index], dims[index]) - @submodel updated_latent, updated_aux = _concat_latents( - models, index + 1, vcat(acc_latent, latent), - (; acc_aux..., new_aux...), dims, n_models) - return updated_latent, (; updated_aux...) - end -end diff --git a/EpiAware/test/EpiLatentModels/models/AR.jl b/EpiAware/test/EpiLatentModels/AR.jl similarity index 100% rename from EpiAware/test/EpiLatentModels/models/AR.jl rename to EpiAware/test/EpiLatentModels/AR.jl diff --git a/EpiAware/test/EpiLatentModels/manipulators/CombineLatentModels.jl b/EpiAware/test/EpiLatentModels/CombineLatentModels.jl similarity index 100% rename from EpiAware/test/EpiLatentModels/manipulators/CombineLatentModels.jl rename to EpiAware/test/EpiLatentModels/CombineLatentModels.jl diff --git a/EpiAware/test/EpiLatentModels/modifiers/DiffLatentModel.jl b/EpiAware/test/EpiLatentModels/DiffLatentModel.jl similarity index 100% rename from EpiAware/test/EpiLatentModels/modifiers/DiffLatentModel.jl rename to EpiAware/test/EpiLatentModels/DiffLatentModel.jl diff --git a/EpiAware/test/EpiLatentModels/models/FixedIntercept.jl b/EpiAware/test/EpiLatentModels/FixedIntercept.jl similarity index 100% rename from EpiAware/test/EpiLatentModels/models/FixedIntercept.jl rename to EpiAware/test/EpiLatentModels/FixedIntercept.jl diff --git a/EpiAware/test/EpiLatentModels/models/HierarchicalNormal.jl b/EpiAware/test/EpiLatentModels/HierarchicalNormal.jl similarity index 100% rename from EpiAware/test/EpiLatentModels/models/HierarchicalNormal.jl rename to EpiAware/test/EpiLatentModels/HierarchicalNormal.jl diff --git a/EpiAware/test/EpiLatentModels/models/Intercept.jl b/EpiAware/test/EpiLatentModels/Intercept.jl similarity index 100% rename from EpiAware/test/EpiLatentModels/models/Intercept.jl rename to EpiAware/test/EpiLatentModels/Intercept.jl diff --git a/EpiAware/test/EpiLatentModels/models/RandomWalk.jl b/EpiAware/test/EpiLatentModels/RandomWalk.jl similarity index 100% rename from EpiAware/test/EpiLatentModels/models/RandomWalk.jl rename to EpiAware/test/EpiLatentModels/RandomWalk.jl diff --git a/EpiAware/test/EpiLatentModels/modifiers/TransformLatentModel.jl b/EpiAware/test/EpiLatentModels/TransformLatentModel.jl similarity index 100% rename from EpiAware/test/EpiLatentModels/modifiers/TransformLatentModel.jl rename to EpiAware/test/EpiLatentModels/TransformLatentModel.jl diff --git a/EpiAware/test/EpiLatentModels/manipulators/ConcatLatentModels.jl b/EpiAware/test/EpiLatentModels/manipulators/ConcatLatentModels.jl deleted file mode 100644 index ae3e92f35..000000000 --- a/EpiAware/test/EpiLatentModels/manipulators/ConcatLatentModels.jl +++ /dev/null @@ -1 +0,0 @@ -#using TestEnv; TestEnv.activate() From d47c1baeac8fa470fd365126475afca5b520a159 Mon Sep 17 00:00:00 2001 From: Sam Date: Sun, 9 Jun 2024 23:52:03 +0100 Subject: [PATCH 03/17] add ObservationErrorModel type and functions --- .../AbstractTuringObservationModel.jl | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 EpiAware/src/EpiObsModels/ObservationErrorModels/AbstractTuringObservationModel.jl diff --git a/EpiAware/src/EpiObsModels/ObservationErrorModels/AbstractTuringObservationModel.jl b/EpiAware/src/EpiObsModels/ObservationErrorModels/AbstractTuringObservationModel.jl new file mode 100644 index 000000000..f9a2d168c --- /dev/null +++ b/EpiAware/src/EpiObsModels/ObservationErrorModels/AbstractTuringObservationModel.jl @@ -0,0 +1,30 @@ +abstract type AbstractTuringObservationErrorModel <: AbstractTuringObservationModel end + +@model function EpiAwareBase.generate_observations( + obs_model::AbstractTuringObservationErrorModel, + y_t, + Y_t) + @submodel priors = generate_obs_error_priors(obs_model, y_t, Y_t) + + if ismissing(y_t) + y_t = Vector{Int}(undef, length(Y_t)) + end + + Y_y = length(y_t) - length(Y_t) + + for i in eachindex(Y_t) + y_t[Y_y + i] ~ obs_error(obs_model, Y_t[i] + 1e-6) + end + + return y_t, priors +end + +@model function generate_obs_error_priors( + obs_model::AbstractTuringObservationErrorModel, y_t, Y_t) + return NamedTuple() +end + +function obs_error(obs_model::AbstractTuringObservationErrorModel, Y_t) + @info "No concrete implementation for `_apply_method` is defined." + return nothing +end From 2f890a49d5194f00bc29e9a3349a3644025a188f Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 10 Jun 2024 00:06:35 +0100 Subject: [PATCH 04/17] move error models and make Poisson use new abstract type --- .../AbstractTuringObservationModel.jl | 6 ++-- .../NegativeBinomialError.jl | 0 .../PoissonError.jl | 29 +++---------------- 3 files changed, 7 insertions(+), 28 deletions(-) rename EpiAware/src/EpiObsModels/{ => ObservationErrorModels}/NegativeBinomialError.jl (100%) rename EpiAware/src/EpiObsModels/{ => ObservationErrorModels}/PoissonError.jl (56%) diff --git a/EpiAware/src/EpiObsModels/ObservationErrorModels/AbstractTuringObservationModel.jl b/EpiAware/src/EpiObsModels/ObservationErrorModels/AbstractTuringObservationModel.jl index f9a2d168c..f443b1f71 100644 --- a/EpiAware/src/EpiObsModels/ObservationErrorModels/AbstractTuringObservationModel.jl +++ b/EpiAware/src/EpiObsModels/ObservationErrorModels/AbstractTuringObservationModel.jl @@ -4,7 +4,7 @@ abstract type AbstractTuringObservationErrorModel <: AbstractTuringObservationMo obs_model::AbstractTuringObservationErrorModel, y_t, Y_t) - @submodel priors = generate_obs_error_priors(obs_model, y_t, Y_t) + @submodel priors = generate_observation_error_priors(obs_model, y_t, Y_t) if ismissing(y_t) y_t = Vector{Int}(undef, length(Y_t)) @@ -13,13 +13,13 @@ abstract type AbstractTuringObservationErrorModel <: AbstractTuringObservationMo Y_y = length(y_t) - length(Y_t) for i in eachindex(Y_t) - y_t[Y_y + i] ~ obs_error(obs_model, Y_t[i] + 1e-6) + y_t[Y_y + i] ~ obs_error(obs_model, Y_t[i]) end return y_t, priors end -@model function generate_obs_error_priors( +@model function generate_observation_error_priors( obs_model::AbstractTuringObservationErrorModel, y_t, Y_t) return NamedTuple() end diff --git a/EpiAware/src/EpiObsModels/NegativeBinomialError.jl b/EpiAware/src/EpiObsModels/ObservationErrorModels/NegativeBinomialError.jl similarity index 100% rename from EpiAware/src/EpiObsModels/NegativeBinomialError.jl rename to EpiAware/src/EpiObsModels/ObservationErrorModels/NegativeBinomialError.jl diff --git a/EpiAware/src/EpiObsModels/PoissonError.jl b/EpiAware/src/EpiObsModels/ObservationErrorModels/PoissonError.jl similarity index 56% rename from EpiAware/src/EpiObsModels/PoissonError.jl rename to EpiAware/src/EpiObsModels/ObservationErrorModels/PoissonError.jl index 3a760fb43..5157f7c57 100644 --- a/EpiAware/src/EpiObsModels/PoissonError.jl +++ b/EpiAware/src/EpiObsModels/ObservationErrorModels/PoissonError.jl @@ -1,6 +1,6 @@ @doc raw" The `PoissonError` struct represents an observation model for Poisson errors. It -is a subtype of `AbstractTuringObservationModel`. Note that +is a subtype of `AbstractTuringObservationErrorModel`. Note that when Y_t is shorter than y_t, then the first `length(y_t) - length(Y_t)` elements of y_t are assumed to be missing. ## Constructors @@ -15,7 +15,7 @@ poi_model = generate_observations(poi, missing, fill(10, 10)) rand(poi_model) ``` " -struct PoissonError{T <: AbstractFloat} <: AbstractTuringObservationModel +struct PoissonError{T <: AbstractFloat} <: AbstractTuringObservationErrorModel "The positive shift value." pos_shift::T @@ -25,27 +25,6 @@ struct PoissonError{T <: AbstractFloat} <: AbstractTuringObservationModel end end -@doc raw" -Generate observations using the `PoissonError` observation model. - -# Arguments -- `obs_model::PoissonError`: The observation model. -- `y_t`: The observed values. -- `Y_t`: The true values. - -# Returns -- `y_t`: The generated observations. -- An empty named tuple. -" -@model function EpiAwareBase.generate_observations(obs_model::PoissonError, y_t, Y_t) - if ismissing(y_t) - y_t = Vector{Int}(undef, length(Y_t)) - end - Y_y = length(y_t) - length(Y_t) - - for i in eachindex(Y_t) - y_t[Y_y + i] ~ Poisson(Y_t[i] + obs_model.pos_shift) - end - - return y_t, NamedTuple() +function obs_error(obs_model::PoissonError, Y_t) + return Poisson(Y_t + obs_model.pos_shift) end From b361529009f7484783e617f2ba5cd6e2b44a52e7 Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 10 Jun 2024 00:28:28 +0100 Subject: [PATCH 05/17] tests passing --- EpiAware/src/EpiObsModels/EpiObsModels.jl | 9 +++-- .../AbstractTuringObservationModel.jl | 6 ++-- .../NegativeBinomialError.jl | 36 +++++-------------- .../ObservationErrorModels/PoissonError.jl | 2 +- .../NegativeBinomialError.jl | 0 .../PoissonError.jl | 0 6 files changed, 19 insertions(+), 34 deletions(-) rename EpiAware/test/EpiObsModels/{ => ObservationErrorModels}/NegativeBinomialError.jl (100%) rename EpiAware/test/EpiObsModels/{ => ObservationErrorModels}/PoissonError.jl (100%) diff --git a/EpiAware/src/EpiObsModels/EpiObsModels.jl b/EpiAware/src/EpiObsModels/EpiObsModels.jl index 2ab3d8716..74d6f2f27 100644 --- a/EpiAware/src/EpiObsModels/EpiObsModels.jl +++ b/EpiAware/src/EpiObsModels/EpiObsModels.jl @@ -11,6 +11,10 @@ using ..EpiLatentModels: HierarchicalNormal, broadcast_dayofweek using Turing, Distributions, DocStringExtensions, SparseArrays +# Abstract observation model +export AbstractTuringObservationErrorModel, generate_observation_error_priors, + observation_error + # Observation models export PoissonError, NegativeBinomialError @@ -25,8 +29,9 @@ include("LatentDelay.jl") include("ascertainment/Ascertainment.jl") include("ascertainment/helpers.jl") include("StackObservationModels.jl") -include("PoissonError.jl") -include("NegativeBinomialError.jl") +include("ObservationErrorModels/AbstractTuringObservationModel.jl") +include("ObservationErrorModels/NegativeBinomialError.jl") +include("ObservationErrorModels/PoissonError.jl") include("utils.jl") end diff --git a/EpiAware/src/EpiObsModels/ObservationErrorModels/AbstractTuringObservationModel.jl b/EpiAware/src/EpiObsModels/ObservationErrorModels/AbstractTuringObservationModel.jl index f443b1f71..653d8f468 100644 --- a/EpiAware/src/EpiObsModels/ObservationErrorModels/AbstractTuringObservationModel.jl +++ b/EpiAware/src/EpiObsModels/ObservationErrorModels/AbstractTuringObservationModel.jl @@ -13,7 +13,7 @@ abstract type AbstractTuringObservationErrorModel <: AbstractTuringObservationMo Y_y = length(y_t) - length(Y_t) for i in eachindex(Y_t) - y_t[Y_y + i] ~ obs_error(obs_model, Y_t[i]) + y_t[Y_y + i] ~ observation_error(obs_model, Y_t[i], priors...) end return y_t, priors @@ -24,7 +24,7 @@ end return NamedTuple() end -function obs_error(obs_model::AbstractTuringObservationErrorModel, Y_t) - @info "No concrete implementation for `_apply_method` is defined." +function observation_error(obs_model::AbstractTuringObservationErrorModel, Y_t) + @info "No concrete implementation for `observation_error` is defined." return nothing end diff --git a/EpiAware/src/EpiObsModels/ObservationErrorModels/NegativeBinomialError.jl b/EpiAware/src/EpiObsModels/ObservationErrorModels/NegativeBinomialError.jl index 49dab2173..a041e9c07 100644 --- a/EpiAware/src/EpiObsModels/ObservationErrorModels/NegativeBinomialError.jl +++ b/EpiAware/src/EpiObsModels/ObservationErrorModels/NegativeBinomialError.jl @@ -15,7 +15,7 @@ rand(nb_model) ``` " struct NegativeBinomialError{S <: Sampleable, T <: AbstractFloat} <: - AbstractTuringObservationModel + AbstractTuringObservationErrorModel "The prior distribution for the cluster factor." cluster_factor_prior::S "The positive shift value." @@ -35,34 +35,14 @@ struct NegativeBinomialError{S <: Sampleable, T <: AbstractFloat} <: end end -@doc raw" -Generate observations using the NegativeBinomialError observation model. - -# Arguments -- `obs_model::NegativeBinomialError`: The observation model. -- `y_t`: The observed values. -- `Y_t`: The true values. - -# Returns -- `y_t`: The generated observations. -- `(; cluster_factor,)`: A named tuple containing the generated `cluster_factor`. -" -@model function EpiAwareBase.generate_observations(obs_model::NegativeBinomialError, - y_t, - Y_t) +@model function generate_observation_error_priors( + obs_model::NegativeBinomialError, Y_t, y_t) cluster_factor ~ obs_model.cluster_factor_prior sq_cluster_factor = cluster_factor^2 - if ismissing(y_t) - y_t = Vector{Int}(undef, length(Y_t)) - end - - Y_y = length(y_t) - length(Y_t) - - for i in eachindex(Y_t) - y_t[Y_y + i] ~ NegativeBinomialMeanClust( - Y_t[i] + obs_model.pos_shift, sq_cluster_factor - ) - end + return (; sq_cluster_factor) +end - return y_t, (; cluster_factor,) +function observation_error(obs_model::NegativeBinomialError, Y_t, sq_cluster_factor) + return NegativeBinomialMeanClust(Y_t + obs_model.pos_shift, + sq_cluster_factor) end diff --git a/EpiAware/src/EpiObsModels/ObservationErrorModels/PoissonError.jl b/EpiAware/src/EpiObsModels/ObservationErrorModels/PoissonError.jl index 5157f7c57..600d86a07 100644 --- a/EpiAware/src/EpiObsModels/ObservationErrorModels/PoissonError.jl +++ b/EpiAware/src/EpiObsModels/ObservationErrorModels/PoissonError.jl @@ -25,6 +25,6 @@ struct PoissonError{T <: AbstractFloat} <: AbstractTuringObservationErrorModel end end -function obs_error(obs_model::PoissonError, Y_t) +function observation_error(obs_model::PoissonError, Y_t) return Poisson(Y_t + obs_model.pos_shift) end diff --git a/EpiAware/test/EpiObsModels/NegativeBinomialError.jl b/EpiAware/test/EpiObsModels/ObservationErrorModels/NegativeBinomialError.jl similarity index 100% rename from EpiAware/test/EpiObsModels/NegativeBinomialError.jl rename to EpiAware/test/EpiObsModels/ObservationErrorModels/NegativeBinomialError.jl diff --git a/EpiAware/test/EpiObsModels/PoissonError.jl b/EpiAware/test/EpiObsModels/ObservationErrorModels/PoissonError.jl similarity index 100% rename from EpiAware/test/EpiObsModels/PoissonError.jl rename to EpiAware/test/EpiObsModels/ObservationErrorModels/PoissonError.jl From cc163b6645d4a8f251815dd9da37fb79b3f29236 Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 10 Jun 2024 09:25:42 +0100 Subject: [PATCH 06/17] update handling of missingness to maintain expected_obs length --- EpiAware/src/EpiObsModels/LatentDelay.jl | 2 +- .../AbstractTuringObservationModel.jl | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/EpiAware/src/EpiObsModels/LatentDelay.jl b/EpiAware/src/EpiObsModels/LatentDelay.jl index 86f1bbf1a..12e21b8eb 100644 --- a/EpiAware/src/EpiObsModels/LatentDelay.jl +++ b/EpiAware/src/EpiObsModels/LatentDelay.jl @@ -70,7 +70,7 @@ Generates observations based on the `LatentDelay` observation model. expected_obs = kernel * Y_t @submodel y_t, obs_aux = generate_observations( - obs_model.model, y_t, expected_obs) + obs_model.model, y_t, complete_expected_obs) return y_t, (; obs_aux...) end diff --git a/EpiAware/src/EpiObsModels/ObservationErrorModels/AbstractTuringObservationModel.jl b/EpiAware/src/EpiObsModels/ObservationErrorModels/AbstractTuringObservationModel.jl index 653d8f468..b64427d52 100644 --- a/EpiAware/src/EpiObsModels/ObservationErrorModels/AbstractTuringObservationModel.jl +++ b/EpiAware/src/EpiObsModels/ObservationErrorModels/AbstractTuringObservationModel.jl @@ -7,13 +7,13 @@ abstract type AbstractTuringObservationErrorModel <: AbstractTuringObservationMo @submodel priors = generate_observation_error_priors(obs_model, y_t, Y_t) if ismissing(y_t) - y_t = Vector{Int}(undef, length(Y_t)) + y_t = Vector{Union{Real, Missing}}(missing, length(Y_t)) end - Y_y = length(y_t) - length(Y_t) - - for i in eachindex(Y_t) - y_t[Y_y + i] ~ observation_error(obs_model, Y_t[i], priors...) + for i in eachindex(y_t) + if (!ismissing(Y_t[i])) + y_t[i] ~ observation_error(obs_model, Y_t[i], priors...) + end end return y_t, priors From 6f943d9162cd4cd5c93e3163941bdffb2fca4e9b Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 10 Jun 2024 09:34:57 +0100 Subject: [PATCH 07/17] update handling of missingness --- EpiAware/src/EpiObsModels/LatentDelay.jl | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/EpiAware/src/EpiObsModels/LatentDelay.jl b/EpiAware/src/EpiObsModels/LatentDelay.jl index 12e21b8eb..890659fd2 100644 --- a/EpiAware/src/EpiObsModels/LatentDelay.jl +++ b/EpiAware/src/EpiObsModels/LatentDelay.jl @@ -64,13 +64,16 @@ Generates observations based on the `LatentDelay` observation model. " @model function EpiAwareBase.generate_observations(obs_model::LatentDelay, y_t, Y_t) - @assert length(obs_model.pmf)<=length(Y_t) "The delay PMF must be shorter than or equal to the observation vector" + first_Y_t = findfirst(!ismissing, Y_t) + trunc_Y_t = Y_t[first_Y_t:end] + @assert length(obs_model.pmf)<=length(trunc_Y_t) "The delay PMF must be shorter than or equal to the observation vector" - kernel = generate_observation_kernel(obs_model.pmf, length(Y_t), partial = false) - expected_obs = kernel * Y_t + kernel = generate_observation_kernel(obs_model.pmf, length(trunc_Y_t), partial = false) + expected_obs = kernel * trunc_Y_t + complete_obs = vcat(fill(missing, length(obs_model.pmf) + first_Y_t - 1), expected_obs) @submodel y_t, obs_aux = generate_observations( - obs_model.model, y_t, complete_expected_obs) + obs_model.model, y_t, complete_obs) return y_t, (; obs_aux...) end From 634e7f7d24291c47cb09db5804770d61416ed698 Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 10 Jun 2024 09:41:52 +0100 Subject: [PATCH 08/17] change type of some tests --- EpiAware/test/EpiAwareUtils/generate_epiware.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/EpiAware/test/EpiAwareUtils/generate_epiware.jl b/EpiAware/test/EpiAwareUtils/generate_epiware.jl index d2fd32208..e73fd3878 100644 --- a/EpiAware/test/EpiAwareUtils/generate_epiware.jl +++ b/EpiAware/test/EpiAwareUtils/generate_epiware.jl @@ -32,7 +32,7 @@ gen = generated_quantities(test_mdl, rand(test_mdl)) #Check model sampled - @test eltype(gen.generated_y_t) <: Int + @test eltype(gen.generated_y_t) <: Union{Missing, Real} @test eltype(gen.I_t) <: AbstractFloat @test length(gen.I_t) == time_horizon end @@ -70,7 +70,7 @@ end gens = generated_quantities(test_mdl, chn) #Check model sampled - @test eltype(gens[1].generated_y_t) <: Int + @test eltype(gens[1].generated_y_t) <: Union{Missing, Real} @test eltype(gens[1].I_t) <: AbstractFloat @test length(gens[1].I_t) == time_horizon end @@ -110,7 +110,7 @@ end gens = generated_quantities(test_mdl, chn) #Check model sampled - @test eltype(gens[1].generated_y_t) <: Int + @test eltype(gens[1].generated_y_t) <: Union{Missing, Real} @test eltype(gens[1].I_t) <: AbstractFloat @test length(gens[1].I_t) == time_horizon end From 6d44b686e07c869a193cf0d7a32b2546b6991418 Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 10 Jun 2024 09:55:45 +0100 Subject: [PATCH 09/17] get tests passing --- EpiAware/src/EpiObsModels/LatentDelay.jl | 2 +- .../AbstractTuringObservationModel.jl | 7 +++---- EpiAware/test/EpiObsModels/LatentDelay.jl | 8 +++++--- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/EpiAware/src/EpiObsModels/LatentDelay.jl b/EpiAware/src/EpiObsModels/LatentDelay.jl index 890659fd2..40a7a4712 100644 --- a/EpiAware/src/EpiObsModels/LatentDelay.jl +++ b/EpiAware/src/EpiObsModels/LatentDelay.jl @@ -70,7 +70,7 @@ Generates observations based on the `LatentDelay` observation model. kernel = generate_observation_kernel(obs_model.pmf, length(trunc_Y_t), partial = false) expected_obs = kernel * trunc_Y_t - complete_obs = vcat(fill(missing, length(obs_model.pmf) + first_Y_t - 1), expected_obs) + complete_obs = vcat(fill(missing, length(obs_model.pmf) + first_Y_t - 2), expected_obs) @submodel y_t, obs_aux = generate_observations( obs_model.model, y_t, complete_obs) diff --git a/EpiAware/src/EpiObsModels/ObservationErrorModels/AbstractTuringObservationModel.jl b/EpiAware/src/EpiObsModels/ObservationErrorModels/AbstractTuringObservationModel.jl index b64427d52..428344780 100644 --- a/EpiAware/src/EpiObsModels/ObservationErrorModels/AbstractTuringObservationModel.jl +++ b/EpiAware/src/EpiObsModels/ObservationErrorModels/AbstractTuringObservationModel.jl @@ -10,10 +10,9 @@ abstract type AbstractTuringObservationErrorModel <: AbstractTuringObservationMo y_t = Vector{Union{Real, Missing}}(missing, length(Y_t)) end - for i in eachindex(y_t) - if (!ismissing(Y_t[i])) - y_t[i] ~ observation_error(obs_model, Y_t[i], priors...) - end + first_Y_t = findfirst(!ismissing, Y_t) + for i in first_Y_t:length(Y_t) + y_t[i] ~ observation_error(obs_model, Y_t[i], priors...) end return y_t, priors diff --git a/EpiAware/test/EpiObsModels/LatentDelay.jl b/EpiAware/test/EpiObsModels/LatentDelay.jl index bba768912..e30486e70 100644 --- a/EpiAware/test/EpiObsModels/LatentDelay.jl +++ b/EpiAware/test/EpiObsModels/LatentDelay.jl @@ -130,17 +130,19 @@ end obs_model = LatentDelay(TestObs(), delay_int) I_t = [10.0, 20.0, 30.0, 40.0, 50.0] - expected_obs = [23.0, 33.0, 43.0] + expected_obs = [missing, missing, 23.0, 33.0, 43.0] @testset "Test with entirely missing data" begin mdl = generate_observations(obs_model, missing, I_t) - @test mdl()[1] == expected_obs + @test mdl()[1][3:end] == expected_obs[3:end] + @test sum(mdl()[1] .|> ismissing) == 2 end @testset "Test with missing data defined as a vector" begin mdl = generate_observations( obs_model, [missing, missing, missing, missing, missing], I_t) - @test mdl()[1] == expected_obs + @test mdl()[1][3:end] == expected_obs[3:end] + @test sum(mdl()[1] .|> ismissing) == 2 end @testset "Test with data" begin From c87bcd869fa6fa41abf70e660f334155181e2166 Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 10 Jun 2024 21:28:13 +0100 Subject: [PATCH 10/17] add docs for abstract methods --- .../AbstractTuringObservationModel.jl | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/EpiAware/src/EpiObsModels/ObservationErrorModels/AbstractTuringObservationModel.jl b/EpiAware/src/EpiObsModels/ObservationErrorModels/AbstractTuringObservationModel.jl index 428344780..b068b01bd 100644 --- a/EpiAware/src/EpiObsModels/ObservationErrorModels/AbstractTuringObservationModel.jl +++ b/EpiAware/src/EpiObsModels/ObservationErrorModels/AbstractTuringObservationModel.jl @@ -1,5 +1,12 @@ +@doc raw" +The abstract supertype for all structs that define a model for generating +observation errors. +" abstract type AbstractTuringObservationErrorModel <: AbstractTuringObservationModel end +@doc raw" +Generates observations from an observation error model. It provides support for missing values in observations (`y_t`), and missing values at the beginning of the expected observations (`Y_t`). It dispatches to the `observation_error` function to generate the observation error distribution which uses priors generated by `generate_observation_error_priors` submodel. For most observation error models specific implementations of `observation_error` and `generate_observation_error_priors` are required but a specific implementation of `generate_observations` is not required. +" @model function EpiAwareBase.generate_observations( obs_model::AbstractTuringObservationErrorModel, y_t, @@ -10,19 +17,24 @@ abstract type AbstractTuringObservationErrorModel <: AbstractTuringObservationMo y_t = Vector{Union{Real, Missing}}(missing, length(Y_t)) end - first_Y_t = findfirst(!ismissing, Y_t) - for i in first_Y_t:length(Y_t) + for i in findfirst(!ismissing, Y_t):length(Y_t) y_t[i] ~ observation_error(obs_model, Y_t[i], priors...) end return y_t, priors end +@doc raw" +Generates priors for the observation error model. This should return a named tuple containing the priors required for generating the observation error distribution. +" @model function generate_observation_error_priors( obs_model::AbstractTuringObservationErrorModel, y_t, Y_t) return NamedTuple() end +@doc raw" +Generates the observation error distribution for the observation error model. This function should return the distribution for the observation error given the expected observation value `Y_t` and the priors generated by `generate_observation_error_priors`. +" function observation_error(obs_model::AbstractTuringObservationErrorModel, Y_t) @info "No concrete implementation for `observation_error` is defined." return nothing From 48eb1971f1818af9061471c4ae1a3db3793bd3b6 Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 10 Jun 2024 21:32:00 +0100 Subject: [PATCH 11/17] add docs for new specific methods --- .../AbstractTuringObservationModel.jl | 2 +- .../ObservationErrorModels/NegativeBinomialError.jl | 6 ++++++ .../EpiObsModels/ObservationErrorModels/PoissonError.jl | 7 +++++-- 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/EpiAware/src/EpiObsModels/ObservationErrorModels/AbstractTuringObservationModel.jl b/EpiAware/src/EpiObsModels/ObservationErrorModels/AbstractTuringObservationModel.jl index b068b01bd..89a7d208a 100644 --- a/EpiAware/src/EpiObsModels/ObservationErrorModels/AbstractTuringObservationModel.jl +++ b/EpiAware/src/EpiObsModels/ObservationErrorModels/AbstractTuringObservationModel.jl @@ -33,7 +33,7 @@ Generates priors for the observation error model. This should return a named tup end @doc raw" -Generates the observation error distribution for the observation error model. This function should return the distribution for the observation error given the expected observation value `Y_t` and the priors generated by `generate_observation_error_priors`. +The observation error distribution for the observation error model. This function should return the distribution for the observation error given the expected observation value `Y_t` and the priors generated by `generate_observation_error_priors`. " function observation_error(obs_model::AbstractTuringObservationErrorModel, Y_t) @info "No concrete implementation for `observation_error` is defined." diff --git a/EpiAware/src/EpiObsModels/ObservationErrorModels/NegativeBinomialError.jl b/EpiAware/src/EpiObsModels/ObservationErrorModels/NegativeBinomialError.jl index a041e9c07..f4eeaeb3a 100644 --- a/EpiAware/src/EpiObsModels/ObservationErrorModels/NegativeBinomialError.jl +++ b/EpiAware/src/EpiObsModels/ObservationErrorModels/NegativeBinomialError.jl @@ -35,6 +35,9 @@ struct NegativeBinomialError{S <: Sampleable, T <: AbstractFloat} <: end end +@doc raw" +Generates observation error priors based on the `NegativeBinomialError` observation model. This function generates the cluster factor prior for the negative binomial error model. +" @model function generate_observation_error_priors( obs_model::NegativeBinomialError, Y_t, y_t) cluster_factor ~ obs_model.cluster_factor_prior @@ -42,6 +45,9 @@ end return (; sq_cluster_factor) end +@doc raw" +This function generates the observation error model based on the negative binomial error model with a positive shift. It dispatches to the `NegativeBinomialMeanClust` distribution. +" function observation_error(obs_model::NegativeBinomialError, Y_t, sq_cluster_factor) return NegativeBinomialMeanClust(Y_t + obs_model.pos_shift, sq_cluster_factor) diff --git a/EpiAware/src/EpiObsModels/ObservationErrorModels/PoissonError.jl b/EpiAware/src/EpiObsModels/ObservationErrorModels/PoissonError.jl index 600d86a07..cea48fffd 100644 --- a/EpiAware/src/EpiObsModels/ObservationErrorModels/PoissonError.jl +++ b/EpiAware/src/EpiObsModels/ObservationErrorModels/PoissonError.jl @@ -1,7 +1,6 @@ @doc raw" The `PoissonError` struct represents an observation model for Poisson errors. It -is a subtype of `AbstractTuringObservationErrorModel`. Note that -when Y_t is shorter than y_t, then the first `length(y_t) - length(Y_t)` elements of y_t are assumed to be missing. +is a subtype of `AbstractTuringObservationErrorModel`. ## Constructors - `PoissonError(; pos_shift::AbstractFloat = 0.)`: Constructs a `PoissonError` @@ -25,6 +24,10 @@ struct PoissonError{T <: AbstractFloat} <: AbstractTuringObservationErrorModel end end +@doc raw" +The observation error model for Poisson errors. This function generates the +observation error model based on the Poisson error model with a positive shift. +" function observation_error(obs_model::PoissonError, Y_t) return Poisson(Y_t + obs_model.pos_shift) end From d9dcbf35e72f02792bc67c402cce6f7bb52a051f Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 10 Jun 2024 22:46:58 +0100 Subject: [PATCH 12/17] add unit tests --- .../AbstractTuringObservationModel.jl | 2 ++ .../AbstractTuringObservationError.jl | 27 +++++++++++++++++++ 2 files changed, 29 insertions(+) create mode 100644 EpiAware/test/EpiObsModels/ObservationErrorModels/AbstractTuringObservationError.jl diff --git a/EpiAware/src/EpiObsModels/ObservationErrorModels/AbstractTuringObservationModel.jl b/EpiAware/src/EpiObsModels/ObservationErrorModels/AbstractTuringObservationModel.jl index 89a7d208a..76ef9ffb4 100644 --- a/EpiAware/src/EpiObsModels/ObservationErrorModels/AbstractTuringObservationModel.jl +++ b/EpiAware/src/EpiObsModels/ObservationErrorModels/AbstractTuringObservationModel.jl @@ -15,6 +15,8 @@ Generates observations from an observation error model. It provides support for if ismissing(y_t) y_t = Vector{Union{Real, Missing}}(missing, length(Y_t)) + else + @assert length(y_t)==length(Y_t) "The observation vector and expected observation vector must have the same length." end for i in findfirst(!ismissing, Y_t):length(Y_t) diff --git a/EpiAware/test/EpiObsModels/ObservationErrorModels/AbstractTuringObservationError.jl b/EpiAware/test/EpiObsModels/ObservationErrorModels/AbstractTuringObservationError.jl new file mode 100644 index 000000000..fe9fe78b2 --- /dev/null +++ b/EpiAware/test/EpiObsModels/ObservationErrorModels/AbstractTuringObservationError.jl @@ -0,0 +1,27 @@ +@testitem "Test specific generate_observations" begin + using Distributions: Normal + struct TestObs <: AbstractTuringObservationErrorModel end + + function EpiObsModels.observation_error(model::TestObs, Y_t) + Normal(Y_t, 1e-6) + end + + obs_model = TestObs() + + I_t = [10.0, 20.0, 30.0, 40.0, 50.0] + + @testset "Test with entirely missing data" begin + mdl = generate_observations(obs_model, missing, I_t) + @test isapprox(mdl()[1], I_t, atol = 1e-3) + end + + missing_I_t = vcat(missing, I_t) + + @testset "Test with leading missing expected observations" begin + mdl = generate_observations(obs_model, missing_I_t, vcat(20, I_t)) + draw = mdl()[1] + @test draw[2:end] == I_t + @test abs(draw[1] - 20) > 0 + @test isapprox(draw[1], 20, atol = 1e-3) + end +end From 0ec11bd86a47f1450f9893d3d644332de58003b6 Mon Sep 17 00:00:00 2001 From: Sam Date: Tue, 11 Jun 2024 18:07:41 +0100 Subject: [PATCH 13/17] move abstract type --- EpiAware/src/EpiAwareBase/EpiAwareBase.jl | 11 ++++++++--- EpiAware/src/EpiAwareBase/types.jl | 6 ++++++ EpiAware/src/EpiObsModels/EpiObsModels.jl | 11 +++++------ .../{AbstractTuringObservationModel.jl => methods.jl} | 6 ------ .../{AbstractTuringObservationError.jl => methods.jl} | 0 5 files changed, 19 insertions(+), 15 deletions(-) rename EpiAware/src/EpiObsModels/ObservationErrorModels/{AbstractTuringObservationModel.jl => methods.jl} (90%) rename EpiAware/test/EpiObsModels/ObservationErrorModels/{AbstractTuringObservationError.jl => methods.jl} (100%) diff --git a/EpiAware/src/EpiAwareBase/EpiAwareBase.jl b/EpiAware/src/EpiAwareBase/EpiAwareBase.jl index 430c2987d..4fedd10c4 100644 --- a/EpiAware/src/EpiAwareBase/EpiAwareBase.jl +++ b/EpiAware/src/EpiAwareBase/EpiAwareBase.jl @@ -10,9 +10,14 @@ using DocStringExtensions #Export models export AbstractModel, AbstractEpiModel, AbstractLatentModel, AbstractObservationModel -# Export Turing-based models -export AbstractTuringEpiModel, AbstractTuringLatentModel, AbstractTuringIntercept, - AbstractTuringObservationModel, AbstractTuringRenewal +# Export Turing-based models EpiModels +export AbstractTuringEpiModel, AbstractTuringRenewal + +# Export Turing-based latent models +export AbstractTuringLatentModel, AbstractTuringIntercept + +# Export Turing-based observation models +export AbstractTuringObservationModel, AbstractTuringObservationErrorModel # Export support types export AbstractBroadcastRule diff --git a/EpiAware/src/EpiAwareBase/types.jl b/EpiAware/src/EpiAwareBase/types.jl index 7afefbab2..0b3a116d9 100644 --- a/EpiAware/src/EpiAwareBase/types.jl +++ b/EpiAware/src/EpiAwareBase/types.jl @@ -42,6 +42,12 @@ A abstract type representing a Turing-based observation model. """ abstract type AbstractTuringObservationModel <: AbstractObservationModel end +""" +The abstract supertype for all structs that defines a Turing-based model for +generating observation errors. +""" +abstract type AbstractTuringObservationErrorModel <: AbstractTuringObservationModel end + """ Abstract supertype for all `EpiAware` problems. """ diff --git a/EpiAware/src/EpiObsModels/EpiObsModels.jl b/EpiAware/src/EpiObsModels/EpiObsModels.jl index 74d6f2f27..a18d116aa 100644 --- a/EpiAware/src/EpiObsModels/EpiObsModels.jl +++ b/EpiAware/src/EpiObsModels/EpiObsModels.jl @@ -11,13 +11,12 @@ using ..EpiLatentModels: HierarchicalNormal, broadcast_dayofweek using Turing, Distributions, DocStringExtensions, SparseArrays -# Abstract observation model -export AbstractTuringObservationErrorModel, generate_observation_error_priors, - observation_error - -# Observation models +# Observation error models export PoissonError, NegativeBinomialError +# Observation error model functions +export generate_observation_error_priors, observation_error + # Observation model modifiers export LatentDelay, Ascertainment, StackObservationModels @@ -29,7 +28,7 @@ include("LatentDelay.jl") include("ascertainment/Ascertainment.jl") include("ascertainment/helpers.jl") include("StackObservationModels.jl") -include("ObservationErrorModels/AbstractTuringObservationModel.jl") +include("ObservationErrorModels/methods.jl") include("ObservationErrorModels/NegativeBinomialError.jl") include("ObservationErrorModels/PoissonError.jl") include("utils.jl") diff --git a/EpiAware/src/EpiObsModels/ObservationErrorModels/AbstractTuringObservationModel.jl b/EpiAware/src/EpiObsModels/ObservationErrorModels/methods.jl similarity index 90% rename from EpiAware/src/EpiObsModels/ObservationErrorModels/AbstractTuringObservationModel.jl rename to EpiAware/src/EpiObsModels/ObservationErrorModels/methods.jl index 76ef9ffb4..b45cf0f00 100644 --- a/EpiAware/src/EpiObsModels/ObservationErrorModels/AbstractTuringObservationModel.jl +++ b/EpiAware/src/EpiObsModels/ObservationErrorModels/methods.jl @@ -1,9 +1,3 @@ -@doc raw" -The abstract supertype for all structs that define a model for generating -observation errors. -" -abstract type AbstractTuringObservationErrorModel <: AbstractTuringObservationModel end - @doc raw" Generates observations from an observation error model. It provides support for missing values in observations (`y_t`), and missing values at the beginning of the expected observations (`Y_t`). It dispatches to the `observation_error` function to generate the observation error distribution which uses priors generated by `generate_observation_error_priors` submodel. For most observation error models specific implementations of `observation_error` and `generate_observation_error_priors` are required but a specific implementation of `generate_observations` is not required. " diff --git a/EpiAware/test/EpiObsModels/ObservationErrorModels/AbstractTuringObservationError.jl b/EpiAware/test/EpiObsModels/ObservationErrorModels/methods.jl similarity index 100% rename from EpiAware/test/EpiObsModels/ObservationErrorModels/AbstractTuringObservationError.jl rename to EpiAware/test/EpiObsModels/ObservationErrorModels/methods.jl From ea2b6df960840efa8472adda6f1a9a3f15cde5f4 Mon Sep 17 00:00:00 2001 From: Sam Date: Tue, 11 Jun 2024 19:32:26 +0100 Subject: [PATCH 14/17] remove numerical pading as an option and hard code --- .../NegativeBinomialError.jl | 27 +++++-------------- .../ObservationErrorModels/PoissonError.jl | 14 +++------- .../ObservationErrorModels/methods.jl | 8 ++++-- .../test/EpiAwareUtils/generate_epiware.jl | 9 ++----- .../NegativeBinomialError.jl | 11 ++++---- .../ObservationErrorModels/PoissonError.jl | 13 +++------ 6 files changed, 26 insertions(+), 56 deletions(-) diff --git a/EpiAware/src/EpiObsModels/ObservationErrorModels/NegativeBinomialError.jl b/EpiAware/src/EpiObsModels/ObservationErrorModels/NegativeBinomialError.jl index f4eeaeb3a..26778ddad 100644 --- a/EpiAware/src/EpiObsModels/ObservationErrorModels/NegativeBinomialError.jl +++ b/EpiAware/src/EpiObsModels/ObservationErrorModels/NegativeBinomialError.jl @@ -3,8 +3,8 @@ The `NegativeBinomialError` struct represents an observation model for negative binomial errors. It is a subtype of `AbstractTuringObservationModel`. ## Constructors -- `NegativeBinomialError(; cluster_factor_prior::Distribution = HalfNormal(0.1), pos_shift::AbstractFloat = 1e-6)`: Constructs a `NegativeBinomialError` object with default values for the cluster factor prior and positive shift. -- `NegativeBinomialError(cluster_factor_prior::Distribution; pos_shift::AbstractFloat = 1e-6)`: Constructs a `NegativeBinomialError` object with a specified cluster factor prior and default value for the positive shift. +- `NegativeBinomialError(; cluster_factor_prior::Distribution = HalfNormal(0.1))`: Constructs a `NegativeBinomialError` object with default values for the cluster factor prior. +- `NegativeBinomialError(cluster_factor_prior::Distribution)`: Constructs a `NegativeBinomialError` object with a specified cluster factor prior. ## Examples ```julia @@ -14,25 +14,10 @@ nb_model = generate_observations(nb, missing, fill(10, 10)) rand(nb_model) ``` " -struct NegativeBinomialError{S <: Sampleable, T <: AbstractFloat} <: - AbstractTuringObservationErrorModel +@kwdef struct NegativeBinomialError{S <: Sampleable, T <: AbstractFloat} <: + AbstractTuringObservationErrorModel "The prior distribution for the cluster factor." - cluster_factor_prior::S - "The positive shift value." - pos_shift::T - - function NegativeBinomialError(; - cluster_factor_prior::Distribution = HalfNormal(0.01), - pos_shift::AbstractFloat = 1e-6) - new{typeof(cluster_factor_prior), typeof(pos_shift)}( - cluster_factor_prior, pos_shift) - end - - function NegativeBinomialError(cluster_factor_prior::Distribution; - pos_shift::AbstractFloat = 1e-6) - new{typeof(cluster_factor_prior), typeof(pos_shift)}( - cluster_factor_prior, pos_shift) - end + cluster_factor_prior::S = HalfNormal(0.01) end @doc raw" @@ -49,6 +34,6 @@ end This function generates the observation error model based on the negative binomial error model with a positive shift. It dispatches to the `NegativeBinomialMeanClust` distribution. " function observation_error(obs_model::NegativeBinomialError, Y_t, sq_cluster_factor) - return NegativeBinomialMeanClust(Y_t + obs_model.pos_shift, + return NegativeBinomialMeanClust(Y_t, sq_cluster_factor) end diff --git a/EpiAware/src/EpiObsModels/ObservationErrorModels/PoissonError.jl b/EpiAware/src/EpiObsModels/ObservationErrorModels/PoissonError.jl index cea48fffd..1424c7a60 100644 --- a/EpiAware/src/EpiObsModels/ObservationErrorModels/PoissonError.jl +++ b/EpiAware/src/EpiObsModels/ObservationErrorModels/PoissonError.jl @@ -3,8 +3,7 @@ The `PoissonError` struct represents an observation model for Poisson errors. It is a subtype of `AbstractTuringObservationErrorModel`. ## Constructors -- `PoissonError(; pos_shift::AbstractFloat = 0.)`: Constructs a `PoissonError` -object with default values for the cluster factor prior and positive shift. +- `PoissonError()`: Constructs a `PoissonError` object. ## Examples ```julia @@ -15,19 +14,12 @@ rand(poi_model) ``` " struct PoissonError{T <: AbstractFloat} <: AbstractTuringObservationErrorModel - "The positive shift value." - pos_shift::T - - function PoissonError(; pos_shift::AbstractFloat = 0.0) - @assert pos_shift>=0.0 "The positive shift value must be non-negative." - new{typeof(pos_shift)}(pos_shift) - end end @doc raw" The observation error model for Poisson errors. This function generates the -observation error model based on the Poisson error model with a positive shift. +observation error model based on the Poisson error model. " function observation_error(obs_model::PoissonError, Y_t) - return Poisson(Y_t + obs_model.pos_shift) + return Poisson(Y_t) end diff --git a/EpiAware/src/EpiObsModels/ObservationErrorModels/methods.jl b/EpiAware/src/EpiObsModels/ObservationErrorModels/methods.jl index b45cf0f00..2bbab14fa 100644 --- a/EpiAware/src/EpiObsModels/ObservationErrorModels/methods.jl +++ b/EpiAware/src/EpiObsModels/ObservationErrorModels/methods.jl @@ -1,5 +1,7 @@ @doc raw" -Generates observations from an observation error model. It provides support for missing values in observations (`y_t`), and missing values at the beginning of the expected observations (`Y_t`). It dispatches to the `observation_error` function to generate the observation error distribution which uses priors generated by `generate_observation_error_priors` submodel. For most observation error models specific implementations of `observation_error` and `generate_observation_error_priors` are required but a specific implementation of `generate_observations` is not required. +Generates observations from an observation error model. It provides support for missing values in observations (`y_t`), and missing values at the beginning of the expected observations (`Y_t`). It also pads the expected observations with a small value (1e-6) to mitigate potential numerical issues. + +It dispatches to the `observation_error` function to generate the observation error distribution which uses priors generated by `generate_observation_error_priors` submodel. For most observation error models specific implementations of `observation_error` and `generate_observation_error_priors` are required but a specific implementation of `generate_observations` is not required. " @model function EpiAwareBase.generate_observations( obs_model::AbstractTuringObservationErrorModel, @@ -13,8 +15,10 @@ Generates observations from an observation error model. It provides support for @assert length(y_t)==length(Y_t) "The observation vector and expected observation vector must have the same length." end + pad_Y_t = Y_t + 1e-6 + for i in findfirst(!ismissing, Y_t):length(Y_t) - y_t[i] ~ observation_error(obs_model, Y_t[i], priors...) + y_t[i] ~ observation_error(obs_model, pad_Y_t[i], priors...) end return y_t, priors diff --git a/EpiAware/test/EpiAwareUtils/generate_epiware.jl b/EpiAware/test/EpiAwareUtils/generate_epiware.jl index e73fd3878..0b8175fa8 100644 --- a/EpiAware/test/EpiAwareUtils/generate_epiware.jl +++ b/EpiAware/test/EpiAwareUtils/generate_epiware.jl @@ -1,10 +1,8 @@ - @testitem "`generate_epiaware` with direct infections and RW latent process runs" begin using Distributions, Turing, DynamicPPL # Define test inputs y_t = missing # Data will be generated from the model data = EpiData([0.2, 0.3, 0.5], exp) - pos_shift = 1e-6 time_horizon = 100 #Define the epi_model @@ -42,7 +40,6 @@ end # Define test inputs y_t = missing# rand(1:10, 365) # Data will be generated from the model data = EpiData([0.2, 0.3, 0.5], exp) - pos_shift = 1e-6 #Define the epi_model epi_model = ExpGrowthRate(data, Normal()) @@ -56,7 +53,7 @@ end #Define the observation model - no delay model time_horizon = 5 obs_model = NegativeBinomialError( - truncated(Gamma(5, 0.05 / 5), 1e-3, 1.0); pos_shift + truncated(Gamma(5, 0.05 / 5), 1e-3, 1.0) ) # Create full epi model and sample from it @@ -80,7 +77,6 @@ end # Define test inputs y_t = missing# rand(1:10, 365) # Data will be generated from the model data = EpiData([0.2, 0.3, 0.5], exp) - pos_shift = 1e-6 #Define the epi_model epi_model = Renewal(data, Normal()) @@ -94,8 +90,7 @@ end #Define the observation model - no delay model time_horizon = 5 obs_model = NegativeBinomialError( - truncated(Gamma(5, 0.05 / 5), 1e-3, 1.0); - pos_shift + truncated(Gamma(5, 0.05 / 5), 1e-3, 1.0) ) # Create full epi model and sample from it diff --git a/EpiAware/test/EpiObsModels/ObservationErrorModels/NegativeBinomialError.jl b/EpiAware/test/EpiObsModels/ObservationErrorModels/NegativeBinomialError.jl index 9e6d0c70b..2b1c8fd54 100644 --- a/EpiAware/test/EpiObsModels/ObservationErrorModels/NegativeBinomialError.jl +++ b/EpiAware/test/EpiObsModels/ObservationErrorModels/NegativeBinomialError.jl @@ -2,20 +2,19 @@ using Distributions # Test default constructor nb = NegativeBinomialError() + @test typeof(nb) <: NegativeBinomialError + @test typeof(nb) <: TuringObservationErrorModel @test all(rand(nb.cluster_factor_prior, 100) .>= 0.0) @test isapprox(mean(nb.cluster_factor_prior), 0.01) - @test nb.pos_shift ≈ 1e-6 # Test constructor with custom prior prior = Gamma(2.0, 1.0) nb = NegativeBinomialError(prior) @test nb.cluster_factor_prior == prior - @test nb.pos_shift ≈ 1e-6 - # Test constructor with custom prior and pos_shift - nb = NegativeBinomialError(prior; pos_shift = 1e-3) + # Test constructor with custom prior + nb = NegativeBinomialError(prior) @test nb.cluster_factor_prior == prior - @test nb.pos_shift ≈ 1e-3 end @testitem "Testing NegativeBinomialError against theoretical properties" begin @@ -27,7 +26,7 @@ end α = 0.2 # Cluster factor (dispersion parameter) # Define the observation model - nb_obs_model = NegativeBinomialError(pos_shift = 0.0) + nb_obs_model = NegativeBinomialError() # Generate observations from the model Y_t = fill(μ, n) # True values diff --git a/EpiAware/test/EpiObsModels/ObservationErrorModels/PoissonError.jl b/EpiAware/test/EpiObsModels/ObservationErrorModels/PoissonError.jl index 732709ce6..55bc6ff74 100644 --- a/EpiAware/test/EpiObsModels/ObservationErrorModels/PoissonError.jl +++ b/EpiAware/test/EpiObsModels/ObservationErrorModels/PoissonError.jl @@ -1,14 +1,9 @@ @testitem "PoissonErrorConstructor" begin using Distributions # Test default constructor - poi = PoissonError() - @test poi.pos_shift ≈ zero(Float64) - poi_float = PoissonError(; pos_shift = 0.0f0) - @test poi_float.pos_shift ≈ zero(Float32) - - # Test constructor with pos_shift - poi2 = PoissonError(; pos_shift = 1e-3) - @test poi2.pos_shift ≈ 1e-3 + poi = PoissonError + @test typeof(poi) <: PoissonError + @test typeof(poi) <: TuringObservationErrorModel end @testitem "Testing PoissonError against theoretical properties" begin @@ -19,7 +14,7 @@ end μ = 10.0 # Mean of the poisson distribution # Define the observation model - poi_obs_model = PoissonError(pos_shift = 0.0) + poi_obs_model = PoissonError() # Generate observations from the model Y_t = fill(μ, n) # True values From d5f847dd723cb89c4e5d62664fae73f7d3358e43 Mon Sep 17 00:00:00 2001 From: Sam Date: Tue, 11 Jun 2024 19:47:24 +0100 Subject: [PATCH 15/17] correct type specification --- .../ObservationErrorModels/NegativeBinomialError.jl | 2 +- .../src/EpiObsModels/ObservationErrorModels/PoissonError.jl | 2 +- EpiAware/src/EpiObsModels/ObservationErrorModels/methods.jl | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/EpiAware/src/EpiObsModels/ObservationErrorModels/NegativeBinomialError.jl b/EpiAware/src/EpiObsModels/ObservationErrorModels/NegativeBinomialError.jl index 26778ddad..98aad3fb5 100644 --- a/EpiAware/src/EpiObsModels/ObservationErrorModels/NegativeBinomialError.jl +++ b/EpiAware/src/EpiObsModels/ObservationErrorModels/NegativeBinomialError.jl @@ -14,7 +14,7 @@ nb_model = generate_observations(nb, missing, fill(10, 10)) rand(nb_model) ``` " -@kwdef struct NegativeBinomialError{S <: Sampleable, T <: AbstractFloat} <: +@kwdef struct NegativeBinomialError{S <: Sampleable} <: AbstractTuringObservationErrorModel "The prior distribution for the cluster factor." cluster_factor_prior::S = HalfNormal(0.01) diff --git a/EpiAware/src/EpiObsModels/ObservationErrorModels/PoissonError.jl b/EpiAware/src/EpiObsModels/ObservationErrorModels/PoissonError.jl index 1424c7a60..491cf13de 100644 --- a/EpiAware/src/EpiObsModels/ObservationErrorModels/PoissonError.jl +++ b/EpiAware/src/EpiObsModels/ObservationErrorModels/PoissonError.jl @@ -13,7 +13,7 @@ poi_model = generate_observations(poi, missing, fill(10, 10)) rand(poi_model) ``` " -struct PoissonError{T <: AbstractFloat} <: AbstractTuringObservationErrorModel +struct PoissonError <: AbstractTuringObservationErrorModel end @doc raw" diff --git a/EpiAware/src/EpiObsModels/ObservationErrorModels/methods.jl b/EpiAware/src/EpiObsModels/ObservationErrorModels/methods.jl index 2bbab14fa..342f90d7d 100644 --- a/EpiAware/src/EpiObsModels/ObservationErrorModels/methods.jl +++ b/EpiAware/src/EpiObsModels/ObservationErrorModels/methods.jl @@ -15,7 +15,7 @@ It dispatches to the `observation_error` function to generate the observation er @assert length(y_t)==length(Y_t) "The observation vector and expected observation vector must have the same length." end - pad_Y_t = Y_t + 1e-6 + pad_Y_t = Y_t .+ 1e-6 for i in findfirst(!ismissing, Y_t):length(Y_t) y_t[i] ~ observation_error(obs_model, pad_Y_t[i], priors...) From 78c54e3d097d74355fa6da0560f45b23ebfa67fa Mon Sep 17 00:00:00 2001 From: Sam Date: Tue, 11 Jun 2024 19:50:02 +0100 Subject: [PATCH 16/17] fix constructor tests --- .../ObservationErrorModels/NegativeBinomialError.jl | 2 +- .../test/EpiObsModels/ObservationErrorModels/PoissonError.jl | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/EpiAware/test/EpiObsModels/ObservationErrorModels/NegativeBinomialError.jl b/EpiAware/test/EpiObsModels/ObservationErrorModels/NegativeBinomialError.jl index 2b1c8fd54..7739982ce 100644 --- a/EpiAware/test/EpiObsModels/ObservationErrorModels/NegativeBinomialError.jl +++ b/EpiAware/test/EpiObsModels/ObservationErrorModels/NegativeBinomialError.jl @@ -3,7 +3,7 @@ # Test default constructor nb = NegativeBinomialError() @test typeof(nb) <: NegativeBinomialError - @test typeof(nb) <: TuringObservationErrorModel + @test typeof(nb) <: AbstractTuringObservationErrorModel @test all(rand(nb.cluster_factor_prior, 100) .>= 0.0) @test isapprox(mean(nb.cluster_factor_prior), 0.01) diff --git a/EpiAware/test/EpiObsModels/ObservationErrorModels/PoissonError.jl b/EpiAware/test/EpiObsModels/ObservationErrorModels/PoissonError.jl index 55bc6ff74..f3901233f 100644 --- a/EpiAware/test/EpiObsModels/ObservationErrorModels/PoissonError.jl +++ b/EpiAware/test/EpiObsModels/ObservationErrorModels/PoissonError.jl @@ -1,9 +1,9 @@ @testitem "PoissonErrorConstructor" begin using Distributions # Test default constructor - poi = PoissonError + poi = PoissonError() @test typeof(poi) <: PoissonError - @test typeof(poi) <: TuringObservationErrorModel + @test typeof(poi) <: AbstractTuringObservationErrorModel end @testitem "Testing PoissonError against theoretical properties" begin From 28182b3934b39a829d5637aa266694e628375ce4 Mon Sep 17 00:00:00 2001 From: Sam Date: Tue, 11 Jun 2024 19:57:59 +0100 Subject: [PATCH 17/17] add a test for new Y_t == y_t check --- EpiAware/test/EpiObsModels/ObservationErrorModels/methods.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/EpiAware/test/EpiObsModels/ObservationErrorModels/methods.jl b/EpiAware/test/EpiObsModels/ObservationErrorModels/methods.jl index fe9fe78b2..8867680d5 100644 --- a/EpiAware/test/EpiObsModels/ObservationErrorModels/methods.jl +++ b/EpiAware/test/EpiObsModels/ObservationErrorModels/methods.jl @@ -24,4 +24,6 @@ @test abs(draw[1] - 20) > 0 @test isapprox(draw[1], 20, atol = 1e-3) end + + @test_throws AssertionError generate_observations(obs_model, vcat(1, I_t), I_t)() end