Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Instability across chains #177

Closed
SamuelBrand1 opened this issue Mar 28, 2024 · 13 comments
Closed

Instability across chains #177

SamuelBrand1 opened this issue Mar 28, 2024 · 13 comments
Labels
bug Something isn't working EpiAware

Comments

@SamuelBrand1
Copy link
Collaborator

Whilst working on addressing #174 I've noticed that there seems to be increased amounts of divergent chains in NUTS sampling.

The main output pathology seems to be that different chains return divergent cluster_factor estimates, which allows other parameters to be pretty divergent too.

Looking into this is ongoing.

@SamuelBrand1 SamuelBrand1 added bug Something isn't working EpiAware labels Mar 28, 2024
@seabbs
Copy link
Collaborator

seabbs commented Mar 28, 2024

Thanks for this. Can you bit a bit more info into the title to help tighten up the focus of this issue? And actually in the body of the issue. Is it just the renewal model or all etc?

@seabbs
Copy link
Collaborator

seabbs commented Apr 19, 2024

@SamuelBrand1 has been investigating this and it looks like its due to an interaction between Turing.condition and either LatentDelay or NegativeBinomialError.

@seabbs
Copy link
Collaborator

seabbs commented Apr 24, 2024

@SamuelBrand1 where are we on this? It would be good to update here if we have a sensible fix or workaround so we can build on it

@SamuelBrand1
Copy link
Collaborator Author

Switching from a random walk to an AR process seems to have solved this for at least the simple getting started example.

@seabbs
Copy link
Collaborator

seabbs commented Apr 24, 2024

That doesn't really sound like resolving the route issue if it was linked to conditioning or not? Can you summarise what you found in that investigation / likely PRs we might see to test it (i.e the end to end tests you started?)

@seabbs
Copy link
Collaborator

seabbs commented May 1, 2024

@SamuelBrand1 can you update on what the combinations were that were causing the problems, any scratch scripts you had for looking at that, and anything else you found in investigating.

@SamuelBrand1
Copy link
Collaborator Author

SamuelBrand1 commented May 1, 2024

So the super scratch script is

let
    docs_dir = dirname(dirname(@__DIR__))
    pkg_dir = dirname(docs_dir)

    using Pkg: Pkg
    Pkg.activate(docs_dir)
    Pkg.develop(; path = pkg_dir)
    Pkg.instantiate()
end;

begin
    using EpiAware
    using Turing
    using Distributions
    using StatsPlots
    using Random
    using DynamicPPL
    using Statistics
    using DataFramesMeta
    using LinearAlgebra
    using Transducers
    using ReverseDiff
    Random.seed!(1)
end

# Model definition

rwp = EpiAware.RandomWalk(
    init_prior = Normal(),
    std_prior = HalfNormal(0.05)
)

weekly_rwp = BroadcastLatentModel(rwp, 7, RepeatBlock())
ar = AR(; damp_priors = [truncated(Normal(0.8, .05), 0, 1), truncated(Normal(0.1, .05), 0, 1)],
            std_prior = truncated(Normal(), 0, Inf),
            init_priors = [Normal(-1.0, 0.1), Normal(-1.0, 0.1)],
)

weekly_ar = BroadcastLatentModel(ar, 7, RepeatBlock())

truth_GI = Gamma(2, 5)
model_data = EpiData(gen_distribution = truth_GI,
    D_gen = 10.0)

log_I0_prior = Normal(log(100.0), 1.0)
epi = DirectInfections(model_data, log_I0_prior)

obs = LatentDelay(
    NegativeBinomialError(cluster_factor_prior = Gamma(10, 0.05 / 10)),
    fill(0.25, 4),
)

obs_direct = NegativeBinomialError(cluster_factor_prior = Gamma(10, 0.05 / 10))

##
n_latent_inference = 20

@model function test_latent_inference(rwp, n)
    @submodel Z, _ = generate_latent(rwp, n)
    y_t ~ MvNormal(Z, 0.1)
    return (Z = Z, y_t = y_t)
end

chn_rad_f = sample(cond_mdl, NUTS(adtype = AutoReverseDiff(false)), 2_000)
chn_rad_t = sample(cond_mdl, NUTS(adtype = AutoReverseDiff(true)), 2_000)

##
function test_inference(adtype; n = 28)
    mdl = test_latent_inference(weekly_rwp, n)
    test_y_t = vcat(randn(n ÷ 2) .+ 10 , randn(n ÷ 2) .+ 20)
    cond_mdl = mdl | (y_t = test_y_t,)
    chn_fwd = sample(cond_mdl, NUTS(adtype = adtype), 2_000)

    gens = generated_quantities(cond_mdl, chn_fwd)
    plt = plot()
    for gen in gens[1:10:end]
        plot!(plt, gen.Z, lab = "", alpha = 0.1, c = :grey)
    end
    plot!(plt, test_y_t, lab = "Observed", lw = 2, c = :red)
    plt
