Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 246: Standarise outputs v2 #365

Closed
wants to merge 16 commits into from
4 changes: 2 additions & 2 deletions EpiAware/docs/src/showcase/replications/mishra-2020/index.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ plt_ar_sample = let
n_samples = 100
ar_mdl_samples = mapreduce(hcat, 1:n_samples) do _
θ = rand(ar_mdl) #Sample unconditionally the underlying parameters of the model
gen = generated_quantities(ar_mdl, θ)[1]
gen = generated_quantities(ar_mdl, θ)
end

plot(ar_mdl_samples,
Expand All @@ -157,7 +157,7 @@ let
n_samples = 100
ar_mdl_samples = mapreduce(hcat, 1:n_samples) do _
θ = rand(cond_ar_mdl) #Sample unconditionally the underlying parameters of the model
gen = generated_quantities(cond_ar_mdl, θ)[1]
gen = generated_quantities(cond_ar_mdl, θ)
end

plot(ar_mdl_samples,
Expand Down
6 changes: 3 additions & 3 deletions EpiAware/src/EpiAwareUtils/turing-methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,19 @@ A `DynamicPPPL.Model` object.
y_t, time_steps, epi_model::AbstractTuringEpiModel;
latent_model::AbstractTuringLatentModel, observation_model::AbstractTuringObservationModel)
# Latent process
@submodel prefix="latent" Z_t, latent_model_aux=generate_latent(
@submodel prefix="latent" Z_t=generate_latent(
latent_model, time_steps)

# Transform into infections
@submodel I_t = generate_latent_infs(epi_model, Z_t)

# Predictive distribution of ascertained cases
@submodel prefix="obs" generated_y_t, generated_y_t_aux=generate_observations(
@submodel prefix="obs" generated_y_t=generate_observations(
observation_model, y_t, I_t)

# Generate quantities
return (;
generated_y_t, I_t, Z_t, process_aux = merge(latent_model_aux, generated_y_t_aux))
generated_y_t, I_t, Z_t)
end

"""
Expand Down
22 changes: 10 additions & 12 deletions EpiAware/src/EpiLatentModels/manipulators/CombineLatentModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,27 +53,25 @@ Generate latent variables using a combination of multiple latent models.
- `n`: The number of latent variables to generate.

# Returns
- `combined_latents`: The combined latent variables generated from all the models.
- `latent_aux`: A tuple containing the auxiliary latent variables generated from each individual model.
- The combined latent variables generated from all the models.

# Example
"
@model function EpiAwareBase.generate_latent(latent_models::CombineLatentModels, n)
@submodel final_latent, latent_aux = _accumulate_latents(
latent_models.models, 1, fill(0.0, n), [], n, length(latent_models.models))
@submodel final_latent = _accumulate_latents(
latent_models.models, 1, fill(0.0, n), n, length(latent_models.models))

return final_latent, (; latent_aux...)
return final_latent
end

@model function _accumulate_latents(
models, index, acc_latent, acc_aux, n, n_models)
models, index, acc_latent, n, n_models)
if index > n_models
return acc_latent, (; acc_aux...)
return acc_latent
else
@submodel latent, new_aux = generate_latent(models[index], n)
@submodel updated_latent, updated_aux = _accumulate_latents(
models, index + 1, acc_latent .+ latent,
(; acc_aux..., new_aux...), n, n_models)
return updated_latent, (; updated_aux...)
@submodel latent = generate_latent(models[index], n)
@submodel updated_latent = _accumulate_latents(
models, index + 1, acc_latent .+ latent, n, n_models)
return updated_latent
end
end
23 changes: 12 additions & 11 deletions EpiAware/src/EpiLatentModels/manipulators/ConcatLatentModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,22 +110,23 @@ Generate latent variables by concatenating multiple latent models.
@assert all(x -> x > 0, dims) "Non-positive dimensions are not allowed"
@assert sum(dims)==n "Sum of dimensions must be equal to the dimension of the latent variables"

@submodel final_latent, latent_aux = _concat_latents(
latent_models.models, 1, Real[], [], dims, latent_models.no_models)
@submodel final_latent = _concat_latents(
latent_models.models, 1, nothing, dims, latent_models.no_models)

return final_latent, (; latent_aux...)
return final_latent
end

@model function _concat_latents(
models, index::Int, acc_latent::AbstractVector{<:Real}, acc_aux,
dims::AbstractVector{<:Int}, n_models::Int)
models, index::Int, acc_latent, dims::AbstractVector{<:Int}, n_models::Int)
if index > n_models
return acc_latent, (; acc_aux...)
return acc_latent
else
@submodel latent, new_aux = generate_latent(models[index], dims[index])
@submodel updated_latent, updated_aux = _concat_latents(
models, index + 1, vcat(acc_latent, latent),
(; acc_aux..., new_aux...), dims, n_models)
return updated_latent, (; updated_aux...)
@submodel latent = generate_latent(models[index], dims[index])

acc_latent = isnothing(acc_latent) ? latent : vcat(acc_latent, latent)
@submodel updated_latent = _concat_latents(
models, index + 1, acc_latent, dims, n_models
)
return updated_latent
end
end
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ Generates latent periods using the specified `model` and `n` number of samples.
"
@model function EpiAwareBase.generate_latent(model::BroadcastLatentModel, n)
m = broadcast_n(model.broadcast_rule, n, model.period)
@submodel latent_period, latent_period_aux = generate_latent(model.model, m)
@submodel latent_period = generate_latent(model.model, m)
broadcasted_latent = broadcast_rule(
model.broadcast_rule, latent_period, n, model.period)
return broadcasted_latent, (; latent_period_aux...)
return broadcasted_latent
end
3 changes: 1 addition & 2 deletions EpiAware/src/EpiLatentModels/models/AR.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ Generate a latent AR series.

# Returns
- `ar::Vector{Float64}`: The generated AR series.
- `params::NamedTuple`: A named tuple containing the generated parameters (`σ_AR`, `ar_init`, `damp_AR`).

# Notes
- The length of `damp_prior` and `init_prior` must be the same.
Expand All @@ -91,5 +90,5 @@ Generate a latent AR series.
ar[t] = damp_AR' * ar[(t - p):(t - 1)] + σ_AR * ϵ_t[t - p]
end

return ar, (; σ_AR, ar_init, damp_AR)
return ar
end
3 changes: 1 addition & 2 deletions EpiAware/src/EpiLatentModels/models/HierarchicalNormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,11 @@ Generate latent variables from the hierarchical normal distribution.

# Returns
- `η_t`: Generated latent variables.
- `std`: Standard deviation used in the generation.
"
@model function EpiAwareBase.generate_latent(obs_model::HierarchicalNormal, n)
std ~ obs_model.std_prior
ϵ_t ~ MvNormal(Diagonal(Fill(one(eltype(std)), n)))

η_t = obs_model.mean .+ std .* ϵ_t
return η_t, (; std = std)
return η_t
end
6 changes: 2 additions & 4 deletions EpiAware/src/EpiLatentModels/models/Intercept.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,10 @@ Generate a latent intercept series.
# Returns

- `intercept::Vector{Float64}`: The generated intercept series.
- `metadata::NamedTuple`: A named tuple containing the intercept value.
"
@model function EpiAwareBase.generate_latent(latent_model::Intercept, n)
intercept ~ latent_model.intercept_prior
return fill(intercept, n), (; intercept = intercept)
return fill(intercept, n)
end

@doc raw"
Expand Down Expand Up @@ -71,8 +70,7 @@ Generate a latent intercept series with a fixed intercept value.

# Returns
- `latent_vars`: An array of length `n` filled with the fixed intercept value.
- `metadata`: A named tuple containing the intercept value.
"
@model function EpiAwareBase.generate_latent(latent_model::FixedIntercept, n)
return fill(latent_model.intercept, n), (; intercept = latent_model.intercept)
return fill(latent_model.intercept, n)
end
2 changes: 1 addition & 1 deletion EpiAware/src/EpiLatentModels/models/RandomWalk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,5 +94,5 @@ Z_t, _ = generated_quantities(rw_model, θ)
for t in 2:n
rw[t] = rw[t - 1] + σ_RW * ϵ_t[t]
end
return rw, (; σ_RW, rw_init)
return rw
end
4 changes: 2 additions & 2 deletions EpiAware/src/EpiLatentModels/modifiers/DiffLatentModel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,9 @@ Z_t
@assert n>d "n must be longer than d"
latent_init ~ latent_model.init_prior

@submodel diff_latent, diff_latent_aux = generate_latent(latent_model.model, n - d)
@submodel diff_latent = generate_latent(latent_model.model, n - d)

return _combine_diff(latent_init, diff_latent, d), (; latent_init, diff_latent_aux...)
return _combine_diff(latent_init, diff_latent, d)
end

function _combine_diff(init, diff, d)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,11 @@ Generate latent variables using the specified `TransformLatentModel`.
- `n`: The number of latent variables to generate.

# Returns
- `transformed`: The transformed latent variables.
- `latent_aux`: Additional auxiliary variables generated by the underlying latent model.
- The transformed latent variables.

"""
@model function EpiAwareBase.generate_latent(model::TransformLatentModel, n)
@submodel untransformed, latent_aux = generate_latent(model.model, n)
@submodel untransformed = generate_latent(model.model, n)
latent = model.trans_function(untransformed)
return latent, (; latent_aux)
return latent
end
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ It dispatches to the `observation_error` function to generate the observation er
y_t[i] ~ observation_error(obs_model, pad_Y_t[i], priors...)
end

return y_t, priors
return y_t
end

@doc raw"
Expand Down
5 changes: 2 additions & 3 deletions EpiAware/src/EpiObsModels/StackObservationModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,10 @@ Generate observations from a stack of observation models. Assumes a 1 to 1 mappi
@assert obs_model.model_names==keys(y_t) .|> string |> collect "The model names must match the keys of the observation datasets."
@assert keys(y_t)==keys(Y_t) "The keys of the observed and true values must match."

obs = ()
for (model, model_name) in zip(obs_model.models, obs_model.model_names)
obs = map(zip(obs_model.models, obs_model.model_names)) do (model, model_name)
@submodel obs_tmp = generate_observations(
model, y_t[Symbol(model_name)], Y_t[Symbol(model_name)])
obs = obs..., obs_tmp...
return obs_tmp
end
return obs
end
Expand Down
5 changes: 2 additions & 3 deletions EpiAware/src/EpiObsModels/modifiers/LatentDelay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ Generates observations based on the `LatentDelay` observation model.

## Returns
- `y_t`: The updated observations.
- `obs_aux`: Additional observation-related variables.

"
@model function EpiAwareBase.generate_observations(obs_model::LatentDelay, y_t, Y_t)
Expand All @@ -69,8 +68,8 @@ Generates observations based on the `LatentDelay` observation model.
expected_obs = kernel * trunc_Y_t
complete_obs = vcat(fill(missing, length(obs_model.pmf) + first_Y_t - 2), expected_obs)

@submodel y_t, obs_aux = generate_observations(
@submodel y_t = generate_observations(
obs_model.model, y_t, complete_obs)

return y_t, (; obs_aux...)
return y_t
end
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,11 @@ Generates observations based on the `LatentDelay` observation model.
- `obs_aux`: Additional observation-related variables.
"
@model function EpiAwareBase.generate_observations(obs_model::Ascertainment, y_t, Y_t)
@submodel expected_obs_mod, expected_aux = generate_latent(
@submodel expected_obs_mod = generate_latent(
obs_model.latent_model, length(Y_t))

expected_obs = Y_t .* obs_model.link(expected_obs_mod)

@submodel y_t, obs_aux = generate_observations(obs_model.model, y_t, expected_obs)
return y_t, (; expected_obs, expected_obs_mod, expected_aux..., obs_aux...)
@submodel y_t = generate_observations(obs_model.model, y_t, expected_obs)
return y_t
end
18 changes: 8 additions & 10 deletions EpiAware/test/EpiLatentModels/manipulators/CombineLatentModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ end

@model function EpiAware.EpiAwareBase.generate_latent(model::NextScale, n::Int)
scale = 2
return scale_vect = fill(scale, n), (; nscale = scale)
return fill(scale, n)
end

s = FixedIntercept(1)
Expand All @@ -44,15 +44,13 @@ end
comb_model_out = comb_model()

@test typeof(comb_model) <: DynamicPPL.Model
@test length(comb_model_out[1]) == 5
@test all(comb_model_out[1] .== fill(3.0, 5))
@test comb_model_out[2].intercept == 1.0
@test comb_model_out[2].nscale == 2.0
@test length(comb_model_out) == 5
@test all(comb_model_out .== fill(3.0, 5))
end

@testitem "CombineLatentModels generate_latent method works as expected: Intercept + AR" begin
using Turing
using Distributions: Normal
using Distributions
using HypothesisTests: ExactOneSampleKSTest, pvalue
using LinearAlgebra: Diagonal

Expand All @@ -65,21 +63,21 @@ end
# Test constant if conditioning on zero residuals
no_residual_mdl = comb_model |
(var"Combine.2.ϵ_t" = zeros(n - 1), var"Combine.2.ar_init" = [0.0])
y_const, θ_const = no_residual_mdl()
y_const = no_residual_mdl()

@test all(y_const .== fill(θ_const.intercept, n))
@test all(y_const .== y_const[1])

# Check against linear regression by conditioning on normal residuals
# Generate data
fix_intercept = 0.5
normal_res_mdl = comb_model |
(var"Combine.2.damp_AR" = [0.0], var"Combine.2.σ_AR" = 1.0,
var"Combine.1.intercept" = fix_intercept)
y, θ = normal_res_mdl()
y = normal_res_mdl()

# Fit no-slope linear regression as a model test
@model function no_slope_linear_regression(y)
@submodel y_pred, θ = generate_latent(comb, n)
@submodel y_pred = generate_latent(comb, n)
y ~ MvNormal(y_pred, Diagonal(ones(n)))
end

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ end

@model function EpiAware.EpiAwareBase.generate_latent(model::NextScale, n::Int)
scale = 2
return scale_vect = fill(scale, n), (; nscale = scale)
return scale_vect = fill(scale, n)
end

s = FixedIntercept(1)
Expand All @@ -62,8 +62,6 @@ end
con_model_out = con_model()

@test typeof(con_model) <: DynamicPPL.Model
@test length(con_model_out[1]) == 5
@test all(con_model_out[1] .== vcat(fill(1.0, 3), fill(2.0, 2)))
@test con_model_out[2].intercept == 1.0
@test con_model_out[2].nscale == 2.0
@test length(con_model_out) == 5
@test all(con_model_out .== vcat(fill(1.0, 3), fill(2.0, 2)))
end
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,10 @@ end
rand_model = rand(broadcasted_model)

@test length(rand_model.ϵ_t) == 2
fix_model = fix(broadcasted_model, (σ_RW = 1.0, rw_init = 1.0))
sample_model = sample(fix_model, Prior(), 100; progress = false)
gen_model = sample_model |>
chn -> mapreduce(hcat, generated_quantities(fix_model, chn)) do gen
gen[1]
end

@testset "Testing gen_model matrix" begin
for col in eachcol(gen_model)
unique_values = unique(col)
@test length(unique_values) == 2

@test count(x -> x == unique_values[1], col) == 5
@test count(x -> x == unique_values[2], col) == 5
end
end
fix_model = fix(
broadcasted_model,
(σ_RW = 2.0, rw_init = 1.0, ϵ_t = [1, 2])
)
out = fix_model()
@test out == vcat(fill(3.0, 5), fill(7.0, 5))
end
2 changes: 1 addition & 1 deletion EpiAware/test/EpiLatentModels/models/AR.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ end
n_samples = 100
samples = sample(fixed_model, Prior(), n_samples; progress = false) |>
chn -> mapreduce(vcat, generated_quantities(fixed_model, chn)) do gen
gen[1]
gen
end

theoretical_mean = 0.0
Expand Down
6 changes: 4 additions & 2 deletions EpiAware/test/EpiLatentModels/models/FixedIntercept.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ end
int = FixedIntercept(0.1)
int_model = generate_latent(int, 10)
int_model_out = int_model()
@test length(int_model_out[1]) == 10
@test all(x -> x == int_model_out[2].intercept, int_model_out[1])
rand_model = rand(int_model)
@test rand_model == NamedTuple()
@test length(int_model_out) == 10
@test all(x -> x == 0.1, int_model_out)
end
Loading
Loading