-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Instability across chains #177
Comments
Thanks for this. Can you bit a bit more info into the title to help tighten up the focus of this issue? And actually in the body of the issue. Is it just the renewal model or all etc? |
@SamuelBrand1 has been investigating this and it looks like its due to an interaction between |
@SamuelBrand1 where are we on this? It would be good to update here if we have a sensible fix or workaround so we can build on it |
Switching from a random walk to an AR process seems to have solved this for at least the simple getting started example. |
That doesn't really sound like resolving the route issue if it was linked to conditioning or not? Can you summarise what you found in that investigation / likely PRs we might see to test it (i.e the end to end tests you started?) |
@SamuelBrand1 can you update on what the combinations were that were causing the problems, any scratch scripts you had for looking at that, and anything else you found in investigating. |
So the super scratch script is let
docs_dir = dirname(dirname(@__DIR__))
pkg_dir = dirname(docs_dir)
using Pkg: Pkg
Pkg.activate(docs_dir)
Pkg.develop(; path = pkg_dir)
Pkg.instantiate()
end;
begin
using EpiAware
using Turing
using Distributions
using StatsPlots
using Random
using DynamicPPL
using Statistics
using DataFramesMeta
using LinearAlgebra
using Transducers
using ReverseDiff
Random.seed!(1)
end
# Model definition
rwp = EpiAware.RandomWalk(
init_prior = Normal(),
std_prior = HalfNormal(0.05)
)
weekly_rwp = BroadcastLatentModel(rwp, 7, RepeatBlock())
ar = AR(; damp_priors = [truncated(Normal(0.8, .05), 0, 1), truncated(Normal(0.1, .05), 0, 1)],
std_prior = truncated(Normal(), 0, Inf),
init_priors = [Normal(-1.0, 0.1), Normal(-1.0, 0.1)],
)
weekly_ar = BroadcastLatentModel(ar, 7, RepeatBlock())
truth_GI = Gamma(2, 5)
model_data = EpiData(gen_distribution = truth_GI,
D_gen = 10.0)
log_I0_prior = Normal(log(100.0), 1.0)
epi = DirectInfections(model_data, log_I0_prior)
obs = LatentDelay(
NegativeBinomialError(cluster_factor_prior = Gamma(10, 0.05 / 10)),
fill(0.25, 4),
)
obs_direct = NegativeBinomialError(cluster_factor_prior = Gamma(10, 0.05 / 10))
##
n_latent_inference = 20
@model function test_latent_inference(rwp, n)
@submodel Z, _ = generate_latent(rwp, n)
y_t ~ MvNormal(Z, 0.1)
return (Z = Z, y_t = y_t)
end
chn_rad_f = sample(cond_mdl, NUTS(adtype = AutoReverseDiff(false)), 2_000)
chn_rad_t = sample(cond_mdl, NUTS(adtype = AutoReverseDiff(true)), 2_000)
##
function test_inference(adtype; n = 28)
mdl = test_latent_inference(weekly_rwp, n)
test_y_t = vcat(randn(n ÷ 2) .+ 10 , randn(n ÷ 2) .+ 20)
cond_mdl = mdl | (y_t = test_y_t,)
chn_fwd = sample(cond_mdl, NUTS(adtype = adtype), 2_000)
gens = generated_quantities(cond_mdl, chn_fwd)
plt = plot()
for gen in gens[1:10:end]
plot!(plt, gen.Z, lab = "", alpha = 0.1, c = :grey)
end
plot!(plt, test_y_t, lab = "Observed", lw = 2, c = :red)
plt
end
inference_plts = map([
AutoForwardDiff(), AutoReverseDiff(true), AutoReverseDiff(false)]) do adtype
test_inference(adtype)
end
plts = plot(inference_plts..., layout = (3, 1), size = (500, 400),
title = "rwp only: " .* ["Forward" "Reverse (true)" "Reverse (false)"] .* " AD")
##
n = 20
@model function test_latent_and_latent_inf_inference(rwp, epi, n)
@submodel Z, _ = generate_latent(rwp, n)
@submodel I_t = generate_latent_infs(epi, Z)
y_t ~ MvNormal(log.(I_t), 0.1)
return (Z = Z, y_t = y_t, I_t)
end
##
function test_inference_lat_and_lat_inf(adtype; n = 20)
mdl = test_latent_and_latent_inf_inference(ar, epi, n)
test_y_t = randn(n) .+ 1
cond_mdl = mdl | (y_t = test_y_t,)
chn = sample(cond_mdl, NUTS(adtype = adtype), 2_000)
gens = generated_quantities(cond_mdl, chn)
plt = plot()
for gen in gens[1:10:end]
plot!(plt, gen.I_t, lab = "", alpha = 0.1, c = :grey)
end
plot!(plt, exp.(test_y_t), lab = "Observed", lw = 2, c = :red)
plt
end
inference_plts_lat_and_lat_inf = map([
AutoForwardDiff(), AutoReverseDiff(true), AutoReverseDiff(false)]) do adtype
test_inference_lat_and_lat_inf(adtype)
end
plts = plot(inference_plts_lat_and_lat_inf...,
layout = (3, 1),
size = (500, 400),
title = "rwp + lat inf only: " .* ["Forward" "Reverse (true)" "Reverse (false)"] .*
" AD")
##
@model function test_inference_lat_and_lat_inf_obs(rwp, epi, obs, n, obs_yt = missing)
@submodel Z, _ = generate_latent(rwp, n)
@submodel I_t = generate_latent_infs(epi, Z)
@submodel gen_y_t, _ = generate_observations(obs, obs_yt, I_t)
return (; Z, I_t, gen_y_t)
end
##
function test_inference_lat_and_lat_inf_obs(adtype; n = 20)
test_y_t = randn(n) .+ 1 .|> exp .|> _y -> round(Int64, _y)
cond_mdl = test_inference_lat_and_lat_inf_obs(rwp, epi, obs, n, test_y_t)
# cond_mdl = mdl | (y_t = test_y_t,)
chn = sample(cond_mdl, NUTS(adtype = adtype), 2_000)
gens = generated_quantities(cond_mdl, chn)
plt = plot()
for gen in gens[1:10:end]
plot!(plt, gen.I_t, lab = "", alpha = 0.1, c = :grey)
end
scatter!(plt, test_y_t, lab = "Observed", lw = 2, c = :red)
hline!(plt, [(exp(1))], lab = "", c = :black)
plt
end
inference_plts_lat_and_lat_inf_obs = map([
AutoForwardDiff(), AutoReverseDiff(true), AutoReverseDiff(false)]) do adtype
test_inference_lat_and_lat_inf_obs(adtype)
end
plts = plot(inference_plts_lat_and_lat_inf_obs...,
layout = (3, 1),
size = (500, 400),
title = "rwp + lat inf + obs only: " .* ["Forward" "Reverse (true)" "Reverse (false)"] .*
" AD")
##
function test_inference_gen_epiaware(adtype; n = 20)
test_y_t = randn(n) .+ 1 .|> exp .|> _y -> round(Int64, _y)
cond_mdl = generate_epiaware(test_y_t, n, epi;latent_model = rwp, observation_model = obs)
# cond_mdl = mdl | (y_t = test_y_t,)
chn = sample(cond_mdl, NUTS(adtype = adtype), 2_000)
gens = generated_quantities(cond_mdl, chn)
plt = plot()
for gen in gens[1:10:end]
plot!(plt, gen.I_t, lab = "", alpha = 0.1, c = :grey)
end
scatter!(plt, test_y_t, lab = "Observed", lw = 2, c = :red)
hline!(plt, [(exp(1))], lab = "", c = :black)
plt
end
inference_plts_gen_epiaware = map([
AutoForwardDiff(), AutoReverseDiff(true), AutoReverseDiff(false)]) do adtype
test_inference_gen_epiaware(adtype)
end
plts = plot(inference_plts_gen_epiaware...,
layout = (3, 1),
size = (500, 400),
title = "epiaware: " .* ["Forward" "Reverse (true)" "Reverse (false)"] .*
" AD")
##
function test_inference_gen_epiaware2(adtype; n = 20)
test_y_t = randn(n) .+ 1 .|> exp .|> _y -> round(Int64, _y)
epi_prob = EpiProblem(epi, rwp, obs, (1, n))
cond_mdl = generate_epiaware(epi_prob, (y_t = test_y_t,))
# cond_mdl = mdl | (y_t = test_y_t,)
chn = sample(cond_mdl, NUTS(adtype = adtype), 2_000)
gens = generated_quantities(cond_mdl, chn)
plt = plot()
for gen in gens[1:10:end]
plot!(plt, gen.I_t, lab = "", alpha = 0.1, c = :grey)
end
scatter!(plt, test_y_t, lab = "Observed", lw = 2, c = :red)
hline!(plt, [(exp(1))], lab = "", c = :black)
plt
end
inference_plts_gen_epiaware = map([
AutoForwardDiff(), AutoReverseDiff(true), AutoReverseDiff(false)]) do adtype
test_inference_gen_epiaware2(adtype)
end
plts = plot(inference_plts_gen_epiaware...,
layout = (3, 1),
size = (500, 400),
title = "epiprob: " .* ["Forward" "Reverse (true)" "Reverse (false)"] .*
" AD")
## Multiple chains
function test_inference_gen_epiaware_multi(adtype; n = 20)
test_y_t = randn(n) .+ 1 .|> exp .|> _y -> round(Int64, _y)
epi_prob = EpiProblem(epi, rwp, obs, (1, n))
cond_mdl = generate_epiaware(epi_prob, (y_t = test_y_t,))
# cond_mdl = mdl | (y_t = test_y_t,)
chn = sample(cond_mdl, NUTS(adtype = adtype), MCMCThreads(), 500, 4)
gens = generated_quantities(cond_mdl, chn)
plt = plot()
for gen in gens[1:10:end]
plot!(plt, gen.I_t, lab = "", alpha = 0.1, c = :grey)
end
scatter!(plt, test_y_t, lab = "Observed", lw = 2, c = :red)
hline!(plt, [(exp(1))], lab = "", c = :black)
plt
end
inference_plts_gen_epiaware = map([
AutoForwardDiff(), AutoReverseDiff(true), AutoReverseDiff(false)]) do adtype
test_inference_gen_epiaware_multi(adtype)
end
plts = plot(inference_plts_gen_epiaware...,
layout = (3, 1),
size = (500, 400),
title = "epiprob multi: " .* ["Forward" "Reverse (true)" "Reverse (false)"] .*
" AD")
## Raw latent delay
##
function test_inference_gen_epiaware_direct(adtype; n = 20)
test_y_t = randn(n) .+ 1 .|> exp .|> _y -> round(Int64, _y)
cond_mdl = generate_epiaware(test_y_t, n, epi;latent_model = weekly_ar, observation_model = obs_direct)
# cond_mdl = mdl | (y_t = test_y_t,)
chn = sample(cond_mdl, NUTS(adtype = adtype), 2_000)
gens = generated_quantities(cond_mdl, chn)
plt = plot()
for gen in gens[1:10:end]
plot!(plt, gen.I_t, lab = "", alpha = 0.1, c = :grey)
end
scatter!(plt, test_y_t, lab = "Observed", lw = 2, c = :red)
hline!(plt, [(exp(1))], lab = "", c = :black)
plt
end
inference_plts_gen_epiaware_dir = map([
AutoForwardDiff(), AutoReverseDiff(true), AutoReverseDiff(false)]) do adtype
test_inference_gen_epiaware_direct(adtype)
end
plts = plot(inference_plts_gen_epiaware_dir...,
layout = (3, 1),
size = (500, 400),
title = "epiaware direct obs: " .* ["Forward" "Reverse (true)" "Reverse (false)"] .*
" AD")
## To be run from the context of the |
Obviously thats horrible, but thats the basic idea... Watch this space! |
Nice thank you. To be explicit the issue only happened when we using |
Correct. Sent from my iPhoneOn 1 May 2024, at 17:30, Sam Abbott ***@***.***> wrote:
Nice thank you.
To be explicit the issue only happened when we using DynamicPLL.condition and in the observation model that nested a negative binomial inside a latent delay?
—Reply to this email directly, view it on GitHub, or unsubscribe.You are receiving this because you were mentioned.Message ID: ***@***.***>
|
I think we should close this in favour of a more specific issue targetted towards anything that needs to be done |
Yes. Especially since its a weird bug I can't seem to replicate atm (did we fix it without knowing?) |
Didn't we just do other things so the edge case didn't come up ? |
Whilst working on addressing #174 I've noticed that there seems to be increased amounts of divergent chains in NUTS sampling.
The main output pathology seems to be that different chains return divergent
cluster_factor
estimates, which allows other parameters to be pretty divergent too.Looking into this is ongoing.
The text was updated successfully, but these errors were encountered: