Skip to content

Commit

Permalink
Merge pull request #128 from CDCgov/issue53
Browse files Browse the repository at this point in the history
Issue 53: First pass at differenced AR
  • Loading branch information
seabbs authored Mar 15, 2024
2 parents 147c09e + de68031 commit 13398e0
Show file tree
Hide file tree
Showing 18 changed files with 357 additions and 21 deletions.
2 changes: 1 addition & 1 deletion EpiAware/src/EpiAware.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ export spread_draws, scan, create_discrete_pmf
include("EpiLatentModels/EpiLatentModels.jl")
using .EpiLatentModels

export RandomWalk, default_rw_priors
export RandomWalk, AR

include("EpiInfModels/EpiInfModels.jl")
using .EpiInfModels
Expand Down
1 change: 1 addition & 0 deletions EpiAware/src/EpiAwareBase/EpiAwareBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ export AbstractModel, AbstractEpiModel, AbstractLatentModel,
AbstractObservationModel, generate_latent,
generate_latent_infs, generate_observations

include("docstrings.jl")
include("types.jl")
include("functions.jl")

Expand Down
24 changes: 24 additions & 0 deletions EpiAware/src/EpiAwareBase/docstrings.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
@template (FUNCTIONS, METHODS, MACROS) = """
$(TYPEDSIGNATURES)
$(DOCSTRING)
"""

@template (TYPES) = """
$(TYPEDEF)
$(DOCSTRING)
---
## Fields
$(TYPEDFIELDS)
"""

@template MODULES = """
$(DOCSTRING)
---
## Exports
$(EXPORTS)
---
## Imports
$(IMPORTS)
"""
1 change: 1 addition & 0 deletions EpiAware/src/EpiAwareUtils/EpiAwareUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ using DocStringExtensions, QuadGK

export scan, spread_draws, create_discrete_pmf

include("docstrings.jl")
include("prior-tools.jl")
include("distributions.jl")
include("scan.jl")
Expand Down
24 changes: 24 additions & 0 deletions EpiAware/src/EpiAwareUtils/docstrings.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
@template (FUNCTIONS, METHODS, MACROS) = """
$(TYPEDSIGNATURES)
$(DOCSTRING)
"""

@template (TYPES) = """
$(TYPEDEF)
$(DOCSTRING)
---
## Fields
$(TYPEDFIELDS)
"""

@template MODULES = """
$(DOCSTRING)
---
## Exports
$(EXPORTS)
---
## Imports
$(IMPORTS)
"""
1 change: 1 addition & 0 deletions EpiAware/src/EpiInfModels/EpiInfModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ using Turing, Distributions, DocStringExtensions, LinearAlgebra
export EpiData, DirectInfections, ExpGrowthRate, Renewal,
R_to_r, r_to_R

include("docstrings.jl")
include("epidata.jl")
include("directinfections.jl")
include("expgrowthrate.jl")
Expand Down
24 changes: 24 additions & 0 deletions EpiAware/src/EpiInfModels/docstrings.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
@template (FUNCTIONS, METHODS, MACROS) = """
$(TYPEDSIGNATURES)
$(DOCSTRING)
"""

@template (TYPES) = """
$(TYPEDEF)
$(DOCSTRING)
---
## Fields
$(TYPEDFIELDS)
"""

@template MODULES = """
$(DOCSTRING)
---
## Exports
$(EXPORTS)
---
## Imports
$(IMPORTS)
"""
1 change: 1 addition & 0 deletions EpiAware/src/EpiInference/EpiInference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ using DynamicPPL, DocStringExtensions

export manypathfinder

include("docstrings.jl")
include("manypathfinder.jl")

end
24 changes: 24 additions & 0 deletions EpiAware/src/EpiInference/docstrings.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
@template (FUNCTIONS, METHODS, MACROS) = """
$(TYPEDSIGNATURES)
$(DOCSTRING)
"""

@template (TYPES) = """
$(TYPEDEF)
$(DOCSTRING)
---
## Fields
$(TYPEDFIELDS)
"""

@template MODULES = """
$(DOCSTRING)
---
## Exports
$(EXPORTS)
---
## Imports
$(IMPORTS)
"""
6 changes: 4 additions & 2 deletions EpiAware/src/EpiLatentModels/EpiLatentModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ using ..EpiAwareBase

using Turing, Distributions, DocStringExtensions

export RandomWalk, default_rw_priors
export RandomWalk, AR

include("docstrings.jl")
include("randomwalk.jl")

include("autoregressive.jl")
include("utils.jl")
end
95 changes: 95 additions & 0 deletions EpiAware/src/EpiLatentModels/autoregressive.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
@doc raw"
The autoregressive (AR) model struct.
# Constructors
- `AR(damp_prior::Distribution, std_prior::Distribution, init_prior::Distribution; p::Int = 1)`: Constructs an AR model with the specified prior distributions for damping coefficients, standard deviation, and initial conditions. The order of the AR model can also be specified.
- `AR(; damp_priors::Vector{D} = [truncated(Normal(0.0, 0.05))], std_prior::Distribution = truncated(Normal(0.0, 0.05), 0.0, Inf), init_priors::Vector{I} = [Normal()]) where {D <: Distribution, I <: Distribution}`: Constructs an AR model with the specified prior distributions for damping coefficients, standard deviation, and initial conditions. The order of the AR model is determined by the length of the `damp_priors` vector.
- `AR(damp_prior::Distribution, std_prior::Distribution, init_prior::Distribution, p::Int)`: Constructs an AR model with the specified prior distributions for damping coefficients, standard deviation, and initial conditions. The order of the AR model is explicitly specified.
# Examples
```julia
using Distributions
using EpiAware
ar = AR()
ar_model = generate_latent(ar, 10)
rand(ar_model)
```
"
struct AR{D <: Sampleable, S <: Sampleable, I <: Sampleable, P <: Int} <:
AbstractLatentModel
"Prior distribution for the damping coefficients."
damp_prior::D
"Prior distribution for the standard deviation."
std_prior::S
"Prior distribution for the initial conditions"
init_prior::I
"Order of the AR model."
p::P
function AR(damp_prior::Distribution, std_prior::Distribution,
init_prior::Distribution; p::Int = 1)
damp_priors = fill(damp_prior, p)
init_priors = fill(init_prior, p)
return AR(; damp_priors = damp_priors, std_prior = std_prior,
init_priors = init_priors)
end