end

inference_plts = map([
    AutoForwardDiff(), AutoReverseDiff(true), AutoReverseDiff(false)]) do adtype
    test_inference(adtype)
end

plts = plot(inference_plts..., layout = (3, 1), size = (500, 400),
    title = "rwp only: " .* ["Forward" "Reverse (true)" "Reverse (false)"] .* " AD")
##

n = 20
@model function test_latent_and_latent_inf_inference(rwp, epi, n)
    @submodel Z, _ = generate_latent(rwp, n)
    @submodel I_t = generate_latent_infs(epi, Z)
    y_t ~ MvNormal(log.(I_t), 0.1)
    return (Z = Z, y_t = y_t, I_t)
end

##

function test_inference_lat_and_lat_inf(adtype; n = 20)
    mdl = test_latent_and_latent_inf_inference(ar, epi, n)
    test_y_t = randn(n) .+ 1
    cond_mdl = mdl | (y_t = test_y_t,)
    chn = sample(cond_mdl, NUTS(adtype = adtype), 2_000)

    gens = generated_quantities(cond_mdl, chn)
    plt = plot()
    for gen in gens[1:10:end]
        plot!(plt, gen.I_t, lab = "", alpha = 0.1, c = :grey)
    end
    plot!(plt, exp.(test_y_t), lab = "Observed", lw = 2, c = :red)
    plt
end

inference_plts_lat_and_lat_inf = map([
    AutoForwardDiff(), AutoReverseDiff(true), AutoReverseDiff(false)]) do adtype
    test_inference_lat_and_lat_inf(adtype)
end

plts = plot(inference_plts_lat_and_lat_inf...,
    layout = (3, 1),
    size = (500, 400),
    title = "rwp + lat inf only: " .* ["Forward" "Reverse (true)" "Reverse (false)"] .*
            " AD")

##

@model function test_inference_lat_and_lat_inf_obs(rwp, epi, obs, n, obs_yt = missing)
    @submodel Z, _ = generate_latent(rwp, n)
    @submodel I_t = generate_latent_infs(epi, Z)
    @submodel gen_y_t, _ = generate_observations(obs, obs_yt, I_t)

    return (; Z, I_t, gen_y_t)
end

##

function test_inference_lat_and_lat_inf_obs(adtype; n = 20)
    test_y_t = randn(n) .+ 1 .|> exp .|> _y -> round(Int64, _y)
    cond_mdl = test_inference_lat_and_lat_inf_obs(rwp, epi, obs, n, test_y_t)

    # cond_mdl = mdl | (y_t = test_y_t,)
    chn = sample(cond_mdl, NUTS(adtype = adtype), 2_000)

    gens = generated_quantities(cond_mdl, chn)
    plt = plot()
    for gen in gens[1:10:end]
        plot!(plt, gen.I_t, lab = "", alpha = 0.1, c = :grey)
    end
    scatter!(plt, test_y_t, lab = "Observed", lw = 2, c = :red)
    hline!(plt, [(exp(1))], lab = "", c = :black)

    plt
end

inference_plts_lat_and_lat_inf_obs = map([
    AutoForwardDiff(), AutoReverseDiff(true), AutoReverseDiff(false)]) do adtype
        test_inference_lat_and_lat_inf_obs(adtype)
end

plts = plot(inference_plts_lat_and_lat_inf_obs...,
    layout = (3, 1),
    size = (500, 400),
    title = "rwp + lat inf + obs only: " .* ["Forward" "Reverse (true)" "Reverse (false)"] .*
            " AD")


##

function test_inference_gen_epiaware(adtype; n = 20)
    test_y_t = randn(n) .+ 1 .|> exp .|> _y -> round(Int64, _y)

    cond_mdl = generate_epiaware(test_y_t, n, epi;latent_model = rwp, observation_model = obs)

    # cond_mdl = mdl | (y_t = test_y_t,)
    chn = sample(cond_mdl, NUTS(adtype = adtype), 2_000)

    gens = generated_quantities(cond_mdl, chn)
    plt = plot()
    for gen in gens[1:10:end]
        plot!(plt, gen.I_t, lab = "", alpha = 0.1, c = :grey)
    end
    scatter!(plt, test_y_t, lab = "Observed", lw = 2, c = :red)
    hline!(plt, [(exp(1))], lab = "", c = :black)

    plt
end

inference_plts_gen_epiaware = map([
    AutoForwardDiff(), AutoReverseDiff(true), AutoReverseDiff(false)]) do adtype
        test_inference_gen_epiaware(adtype)
end

plts = plot(inference_plts_gen_epiaware...,
    layout = (3, 1),
    size = (500, 400),
    title = "epiaware: " .* ["Forward" "Reverse (true)" "Reverse (false)"] .*
            " AD")


##

