Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 51: First pass at broadcasting #171

Merged
merged 10 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion EpiAware/src/EpiAwareBase/EpiAwareBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,18 @@ export AbstractEpiProblem
#Export inference methods
export AbstractEpiMethod, AbstractEpiOptMethod, AbstractEpiSamplingMethod

#Export functions
# Export support types
export AbstractBroadcastRule

#Export generating functions
export generate_latent, generate_latent_infs, generate_observations

#Export support functions
export broadcast_rule, broadcast_n

include("docstrings.jl")
include("types.jl")
include("functions.jl")
include("generate_models.jl")

end
55 changes: 12 additions & 43 deletions EpiAware/src/EpiAwareBase/functions.jl
Original file line number Diff line number Diff line change
@@ -1,50 +1,19 @@
@doc raw"""
Constructor function for unobserved/latent infections based on the type of
`epi_model <: AbstractEpimodel` and a latent process path ``Z_t``.
@doc raw"
This function is used to define the behavior of broadcasting for a specific type of `AbstractBroadcastRule`.

The `generate_latent_infs` function implements a model of generating unobserved/latent
infections conditional on a latent process. Which model of generating unobserved/latent
infections to be implemented is set by the type of `epi_model`. If no implemention is
defined for the given `epi_model`, then `EpiAware` will return a warning and return
`nothing`.

## Interface to `Turing.jl` probablilistic programming language (PPL)

Apart from the no implementation fallback method, the `generate_latent_infs` implementation
function returns a constructor function for a
[`DynamicPPL.Model`](https://turinglang.org/DynamicPPL.jl/stable/api/#DynamicPPL.Model)
object where the unobserved/latent infections are a generated quantity. Priors for model
parameters are fields of `epi_model`.
"""
function generate_latent_infs(epi_model::AbstractEpiModel, Z_t)
@warn "No concrete implementation for `generate_latent_infs` is defined."
The `broadcast_rule` function implements a model of broadcasting a latent process. Which model of broadcasting to be implemented is set by the type of `broadcast_rule`. If no implemention is defined for the given `broadcast_rule`, then `EpiAware` will return a warning and return `nothing`.
"
function broadcast_rule(broadcast_rule::AbstractBroadcastRule, n, period)
@info "No concrete implementation for broadcast_rule is defined."
return nothing
end

@doc raw"""
Constructor function for a latent process path ``Z_t`` of length `n`.

The `generate_latent` function implements a model of generating a latent process. Which
model for generating the latent process infections is implemented is set by the type of
`latent_model`. If no implemention is defined for the type of `latent_model`, then
`EpiAware` will pass a warning and return `nothing`.

## Interface to `Turing.jl` probablilistic programming language (PPL)

Apart from the no implementation fallback method, the `generate_latent` implementation
function should return a constructor function for a
[`DynamicPPL.Model`](https://turinglang.org/DynamicPPL.jl/stable/api/#DynamicPPL.Model)
object. Sample paths of ``Z_t`` are generated quantities of the constructed model. Priors
for model parameters are fields of `epi_model`.
"""
function generate_latent(latent_model::AbstractLatentModel, n)
@info "No concrete implementation for generate_latent is defined."
return nothing
end
@doc raw"
This function is used to define the behavior of broadcasting for a specific type of `AbstractBroadcastRule`.

function generate_observations(obs_model::AbstractObservationModel,
y_t,
Y_t)
@info "No concrete implementation for generate_observations is defined."
The `broadcast_n` function returns the length of the latent periods to generate using the given `broadcast_rule`. Which model of broadcasting to be implemented is set by the type of `broadcast_rule`. If no implemention is defined for the given `broadcast_rule`, then `EpiAware` will return a warning and return `nothing`.
"
function broadcast_n(broadcast_rule::AbstractBroadcastRule, latent, n, period)
@info "No concrete implementation for broadcast_n is defined."
return nothing
end
55 changes: 55 additions & 0 deletions EpiAware/src/EpiAwareBase/generate_models.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
@doc raw"
Constructor function for unobserved/latent infections based on the type of
`epi_model <: AbstractEpimodel` and a latent process path ``Z_t``.

