Skip to content

Commit

Permalink
Structs to define inference/generative methods (#155)
Browse files Browse the repository at this point in the history
* Extend abstract types in EpiAwareBase

* EpiAwareMethod structs

* some optimization method structs

* constructor unit tests
  • Loading branch information
SamuelBrand1 authored Mar 19, 2024
1 parent 43c448a commit e68a536
Show file tree
Hide file tree
Showing 16 changed files with 130 additions and 29 deletions.
12 changes: 9 additions & 3 deletions EpiAware/src/EpiAware.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,17 @@ include("EpiObsModels/EpiObsModels.jl")
include("EpiInference/EpiInference.jl")
@reexport using .EpiInference

# Non-submodule exports
export make_epi_aware, EpiAwareProblem
#Export problems
export EpiProblem

#Export inference methods
export EpiMethod

#Export functions
export make_epi_aware

include("docstrings.jl")
include("epiawareproblems/epiawareprob.jl")
include("epiawareprob.jl")
include("make_epi_aware.jl")

end
17 changes: 12 additions & 5 deletions EpiAware/src/EpiAwareBase/EpiAwareBase.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
module EpiAwareBase

"""
Module for defining abstract epidemiological types.
"""
module EpiAwareBase

using DocStringExtensions

export AbstractModel, AbstractEpiModel, AbstractLatentModel,
AbstractObservationModel, AbstractEpiAwareProblem, generate_latent,
generate_latent_infs, generate_observations
#Export models
export AbstractModel, AbstractEpiModel, AbstractLatentModel, AbstractObservationModel

#Export problems
export AbstractEpiProblem

#Export inference methods
export AbstractEpiMethod, AbstractEpiOptMethod, AbstractEpiSamplingMethod

#Export functions
export generate_latent, generate_latent_infs, generate_observations

include("docstrings.jl")
include("types.jl")
Expand Down
19 changes: 18 additions & 1 deletion EpiAware/src/EpiAwareBase/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,21 @@ abstract type AbstractObservationModel <: AbstractModel end
"""
Abstract supertype for all `EpiAware` problems.
"""
abstract type AbstractEpiAwareProblem end
abstract type AbstractEpiProblem end

"""
Abstract supertype for all `EpiAware` inference/generative modelling methods.
"""
abstract type AbstractEpiMethod end

"""
Abstract supertype for infence/generative methods that are based on optimization, e.g. MAP
estimation or variational inference.
"""
abstract type AbstractEpiOptMethod <: AbstractEpiMethod end

"""
Abstract supertype for infence/generative methods that are based on sampling from the
posterior distribution, e.g. NUTS.
"""
abstract type AbstractEpiSamplingMethod <: AbstractEpiMethod end
4 changes: 2 additions & 2 deletions EpiAware/src/EpiAwareUtils/EpiAwareUtils.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
module EpiAwareUtils

"""
Module for defining utility functions.
"""
module EpiAwareUtils

using ..EpiAwareBase

Expand All @@ -12,6 +11,7 @@ using Distributions: Distribution, cdf, Normal, truncated

using DocStringExtensions, QuadGK

#Export functions
export scan, spread_draws, create_discrete_pmf

include("docstrings.jl")
Expand Down
10 changes: 6 additions & 4 deletions EpiAware/src/EpiInfModels/EpiInfModels.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
module EpiInfModels

"""
Module for defining epidemiological models.
"""
module EpiInfModels

using ..EpiAwareBase

using ..EpiAwareUtils: scan, create_discrete_pmf

using Turing, Distributions, DocStringExtensions, LinearAlgebra

export EpiData, DirectInfections, ExpGrowthRate, Renewal,
R_to_r, r_to_R
#Export models
export EpiData, DirectInfections, ExpGrowthRate, Renewal

#Export functions
export R_to_r, r_to_R

include("docstrings.jl")
include("epidata.jl")
Expand Down
10 changes: 9 additions & 1 deletion EpiAware/src/EpiInference/EpiInference.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
module EpiInference

"""
Module for defining inference methods.
"""
module EpiInference

using ..EpiAwareBase: AbstractEpiMethod, AbstractEpiOptMethod,
AbstractEpiSamplingMethod
using Pathfinder: pathfinder, PathfinderResult

using DynamicPPL, DocStringExtensions

#Export inference methods
export AbstractNUTSMethod, EpiMethod, ManyPathfinder

#Export functions
export manypathfinder

include("docstrings.jl")
include("epiawaremethod.jl")
include("manypathfinder.jl")
include("nuts.jl")

end
11 changes: 11 additions & 0 deletions EpiAware/src/EpiInference/epiawaremethod.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""
`EpiMethod` represents a method for performing EpiAware inference and/or generative
modelling, which combines a sequence of optimization steps to pass initialisation
information to a sampler method.
"""
@kwdef struct EpiMethod{
O <: AbstractEpiOptMethod, S <: AbstractEpiSamplingMethod} <:
AbstractEpiMethod
pre_sampler_steps::Vector{O}
sampler::S
end
12 changes: 12 additions & 0 deletions EpiAware/src/EpiInference/manypathfinder.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
"""
A variational inference method that runs `manypathfinder`.
"""
@kwdef struct ManyPathfinder <: AbstractEpiOptMethod
"Number of many pathfinder runs."
nruns::Int = 4
"Maximum number of iterations for each run."
maxiters::Int = 50
"Maximum number of tries if all runs fail."
max_tries::Int = 100
end

"""
Run pathfinder multiple times and store the results in an array. Fails safely.
Expand Down
4 changes: 4 additions & 0 deletions EpiAware/src/EpiInference/nuts.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""
Abstract super type for NUTS methods.
"""
abstract type AbstractNUTSMethod <: AbstractEpiSamplingMethod end
4 changes: 2 additions & 2 deletions EpiAware/src/EpiLatentModels/EpiLatentModels.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
module EpiLatentModels

"""
Module for defining latent models.
"""
module EpiLatentModels

using ..EpiAwareBase

using Turing, Distributions, DocStringExtensions

#Export models
export RandomWalk, AR, DiffLatentModel

include("docstrings.jl")
Expand Down
9 changes: 6 additions & 3 deletions EpiAware/src/EpiObsModels/EpiObsModels.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
module EpiObsModels

"""
Module for defining observation models.
"""
module EpiObsModels

using ..EpiAwareBase

using ..EpiAwareUtils: create_discrete_pmf

using Turing, Distributions, DocStringExtensions, SparseArrays

export DelayObservations, default_delay_obs_priors
#Export models
export DelayObservations

#Export functions
export default_delay_obs_priors

include("docstrings.jl")
include("delayobservations.jl")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
"""
Defines an inference/generative modelling problem for case data.
`EpiAwareProblem` wraps the underlying components of an epidemiological model:
`EpiProblem` wraps the underlying components of an epidemiological model:
- `epi_model`: An epidemiological model for unobserved infections.
- `latent_model`: A latent model for underlying latent process.
- `observation_model`: An observation model for observed cases.
Along with a `tspan` tuple for the time span of the case data.
"""
@kwdef struct EpiAwareProblem{
@kwdef struct EpiProblem{
E <: AbstractEpiModel, L <: AbstractLatentModel, O <: AbstractObservationModel} <:
AbstractEpiAwareProblem
AbstractEpiProblem
"Epidemiological model for unobserved infections."
epi_model::E
"Latent model for underlying latent process."
Expand Down
12 changes: 12 additions & 0 deletions EpiAware/test/test_epiawaremethod.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
@testitem "EpiMethod" begin
@testset "Constructor" begin
struct TestNUTSMethod <: AbstractNUTSMethod
end

pre_sampler_steps = [ManyPathfinder(), ManyPathfinder()]
sampler = TestNUTSMethod()
method = EpiMethod(pre_sampler_steps, sampler)
@test method.pre_sampler_steps == pre_sampler_steps
@test method.sampler == sampler
end
end
8 changes: 4 additions & 4 deletions EpiAware/test/test_epiawareprob.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
@testitem "EpiAwareProblem Tests" begin
@testitem "EpiProblem Tests" begin
using Distributions
# Define test inputs
data = EpiData([0.2, 0.3, 0.5], exp)
Expand All @@ -12,10 +12,10 @@
obs_prior[:neg_bin_cluster_factor_prior])
tspan = (0, 365)

# Create an instance of EpiAwareProblem
problem = EpiAwareProblem(epi_model, latent_model, obs_model, tspan)
# Create an instance of EpiProblem
problem = EpiProblem(epi_model, latent_model, obs_model, tspan)

@test typeof(problem) <: EpiAwareProblem
@test typeof(problem) <: EpiProblem
@test typeof(problem.epi_model) <: DirectInfections
@test typeof(problem.latent_model) <: RandomWalk
@test typeof(problem.observation_model) <: DelayObservations
Expand Down
3 changes: 2 additions & 1 deletion EpiAware/test/test_inference-methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ end

@testset "Check always fails for bad models and throws correct Exception" begin
@model function bad_model()
x ~ Normal(0, 1)
x ~ truncated(Normal(0, 1), -Inf, -1e-3)
y ~ Normal(sqrt(x), 1.0)
return sqrt(x) #<-fails
end
badmdl = bad_model()
Expand Down
18 changes: 18 additions & 0 deletions EpiAware/test/test_manypathfinder.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
@testitem "ManyPathfinder constructor" begin
@testset "Default constructor" begin
method = ManyPathfinder()
@test method.nruns == 4
@test method.maxiters == 50
@test method.max_tries == 100
end

@testset "Constructor" begin
nruns = 5
maxiters = 10
max_tries = 10
method = ManyPathfinder(; nruns, maxiters, max_tries)
@test method.nruns == nruns
@test method.maxiters == maxiters
@test method.max_tries == max_tries
end
end

0 comments on commit e68a536

Please sign in to comment.