function AR(; damp_priors::Vector{D} = [truncated(Normal(0.0, 0.05), 0, 1)],
std_prior::Distribution = truncated(Normal(0.0, 0.05), 0.0, Inf),
init_priors::Vector{I} = [Normal()]) where {
D <: Distribution, I <: Distribution}
p = length(damp_priors)
damp_prior = _expand_dist(damp_priors)
init_prior = _expand_dist(init_priors)
return AR(damp_prior, std_prior, init_prior, p)
end

function AR(damp_prior::Distribution, std_prior::Distribution,
init_prior::Distribution, p::Int)
@assert p>0 "p must be greater than 0"
@assert length(damp_prior)==length(init_prior) "damp_prior and init_prior must have the same length"
@assert p==length(damp_prior) "p must be equal to the length of damp_prior"
new{typeof(damp_prior), typeof(std_prior), typeof(init_prior), typeof(p)}(
damp_prior, std_prior, init_prior, p
)
end
end

@doc raw"
Generate a latent AR series.
# Arguments
- `latent_model::AR`: The AR model.
- `n::Int`: The length of the AR series.
# Returns
- `ar::Vector{Float64}`: The generated AR series.
- `params::NamedTuple`: A named tuple containing the generated parameters (`σ_AR`, `ar_init`, `damp_AR`).
# Notes
- The length of `damp_prior` and `init_prior` must be the same.
- `n` must be longer than the order of the autoregressive process.
"
@model function EpiAwareBase.generate_latent(latent_model::AR, n)
p = latent_model.p
ϵ_t ~ MvNormal(ones(n - p))
σ_AR ~ latent_model.std_prior
ar_init ~ latent_model.init_prior
damp_AR ~ latent_model.damp_prior

@assert n>p "n must be longer than order of the autoregressive process"

# Initialize the AR series with the initial values
ar = Vector{Float64}(undef, n)
ar[1:p] = ar_init

# Generate the rest of the AR series
for t in (p + 1):n
ar[t] = damp_AR' * ar[(t - p):(t - 1)] + σ_AR * ϵ_t[t - p]
end

return ar, (; σ_AR, ar_init, damp_AR)
end
24 changes: 24 additions & 0 deletions EpiAware/src/EpiLatentModels/docstrings.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
@template (FUNCTIONS, METHODS, MACROS) = """
$(TYPEDSIGNATURES)
$(DOCSTRING)
"""

@template (TYPES) = """
$(TYPEDEF)
$(DOCSTRING)
---
## Fields
$(TYPEDFIELDS)
"""

@template MODULES = """
$(DOCSTRING)
---
## Exports
$(EXPORTS)
---
## Imports
$(IMPORTS)
"""
11 changes: 3 additions & 8 deletions EpiAware/src/EpiLatentModels/randomwalk.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
struct RandomWalk{D <: Sampleable, S <: Sampleable} <: AbstractLatentModel
init_prior::D
std_prior::S
end

function default_rw_priors()
return (:var_RW_prior => truncated(Normal(0.0, 0.05), 0.0, Inf),
:init_rw_value_prior => Normal()) |> Dict
@kwdef struct RandomWalk{D <: Sampleable, S <: Sampleable} <: AbstractLatentModel
init_prior::D = Normal()
std_prior::S = truncated(Normal(0.0, 0.05), 0.0, Inf)
end

@model function EpiAwareBase.generate_latent(latent_model::RandomWalk, n)
Expand Down
6 changes: 6 additions & 0 deletions EpiAware/src/EpiLatentModels/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
function _expand_dist(dist::Vector{D} where {D <: Distribution})
d = length(dist)
product_dist = all(first(dist) .== dist) ?
filldist(first(dist), d) : arraydist(dist)
return product_dist
end
1 change: 1 addition & 0 deletions EpiAware/src/EpiObsModels/EpiObsModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ using Turing, Distributions, DocStringExtensions, SparseArrays

export DelayObservations, default_delay_obs_priors

include("docstrings.jl")
include("delayobservations.jl")
include("utils.jl")

Expand Down
24 changes: 24 additions & 0 deletions EpiAware/src/EpiObsModels/docstrings.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
@template (FUNCTIONS, METHODS, MACROS) = """
$(TYPEDSIGNATURES)
$(DOCSTRING)
"""

@template (TYPES) = """
$(TYPEDEF)
$(DOCSTRING)
---
## Fields
$(TYPEDFIELDS)
"""

@template MODULES = """
$(DOCSTRING)
---
## Exports
$(EXPORTS)
---
## Imports
$(IMPORTS)
"""
Loading

0 comments on commit 13398e0

Please sign in to comment.