function test_inference_gen_epiaware2(adtype; n = 20)
    test_y_t = randn(n) .+ 1 .|> exp .|> _y -> round(Int64, _y)
    epi_prob = EpiProblem(epi, rwp, obs, (1, n))
    cond_mdl = generate_epiaware(epi_prob, (y_t = test_y_t,))

    # cond_mdl = mdl | (y_t = test_y_t,)
    chn = sample(cond_mdl, NUTS(adtype = adtype), 2_000)

    gens = generated_quantities(cond_mdl, chn)
    plt = plot()
    for gen in gens[1:10:end]
        plot!(plt, gen.I_t, lab = "", alpha = 0.1, c = :grey)
    end
    scatter!(plt, test_y_t, lab = "Observed", lw = 2, c = :red)
    hline!(plt, [(exp(1))], lab = "", c = :black)

    plt
end


inference_plts_gen_epiaware = map([
    AutoForwardDiff(), AutoReverseDiff(true), AutoReverseDiff(false)]) do adtype
        test_inference_gen_epiaware2(adtype)
end

plts = plot(inference_plts_gen_epiaware...,
    layout = (3, 1),
    size = (500, 400),
    title = "epiprob: " .* ["Forward" "Reverse (true)" "Reverse (false)"] .*
            " AD")


## Multiple chains


function test_inference_gen_epiaware_multi(adtype; n = 20)
    test_y_t = randn(n) .+ 1 .|> exp .|> _y -> round(Int64, _y)
    epi_prob = EpiProblem(epi, rwp, obs, (1, n))
    cond_mdl = generate_epiaware(epi_prob, (y_t = test_y_t,))

    # cond_mdl = mdl | (y_t = test_y_t,)
    chn = sample(cond_mdl, NUTS(adtype = adtype), MCMCThreads(), 500, 4)

    gens = generated_quantities(cond_mdl, chn)
    plt = plot()
    for gen in gens[1:10:end]
        plot!(plt, gen.I_t, lab = "", alpha = 0.1, c = :grey)
    end
    scatter!(plt, test_y_t, lab = "Observed", lw = 2, c = :red)
    hline!(plt, [(exp(1))], lab = "", c = :black)

    plt
end


inference_plts_gen_epiaware = map([
    AutoForwardDiff(), AutoReverseDiff(true), AutoReverseDiff(false)]) do adtype
        test_inference_gen_epiaware_multi(adtype)
end

plts = plot(inference_plts_gen_epiaware...,
    layout = (3, 1),
    size = (500, 400),
    title = "epiprob multi: " .* ["Forward" "Reverse (true)" "Reverse (false)"] .*
            " AD")

## Raw latent delay
##

function test_inference_gen_epiaware_direct(adtype; n = 20)
    test_y_t = randn(n) .+ 1 .|> exp .|> _y -> round(Int64, _y)

    cond_mdl = generate_epiaware(test_y_t, n, epi;latent_model = weekly_ar, observation_model = obs_direct)

    # cond_mdl = mdl | (y_t = test_y_t,)
    chn = sample(cond_mdl, NUTS(adtype = adtype), 2_000)

    gens = generated_quantities(cond_mdl, chn)
    plt = plot()
    for gen in gens[1:10:end]
        plot!(plt, gen.I_t, lab = "", alpha = 0.1, c = :grey)
    end
    scatter!(plt, test_y_t, lab = "Observed", lw = 2, c = :red)
    hline!(plt, [(exp(1))], lab = "", c = :black)

    plt
end

inference_plts_gen_epiaware_dir = map([
    AutoForwardDiff(), AutoReverseDiff(true), AutoReverseDiff(false)]) do adtype
        test_inference_gen_epiaware_direct(adtype)
end

plts = plot(inference_plts_gen_epiaware_dir...,
    layout = (3, 1),
    size = (500, 400),
    title = "epiaware direct obs: " .* ["Forward" "Reverse (true)" "Reverse (false)"] .*
            " AD")


##

To be run from the context of the docs env.

@SamuelBrand1
Copy link
Collaborator Author

Obviously thats horrible, but thats the basic idea... Watch this space!

@seabbs
Copy link
Collaborator

seabbs commented May 1, 2024

Nice thank you.

To be explicit the issue only happened when we using DynamicPLL.condition and in the observation model that nested a negative binomial inside a latent delay?

@SamuelBrand1
Copy link
Collaborator Author

SamuelBrand1 commented May 1, 2024 via email

@seabbs
Copy link
Collaborator

seabbs commented May 22, 2024

I think we should close this in favour of a more specific issue targetted towards anything that needs to be done

@SamuelBrand1
Copy link
Collaborator Author

Yes. Especially since its a weird bug I can't seem to replicate atm (did we fix it without knowing?)

@seabbs
Copy link
Collaborator

seabbs commented May 22, 2024

did we fix it without knowing

Didn't we just do other things so the edge case didn't come up ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working EpiAware
Projects
None yet
Development

No branches or pull requests

2 participants