Skip to content

Commit

Permalink
refine methods
Browse files Browse the repository at this point in the history
  • Loading branch information
seabbs committed Mar 25, 2024
1 parent 1887336 commit f71e0f8
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 24 deletions.
4 changes: 2 additions & 2 deletions EpiAware/src/EpiLatentModels/broadcasthelpers.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Outer constructor for DayOfWeek
function DayOfWeek(model::BroadcastLatentModel)
function DayOfWeek(model::AbstractLatentModel)
return BroadcastLatentModel(model, 7, RepeatEach())
end

# Outer constructor for Weekly
function Weekly(model)
function Weekly(model::AbstractLatentModel)
return BroadcastLatentModel(model, 7, RepeatBlock())
end
53 changes: 31 additions & 22 deletions EpiAware/src/EpiLatentModels/broadcastlatentmodel.jl
Original file line number Diff line number Diff line change
@@ -1,35 +1,44 @@
struct RepeatEach end
struct RepeatBlock end
abstract type BroadcastRule end
struct RepeatEach <: BroadcastRule end
struct RepeatBlock <: BroadcastRule end

struct BroadcastLatentModel{M, B, P} <: AbstractLatentModel
struct BroadcastLatentModel{M <: AbstractLatentModel, P <: Int, B <: BroadcastRule} <:
AbstractLatentModel
model::M
period::Int
period::P
broadcast_rule::B
function BroadcastLatentModel(model::M, period::Int, broadcast_rule::B) where {M, B}
@assert period > 0 "period must be greater than 0"
new{M, B, typeof(period)}(model, period, broadcast_rule)
function BroadcastLatentModel(model::M; period::Int,
broadcast_rule::B) where {M <: AbstractLatentModel, B <: BroadcastRule}
@assert period>0 "period must be greater than 0"
new{typeof(model), typeof(period), typeof(broadcast_rule)}(
model, period, broadcast_rule)
end

function BroadcastLatentModel(model::M, period::Int,
broadcast_rule::B) where {M <: AbstractLatentModel, B <: BroadcastRule}
@assert period>0 "period must be greater than 0"
new{typeof(model), typeof(period), typeof(broadcast_rule)}(
model, period, broadcast_rule)
end
end

# Outer constructor for DayOfWeek
function DayOfWeek(model::BroadcastLatentModel)
return BroadcastLatentModel(model, 7, RepeatEach())
function broadcast_rule(::BroadcastRule)
error("broadcast_rule not implemented")
end

# Outer constructor for Weekly
function Weekly(model)
return BroadcastLatentModel(model, 7, RepeatBlock())
function broadcast_rule(::RepeatEach, latent, n, period)
latent = repeat(latent, outer = ceil(Int, n / period))
return latent[1:n]
end

function generate_latent(model::BroadcastLatentModel{<:Any, RepeatEach, <:Any}, n::Int)
latent_period = generate_latent(model.model, model.period)
broadcasted_latent = repeat(latent_period, outer = ceil(Int, n / model.period))
return broadcasted_latent[1:n]
function broadcast_rule(::RepeatBlock, latent, n, period)
indices = [ceil(Int, i / period) for i in 1:n]
return latent[indices]
end

function generate_latent(model::BroadcastLatentModel<:Any, RepeatBlock, <:Any}, n::Int)
latent_period = generate_latent(model.model, model.period)
indices = [ceil(Int, i / model.period) for i in 1:n]
broadcasted_latent = latent_period[indices]
return broadcasted_latent
@model function EpiAwareBase.generate_latent(model::BroadcastLatentModel, n)
@submodel latent_period, latent_period_aux = generate_latent(model.model, model.period)
broadcasted_latent = broadcast_rule(
model.broadcast_rule, latent_period, n, model.period)
return broadcasted_latent, (; latent_period_aux...)
end

0 comments on commit f71e0f8

Please sign in to comment.