Skip to content

Commit

Permalink
Merge branch 'main' into optimise-mvnormal-scan
Browse files Browse the repository at this point in the history
  • Loading branch information
seabbs authored Jul 17, 2024
2 parents c5532ad + 2ab2cf7 commit 26756df
Show file tree
Hide file tree
Showing 33 changed files with 331 additions and 131 deletions.
4 changes: 2 additions & 2 deletions EpiAware/docs/src/showcase/replications/mishra-2020/index.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ plt_ar_sample = let
n_samples = 100
ar_mdl_samples = mapreduce(hcat, 1:n_samples) do _
θ = rand(ar_mdl) #Sample unconditionally the underlying parameters of the model
gen = generated_quantities(ar_mdl, θ)[1]
gen = generated_quantities(ar_mdl, θ)
end

plot(ar_mdl_samples,
Expand All @@ -157,7 +157,7 @@ let
n_samples = 100
ar_mdl_samples = mapreduce(hcat, 1:n_samples) do _
θ = rand(cond_ar_mdl) #Sample unconditionally the underlying parameters of the model
gen = generated_quantities(cond_ar_mdl, θ)[1]
gen = generated_quantities(cond_ar_mdl, θ)
end

plot(ar_mdl_samples,
Expand Down
6 changes: 3 additions & 3 deletions EpiAware/src/EpiAwareUtils/turing-methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,19 @@ A `DynamicPPPL.Model` object.
y_t, time_steps, epi_model::AbstractTuringEpiModel;
latent_model::AbstractTuringLatentModel, observation_model::AbstractTuringObservationModel)
# Latent process
@submodel prefix="latent" Z_t, latent_model_aux=generate_latent(
@submodel prefix="latent" Z_t=generate_latent(
latent_model, time_steps)

# Transform into infections
@submodel I_t = generate_latent_infs(epi_model, Z_t)

# Predictive distribution of ascertained cases
@submodel prefix="obs" generated_y_t, generated_y_t_aux=generate_observations(
@submodel prefix="obs" generated_y_t=generate_observations(
observation_model, y_t, I_t)

# Generate quantities
return (;
generated_y_t, I_t, Z_t, process_aux = merge(latent_model_aux, generated_y_t_aux))
generated_y_t, I_t, Z_t)
end

"""
Expand Down
22 changes: 10 additions & 12 deletions EpiAware/src/EpiLatentModels/manipulators/CombineLatentModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,27 +53,25 @@ Generate latent variables using a combination of multiple latent models.
- `n`: The number of latent variables to generate.
# Returns
- `combined_latents`: The combined latent variables generated from all the models.
- `latent_aux`: A tuple containing the auxiliary latent variables generated from each individual model.
- The combined latent variables generated from all the models.
# Example
"
@model function EpiAwareBase.generate_latent(latent_models::CombineLatentModels, n)
@submodel final_latent, latent_aux = _accumulate_latents(
latent_models.models, 1, fill(0.0, n), [], n, length(latent_models.models))
@submodel final_latent = _accumulate_latents(
latent_models.models, 1, fill(0.0, n), n, length(latent_models.models))

return final_latent, (; latent_aux...)
return final_latent
end

@model function _accumulate_latents(
models, index, acc_latent, acc_aux, n, n_models)
models, index, acc_latent, n, n_models)
if index > n_models
return acc_latent, (; acc_aux...)
return acc_latent
else
@submodel latent, new_aux = generate_latent(models[index], n)
@submodel updated_latent, updated_aux = _accumulate_latents(
models, index + 1, acc_latent .+ latent,
(; acc_aux..., new_aux...), n, n_models)
return updated_latent, (; updated_aux...)
@submodel latent = generate_latent(models[index], n)
@submodel updated_latent = _accumulate_latents(
models, index + 1, acc_latent .+ latent, n, n_models)
return updated_latent
end
end
23 changes: 12 additions & 11 deletions EpiAware/src/EpiLatentModels/manipulators/ConcatLatentModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,22 +110,23 @@ Generate latent variables by concatenating multiple latent models.
@assert all(x -> x > 0, dims) "Non-positive dimensions are not allowed"
@assert sum(dims)==n "Sum of dimensions must be equal to the dimension of the latent variables"

@submodel final_latent, latent_aux = _concat_latents(
latent_models.models, 1, Real[], [], dims, latent_models.no_models)
@submodel final_latent = _concat_latents(
latent_models.models, 1, nothing, dims, latent_models.no_models)

return final_latent, (; latent_aux...)
return final_latent
end

@model function _concat_latents(
models, index::Int, acc_latent::AbstractVector{<:Real}, acc_aux,
dims::AbstractVector{<:Int}, n_models::Int)
models, index::Int, acc_latent, dims::AbstractVector{<:Int}, n_models::Int)
if index > n_models
return acc_latent, (; acc_aux...)
return acc_latent
else
@submodel latent, new_aux = generate_latent(models[index], dims[index])
@submodel updated_latent, updated_aux = _concat_latents(
models, index + 1, vcat(acc_latent, latent),
(; acc_aux..., new_aux...), dims, n_models)
return updated_latent, (; updated_aux...)
@submodel latent = generate_latent(models[index], dims[index])

acc_latent = isnothing(acc_latent) ? latent : vcat(acc_latent, latent)
@submodel updated_latent = _concat_latents(
models, index + 1, acc_latent, dims, n_models
)
return updated_latent
end
end
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ Generates latent periods using the specified `model` and `n` number of samples.
"
@model function EpiAwareBase.generate_latent(model::BroadcastLatentModel, n)
m = broadcast_n(model.broadcast_rule, n, model.period)
@submodel latent_period, latent_period_aux = generate_latent(model.model, m)
@submodel latent_period = generate_latent(model.model, m)
broadcasted_latent = broadcast_rule(
model.broadcast_rule, latent_period, n, model.period)
return broadcasted_latent, (; latent_period_aux...)
return broadcasted_latent
end
3 changes: 1 addition & 2 deletions EpiAware/src/EpiLatentModels/models/AR.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ Generate a latent AR series.
# Returns
- `ar::Vector{Float64}`: The generated AR series.
- `params::NamedTuple`: A named tuple containing the generated parameters (`σ_AR`, `ar_init`, `damp_AR`).
# Notes
- The length of `damp_prior` and `init_prior` must be the same.
Expand All @@ -84,7 +83,7 @@ Generate a latent AR series.

ar = accumulate_scan(ARStep(damp_AR), ar_init, σ_AR * ϵ_t)

return ar, (; σ_AR, ar_init, damp_AR)
return ar
end

@doc raw"
Expand Down
3 changes: 1 addition & 2 deletions EpiAware/src/EpiLatentModels/models/HierarchicalNormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,11 @@ Generate latent variables from the hierarchical normal distribution.
# Returns
- `η_t`: Generated latent variables.
- `std`: Standard deviation used in the generation.
"
@model function EpiAwareBase.generate_latent(obs_model::HierarchicalNormal, n)
std ~ obs_model.std_prior
ϵ_t ~ filldist(Normal(), n)

η_t = obs_model.mean .+ std .* ϵ_t
return η_t, (; std = std)
return η_t
end
6 changes: 2 additions & 4 deletions EpiAware/src/EpiLatentModels/models/Intercept.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,10 @@ Generate a latent intercept series.
# Returns
- `intercept::Vector{Float64}`: The generated intercept series.
- `metadata::NamedTuple`: A named tuple containing the intercept value.
"
@model function EpiAwareBase.generate_latent(latent_model::Intercept, n)
intercept ~ latent_model.intercept_prior
return fill(intercept, n), (; intercept = intercept)
return fill(intercept, n)
end

@doc raw"
Expand Down Expand Up @@ -71,8 +70,7 @@ Generate a latent intercept series with a fixed intercept value.
# Returns
- `latent_vars`: An array of length `n` filled with the fixed intercept value.
- `metadata`: A named tuple containing the intercept value.
"
@model function EpiAwareBase.generate_latent(latent_model::FixedIntercept, n)
return fill(latent_model.intercept, n), (; intercept = latent_model.intercept)
return fill(latent_model.intercept, n)
end
4 changes: 2 additions & 2 deletions EpiAware/src/EpiLatentModels/modifiers/DiffLatentModel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,9 @@ Z_t
@assert n>d "n must be longer than d"
latent_init ~ latent_model.init_prior

@submodel diff_latent, diff_latent_aux = generate_latent(latent_model.model, n - d)
@submodel diff_latent = generate_latent(latent_model.model, n - d)

return _combine_diff(latent_init, diff_latent, d), (; latent_init, diff_latent_aux...)
return _combine_diff(latent_init, diff_latent, d)
end

function _combine_diff(init, diff, d)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,11 @@ Generate latent variables using the specified `TransformLatentModel`.
- `n`: The number of latent variables to generate.
# Returns
- `transformed`: The transformed latent variables.
- `latent_aux`: Additional auxiliary variables generated by the underlying latent model.
- The transformed latent variables.
"""
@model function EpiAwareBase.generate_latent(model::TransformLatentModel, n)
@submodel untransformed, latent_aux = generate_latent(model.model, n)
@submodel untransformed = generate_latent(model.model, n)
latent = model.trans_function(untransformed)
return latent, (; latent_aux)
return latent
end
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ It dispatches to the `observation_error` function to generate the observation er
y_t[i] ~ observation_error(obs_model, pad_Y_t[i], priors...)
end

return y_t, priors
return y_t
end

@doc raw"
Expand Down
5 changes: 2 additions & 3 deletions EpiAware/src/EpiObsModels/StackObservationModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,10 @@ Generate observations from a stack of observation models. Assumes a 1 to 1 mappi
@assert obs_model.model_names==keys(y_t) .|> string |> collect "The model names must match the keys of the observation datasets."
@assert keys(y_t)==keys(Y_t) "The keys of the observed and true values must match."

obs = ()
for (model, model_name) in zip(obs_model.models, obs_model.model_names)
obs = map(zip(obs_model.models, obs_model.model_names)) do (model, model_name)
@submodel obs_tmp = generate_observations(
model, y_t[Symbol(model_name)], Y_t[Symbol(model_name)])
obs = obs..., obs_tmp...
return obs_tmp
end
return obs
end
Expand Down
5 changes: 2 additions & 3 deletions EpiAware/src/EpiObsModels/modifiers/LatentDelay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ Generates observations based on the `LatentDelay` observation model.
## Returns
- `y_t`: The updated observations.
- `obs_aux`: Additional observation-related variables.
"
@model function EpiAwareBase.generate_observations(obs_model::LatentDelay, y_t, Y_t)
Expand All @@ -69,8 +68,8 @@ Generates observations based on the `LatentDelay` observation model.
expected_obs = kernel * trunc_Y_t
complete_obs = vcat(fill(missing, length(obs_model.pmf) + first_Y_t - 2), expected_obs)

@submodel y_t, obs_aux = generate_observations(
@submodel y_t = generate_observations(
obs_model.model, y_t, complete_obs)

return y_t, (; obs_aux...)
return y_t
end
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,11 @@ Generates observations based on the `LatentDelay` observation model.
- `obs_aux`: Additional observation-related variables.
"
@model function EpiAwareBase.generate_observations(obs_model::Ascertainment, y_t, Y_t)
@submodel expected_obs_mod, expected_aux = generate_latent(
@submodel expected_obs_mod = generate_latent(
obs_model.latent_model, length(Y_t))

expected_obs = Y_t .* obs_model.link(expected_obs_mod)

@submodel y_t, obs_aux = generate_observations(obs_model.model, y_t, expected_obs)
return y_t, (; expected_obs, expected_obs_mod, expected_aux..., obs_aux...)
@submodel y_t = generate_observations(obs_model.model, y_t, expected_obs)
return y_t
end
18 changes: 8 additions & 10 deletions EpiAware/test/EpiLatentModels/manipulators/CombineLatentModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ end

@model function EpiAware.EpiAwareBase.generate_latent(model::NextScale, n::Int)
scale = 2
return scale_vect = fill(scale, n), (; nscale = scale)
return fill(scale, n)
end

s = FixedIntercept(1)
Expand All @@ -44,15 +44,13 @@ end
comb_model_out = comb_model()

@test typeof(comb_model) <: DynamicPPL.Model
@test length(comb_model_out[1]) == 5
@test all(comb_model_out[1] .== fill(3.0, 5))
@test comb_model_out[2].intercept == 1.0
@test comb_model_out[2].nscale == 2.0
@test length(comb_model_out) == 5
@test all(comb_model_out .== fill(3.0, 5))
end

@testitem "CombineLatentModels generate_latent method works as expected: Intercept + AR" begin
using Turing
using Distributions: Normal
using Distributions
using HypothesisTests: ExactOneSampleKSTest, pvalue
using LinearAlgebra: Diagonal

Expand All @@ -65,21 +63,21 @@ end
# Test constant if conditioning on zero residuals
no_residual_mdl = comb_model |
(var"Combine.2.ϵ_t" = zeros(n - 1), var"Combine.2.ar_init" = [0.0])
y_const, θ_const = no_residual_mdl()
y_const = no_residual_mdl()

@test all(y_const .== fill(θ_const.intercept, n))
@test all(y_const .== y_const[1])

# Check against linear regression by conditioning on normal residuals
# Generate data
fix_intercept = 0.5
normal_res_mdl = comb_model |
(var"Combine.2.damp_AR" = [0.0], var"Combine.2.σ_AR" = 1.0,
var"Combine.1.intercept" = fix_intercept)
y, θ = normal_res_mdl()
y = normal_res_mdl()

# Fit no-slope linear regression as a model test
@model function no_slope_linear_regression(y)
@submodel y_pred, θ = generate_latent(comb, n)
@submodel y_pred = generate_latent(comb, n)
y ~ MvNormal(y_pred, Diagonal(ones(n)))
end

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ end

@model function EpiAware.EpiAwareBase.generate_latent(model::NextScale, n::Int)
scale = 2
return scale_vect = fill(scale, n), (; nscale = scale)
return scale_vect = fill(scale, n)
end

s = FixedIntercept(1)
Expand All @@ -62,8 +62,6 @@ end
con_model_out = con_model()

@test typeof(con_model) <: DynamicPPL.Model
@test length(con_model_out[1]) == 5
@test all(con_model_out[1] .== vcat(fill(1.0, 3), fill(2.0, 2)))
@test con_model_out[2].intercept == 1.0
@test con_model_out[2].nscale == 2.0
@test length(con_model_out) == 5
@test all(con_model_out .== vcat(fill(1.0, 3), fill(2.0, 2)))
end
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,14 @@ end
@testitem "generate_latent function with BroadcastLatentModel" begin
using Turing, DynamicPPL
model = BroadcastLatentModel(RandomWalk(), 5, RepeatBlock())
broadcasted_model = generate_latent(model, 10)
broadcasted_model = generate_latent(model, 15)
rand_model = rand(broadcasted_model)

@test length(rand_model.ϵ_t) == 1
fix_model = fix(broadcasted_model, (σ_RW = 1.0, rw_init = 1.0))
sample_model = sample(fix_model, Prior(), 100; progress = false)
gen_model = sample_model |>
chn -> mapreduce(hcat, generated_quantities(fix_model, chn)) do gen
gen[1]
end

@testset "Testing gen_model matrix" begin
for col in eachcol(gen_model)
unique_values = unique(col)
@test length(unique_values) == 2

@test count(x -> x == unique_values[1], col) == 5
@test count(x -> x == unique_values[2], col) == 5
end
end
@test length(rand_model.ϵ_t) == 2
fix_model = fix(
broadcasted_model,
(σ_RW = 2.0, rw_init = 1.0, ϵ_t = [1, 2])
)
out = fix_model()
@test out == vcat(fill(1.0, 5), fill(3.0, 5), fill(7.0, 5))
end
2 changes: 1 addition & 1 deletion EpiAware/test/EpiLatentModels/models/AR.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ end
n_samples = 100
samples = sample(fixed_model, Prior(), n_samples; progress = false) |>
chn -> mapreduce(vcat, generated_quantities(fixed_model, chn)) do gen
gen[1]
gen
end

theoretical_mean = 0.0
Expand Down
6 changes: 4 additions & 2 deletions EpiAware/test/EpiLatentModels/models/FixedIntercept.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ end
int = FixedIntercept(0.1)
int_model = generate_latent(int, 10)
int_model_out = int_model()
@test length(int_model_out[1]) == 10
@test all(x -> x == int_model_out[2].intercept, int_model_out[1])
rand_model = rand(int_model)
@test rand_model == NamedTuple()
@test length(int_model_out) == 10
@test all(x -> x == 0.1, int_model_out)
end
Loading

0 comments on commit 26756df

Please sign in to comment.