-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
33 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,9 @@ | ||
# Outer constructor for DayOfWeek | ||
function DayOfWeek(model::BroadcastLatentModel) | ||
function DayOfWeek(model::AbstractLatentModel) | ||
return BroadcastLatentModel(model, 7, RepeatEach()) | ||
end | ||
|
||
# Outer constructor for Weekly | ||
function Weekly(model) | ||
function Weekly(model::AbstractLatentModel) | ||
return BroadcastLatentModel(model, 7, RepeatBlock()) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,35 +1,44 @@ | ||
struct RepeatEach end | ||
struct RepeatBlock end | ||
abstract type BroadcastRule end | ||
struct RepeatEach <: BroadcastRule end | ||
struct RepeatBlock <: BroadcastRule end | ||
|
||
struct BroadcastLatentModel{M, B, P} <: AbstractLatentModel | ||
struct BroadcastLatentModel{M <: AbstractLatentModel, P <: Int, B <: BroadcastRule} <: | ||
AbstractLatentModel | ||
model::M | ||
period::Int | ||
period::P | ||
broadcast_rule::B | ||
function BroadcastLatentModel(model::M, period::Int, broadcast_rule::B) where {M, B} | ||
@assert period > 0 "period must be greater than 0" | ||
new{M, B, typeof(period)}(model, period, broadcast_rule) | ||
function BroadcastLatentModel(model::M; period::Int, | ||
broadcast_rule::B) where {M <: AbstractLatentModel, B <: BroadcastRule} | ||
@assert period>0 "period must be greater than 0" | ||
new{typeof(model), typeof(period), typeof(broadcast_rule)}( | ||
model, period, broadcast_rule) | ||
end | ||
|
||
function BroadcastLatentModel(model::M, period::Int, | ||
broadcast_rule::B) where {M <: AbstractLatentModel, B <: BroadcastRule} | ||
@assert period>0 "period must be greater than 0" | ||
new{typeof(model), typeof(period), typeof(broadcast_rule)}( | ||
model, period, broadcast_rule) | ||
end | ||
end | ||
|
||
# Outer constructor for DayOfWeek | ||
function DayOfWeek(model::BroadcastLatentModel) | ||
return BroadcastLatentModel(model, 7, RepeatEach()) | ||
function broadcast_rule(::BroadcastRule) | ||
error("broadcast_rule not implemented") | ||
end | ||
|
||
# Outer constructor for Weekly | ||
function Weekly(model) | ||
return BroadcastLatentModel(model, 7, RepeatBlock()) | ||
function broadcast_rule(::RepeatEach, latent, n, period) | ||
latent = repeat(latent, outer = ceil(Int, n / period)) | ||
return latent[1:n] | ||
end | ||
|
||
function generate_latent(model::BroadcastLatentModel{<:Any, RepeatEach, <:Any}, n::Int) | ||
latent_period = generate_latent(model.model, model.period) | ||
broadcasted_latent = repeat(latent_period, outer = ceil(Int, n / model.period)) | ||
return broadcasted_latent[1:n] | ||
function broadcast_rule(::RepeatBlock, latent, n, period) | ||
indices = [ceil(Int, i / period) for i in 1:n] | ||
return latent[indices] | ||
end | ||
|
||
function generate_latent(model::BroadcastLatentModel<:Any, RepeatBlock, <:Any}, n::Int) | ||
latent_period = generate_latent(model.model, model.period) | ||
indices = [ceil(Int, i / model.period) for i in 1:n] | ||
broadcasted_latent = latent_period[indices] | ||
return broadcasted_latent | ||
@model function EpiAwareBase.generate_latent(model::BroadcastLatentModel, n) | ||
@submodel latent_period, latent_period_aux = generate_latent(model.model, model.period) | ||
broadcasted_latent = broadcast_rule( | ||
model.broadcast_rule, latent_period, n, model.period) | ||
return broadcasted_latent, (; latent_period_aux...) | ||
end |