The `generate_latent_infs` function implements a model of generating unobserved/latent
infections conditional on a latent process. Which model of generating unobserved/latent
infections to be implemented is set by the type of `epi_model`. If no implemention is
defined for the given `epi_model`, then `EpiAware` will return a warning and return
`nothing`.

## Interface to `Turing.jl` probablilistic programming language (PPL)

Apart from the no implementation fallback method, the `generate_latent_infs` implementation
function returns a constructor function for a
[`DynamicPPL.Model`](https://turinglang.org/DynamicPPL.jl/stable/api/#DynamicPPL.Model)
object where the unobserved/latent infections are a generated quantity. Priors for model
parameters are fields of `epi_model`.
"
function generate_latent_infs(epi_model::AbstractEpiModel, Z_t)
@warn "No concrete implementation for `generate_latent_infs` is defined."
return nothing
end

@doc raw"
Constructor function for a latent process path ``Z_t`` of length `n`.

The `generate_latent` function implements a model of generating a latent process. Which
model for generating the latent process infections is implemented is set by the type of
`latent_model`. If no implemention is defined for the type of `latent_model`, then
`EpiAware` will pass a warning and return `nothing`.

## Interface to `Turing.jl` probablilistic programming language (PPL)

Apart from the no implementation fallback method, the `generate_latent` implementation
function should return a constructor function for a
[`DynamicPPL.Model`](https://turinglang.org/DynamicPPL.jl/stable/api/#DynamicPPL.Model)
object. Sample paths of ``Z_t`` are generated quantities of the constructed model. Priors
for model parameters are fields of `epi_model`.
"
function generate_latent(latent_model::AbstractLatentModel, n)
@info "No concrete implementation for generate_latent is defined."
return nothing
end

@doc raw"
Constructor function for generating observations based on the given observation model.

The `generate_observations` function implements a model of generating observations based on the given observation model. Which model of generating observations to be implemented is set by the type of `obs_model`. If no implemention is defined for the given `obs_model`, then `EpiAware` will return a warning and return `nothing`.
"
function generate_observations(obs_model::AbstractObservationModel,
y_t,
Y_t)
@info "No concrete implementation for generate_observations is defined."
return nothing
end
8 changes: 8 additions & 0 deletions EpiAware/src/EpiAwareBase/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@ used in `EpiAware` models.
"""
abstract type AbstractLatentModel <: AbstractModel end

"""
An abstract type representing a broadcast rule.
"""
abstract type AbstractBroadcastRule end

"""
A type representing an abstract observation model that is a subtype of `AbstractModel`.
"""
abstract type AbstractObservationModel <: AbstractModel end

"""
Expand Down
11 changes: 10 additions & 1 deletion EpiAware/src/EpiLatentModels/EpiLatentModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,21 @@ using ..EpiAwareUtils: HalfNormal
using Turing, Distributions, DocStringExtensions

#Export models
export RandomWalk, AR, DiffLatentModel
export RandomWalk, AR, DiffLatentModel, BroadcastLatentModel

# Export broadcast rules
export RepeatEach, RepeatBlock

# Export helper functions
export dayofweek, weekly

include("docstrings.jl")
include("randomwalk.jl")
include("autoregressive.jl")
include("difflatentmodel.jl")
include("broadcast/latentmodel.jl")
include("broadcast/rules.jl")
include("broadcast/helpers.jl")
include("utils.jl")

end
25 changes: 25 additions & 0 deletions EpiAware/src/EpiLatentModels/broadcast/helpers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
@doc raw"
Constructs a `BroadcastLatentModel` appropriate for modelling the day of the week for a given `AbstractLatentModel`.

# Arguments
- `model::AbstractLatentModel`: The latent model to be repeated.

# Returns
- `BroadcastLatentModel`: The broadcast latent model.
"
function dayofweek(model::AbstractLatentModel)
return BroadcastLatentModel(model, 7, RepeatEach())
end

@doc raw"
Constructs a `BroadcastLatentModel` appropriate for modelling piecewise constant weekly processes for a given `AbstractLatentModel`.

# Arguments
- `model::AbstractLatentModel`: The latent model to be repeated.

# Returns
- `BroadcastLatentModel`: The broadcast latent model.
"
function weekly(model::AbstractLatentModel)
return BroadcastLatentModel(model, 7, RepeatBlock())
end
61 changes: 61 additions & 0 deletions EpiAware/src/EpiLatentModels/broadcast/latentmodel.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
@doc raw"
The `BroadcastLatentModel` struct represents a latent model that supports broadcasting of latent periods.

## Constructors
- `BroadcastLatentModel(;model::M; period::Int, broadcast_rule::B)`: Constructs a `BroadcastLatentModel` with the given `model`, `period`, and `broadcast_rule`.
- `BroadcastLatentModel(model::M, period::Int, broadcast_rule::B)`: An alternative constructor that allows the `model`, `period`, and `broadcast_rule` to be specified without keyword arguments.

## Examples
```julia
using EpiAware, Turing
each_model = BroadcastLatentModel(RandomWalk(), 7, RepeatEach())
gen_each_model = generate_latent(each_model, 10)
rand(gen_each_model)

block_model = BroadcastLatentModel(RandomWalk(), 3, RepeatBlock())
gen_block_model = generate_latent(block_model, 10)
rand(gen_block_model)
```
"
struct BroadcastLatentModel{
M <: AbstractLatentModel, P <: Integer, B <: AbstractBroadcastRule} <:
AbstractLatentModel
"The underlying latent model."
model::M
"The period of the broadcast."
period::P
"The broadcast rule to be applied."
broadcast_rule::B

function BroadcastLatentModel(model::M; period::Integer,
broadcast_rule::B) where {M <: AbstractLatentModel, B <: AbstractBroadcastRule}
BroadcastLatentModel(model, period, broadcast_rule)
end

function BroadcastLatentModel(model::M, period::Integer,
broadcast_rule::B) where {M <: AbstractLatentModel, B <: AbstractBroadcastRule}
@assert period>0 "period must be greater than 0"
new{typeof(model), typeof(period), typeof(broadcast_rule)}(
model, period, broadcast_rule)
end
end

@doc raw"
Generates latent periods using the specified `model` and `n` number of samples.

## Arguments
- `model::BroadcastLatentModel`: The broadcast latent model.
- `n::Any`: The number of samples to generate.

## Returns
- `broadcasted_latent`: The generated broadcasted latent periods.
- `latent_period_aux...`: Additional auxiliary information about the latent periods.

"
@model function EpiAwareBase.generate_latent(model::BroadcastLatentModel, n)
seabbs marked this conversation as resolved.
Show resolved Hide resolved
m = broadcast_n(model.broadcast_rule, n, model.period)
@submodel latent_period, latent_period_aux = generate_latent(model.model, m)
broadcasted_latent = broadcast_rule(
model.broadcast_rule, latent_period, n, model.period)
return broadcasted_latent, (; latent_period_aux...)
end
98 changes: 98 additions & 0 deletions EpiAware/src/EpiLatentModels/broadcast/rules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
@doc raw"
`RepeatEach` is a struct that represents a broadcasting rule. It is a subtype of `AbstractBroadcastRule`.

It repeats the latent process at each period. An example of this rule is to repeat the latent process at each day of the week (though for this we also provide the `dayofweek` helper function).

## Examples
```julia
using EpiAware
rule = RepeatEach()
latent = [1, 2, 3]
n = 10
period = 2
broadcast_rule(rule, latent, n, period)
```
"
struct RepeatEach <: AbstractBroadcastRule end

@doc raw"
A function that returns the length of the latent periods to generate using the `RepeatEach` rule which is equal to the period.

## Arguments
- `rule::RepeatEach`: The broadcasting rule.
- `n`: The number of samples to generate.
- `period`: The period of the broadcast.

## Returns
- `m`: The length of the latent periods to generate.
"
function EpiAwareBase.broadcast_n(::RepeatEach, n, period)
m = period
return m
end

@doc raw"
`broadcast_rule` is a function that applies the `RepeatEach` rule to the latent process `latent` to generate `n` samples.

## Arguments
- `rule::RepeatEach`: The broadcasting rule.
- `latent::Vector`: The latent process.
- `n`: The number of samples to generate.
- `period`: The period of the broadcast.

## Returns
- `latent`: The generated broadcasted latent periods.
"
function EpiAwareBase.broadcast_rule(::RepeatEach, latent, n, period)
@assert length(latent)==period "length(latent) must be equal to period"
broadcast_latent = repeat(latent, outer = ceil(Int, n / period))
return broadcast_latent[1:n]
end

@doc raw"
`RepeatBlock` is a struct that represents a broadcasting rule. It is a subtype of `AbstractBroadcastRule`.

It repeats the latent process in blocks of size `period`. An example of this rule is to repeat the latent process in blocks of size 7 to model a weekly process (though for this we also provide the `weekly` helper function).

## Examples
```julia
using EpiAware
rule = RepeatBlock()
latent = [1, 2, 3, 4, 5]
n = 10
period = 2
broadcast_rule(rule, latent, n, period)
```
"
struct RepeatBlock <: AbstractBroadcastRule end

@doc raw"
A function that returns the length of the latent periods to generate using the `RepeatBlock` rule which is equal n divided by the period and rounded up to the nearest integer.

## Arguments
- `rule::RepeatBlock`: The broadcasting rule.
- `n`: The number of samples to generate.
- `period`: The period of the broadcast.
"
function EpiAwareBase.broadcast_n(::RepeatBlock, n, period)
m = ceil(Int, n / period)
return m
end

@doc raw"
`broadcast_rule` is a function that applies the `RepeatBlock` rule to the latent process `latent` to generate `n` samples.

## Arguments
- `rule::RepeatBlock`: The broadcasting rule.
- `latent::Vector`: The latent process.
- `n`: The number of samples to generate.
- `period`: The period of the broadcast.

## Returns
- `latent`: The generated broadcasted latent periods.
"
function EpiAwareBase.broadcast_rule(::RepeatBlock, latent, n, period)
@assert n<=period * length(latent) "n must be less than or equal to period * length(latent)"
broadcast_latent = [latent[j] for j in 1:length(latent) for i in 1:period]
return broadcast_latent[1:n]
end
20 changes: 6 additions & 14 deletions EpiAware/test/EpiAwareBase/functions.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,13 @@
@testitem "generate_latent_infs function: default" begin
latent_model = [0.1, 0.2, 0.3]
init_incidence = 10.0

struct TestEpiModel <: EpiAware.EpiAwareBase.AbstractEpiModel
@testitem "Testing broadcast_n default" begin
struct TestBroadcastModel <: EpiAware.EpiAwareBase.AbstractBroadcastRule
end

@test isnothing(generate_latent_infs(TestEpiModel(), latent_model))
@test isnothing(broadcast_n(TestBroadcastModel(), missing, missing, missing))
end

@testitem "Testing generate_observations default" begin
struct TestObsModel <: EpiAware.EpiAwareBase.AbstractObservationModel
@testitem "Testing broadcast_rule default" begin
struct TestBroadcastModel <: EpiAware.EpiAwareBase.AbstractBroadcastRule
end

@test try
generate_observations(TestObsModel(), missing, missing)
true
catch
false
end
@test isnothing(broadcast_rule(TestBroadcastModel(), missing, missing))
end
Loading
Loading