Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Poor sampling performance with some complex posteriors compared to HMC #295

Open
sefffal opened this issue Oct 29, 2024 · 35 comments
Open

Comments

@sefffal
Copy link

sefffal commented Oct 29, 2024

Hi all, I have isolated a relatively lightweight example where I find that HMC is significantly outperforming PT. I tested most variations supported in Pigeons: Slice Sampler & AutoMALA, fixed, variational and stabilized-variational).

A corner & trace plot are attached below with HMC in blue and PT in gold.

A script to produce this plot is available here: https://github.com/sefffal/OrbitPosteriorDB/blob/main/models/astrom-GL229A.jl
Use the latest #main commit of Octofitter (e.g. ] add Octofitter#main).

I would be curious to understand better why PT is struggling so much on this target, and if there is a way to improve performance to be at least within the same ball-park.

The HMC series in this plot is not exactly converged since there's pretty high correlation between samples, but it nonetheless successfully explores the posterior while PT is stuck in a much smaller region.

image
@alexandrebouchard
Copy link
Member

Thanks William, I'm excited to see what happens here and hopefully make improvements based on what we observe! So far the focus has been on problems where HMC performs poorly, but I agree that having reasonable performance on problems where HMC work well is very useful too.

I would have expected automala to do well here since there are some funnel-like shapes in the pair plot. It could be the round-based initial step size adaptation running wild. That part of automala is relatively less understood compared to the per step auto step.

Another possibility is that long trajectories are really needed for this one. With @tgliu0406 we prototyped recently an auto HMC, it would be interesting to see if it performs better than automala on that problem.

Since multiple chains is not needed here, it might be worth focussing initially on running pigeons with 1 chain.

I'll start with trying the reproducibility script. Thanks again!

@alexandrebouchard
Copy link
Member

For reference, here are the stats from the HMC run:

  iterations:                                   5100
  ratio_divergent_transitions:                  0.0
  ratio_divergent_transitions_during_adaption:  0.03
  n_steps:                                      4095
  is_accept:                                    true
  acceptance_rate:                              0.9686561630493195
  log_density:                                  19.1778839507093
  hamiltonian_energy:                           -15.535406645231566
  hamiltonian_energy_error:                     0.04783412097392059
  max_hamiltonian_energy_error:                 0.088189176494609
  tree_depth:                                   12
  numerical_error:                              false
  step_size:                                    0.0004985427615975861
  nom_step_size:                                0.0004985427615975861
  is_adapt:                                     false
  mass_matrix:                                  DenseEuclideanMetric(diag=[0.16992803377878063, 7.203 ...])

One other guess is that the preconditioner might not be off...

@alexandrebouchard
Copy link
Member

Some stats based off

chain_pt,pt = octofit_pigeons(
    model,
    # 12 rounds is about the same wall-clock time as the above HMC
    n_rounds=15,
    n_chains=1,
    n_chains_variational=0,
    variational=nothing,#GaussianReference(),
    # explorer=Compose(SliceSampler(),AutoMALA()),
    explorer=AutoMALA(),
)
julia> pt.shared.explorer.step_size
0.004920727982147332

julia> pt.shared.explorer.estimated_target_std_deviations
12-element Vector{Float64}:
 0.009694861110282131
 0.011218920887184406
 9.84596575830136e-5
 0.008446406116308973
 0.007492746456207485
 0.006611899705535479
 0.02192204882563213
 0.006264372034808068
 0.006598652035268915
 0.014764926897502917
 0.006588233373333677
 0.0019844912584754994

@miguelbiron
Copy link
Collaborator

For reference, here are the stats from the HMC run:

  iterations:                                   5100
  ratio_divergent_transitions:                  0.0
  ratio_divergent_transitions_during_adaption:  0.03
  n_steps:                                      4095
  is_accept:                                    true
  acceptance_rate:                              0.9686561630493195
  log_density:                                  19.1778839507093
  hamiltonian_energy:                           -15.535406645231566
  hamiltonian_energy_error:                     0.04783412097392059
  max_hamiltonian_energy_error:                 0.088189176494609
  tree_depth:                                   12
  numerical_error:                              false
  step_size:                                    0.0004985427615975861
  nom_step_size:                                0.0004985427615975861
  is_adapt:                                     false
  mass_matrix:                                  DenseEuclideanMetric(diag=[0.16992803377878063, 7.203 ...])

One other guess is that the preconditioner might not be off...

Some questions/comments

  • Is n_steps the number of leapfrog steps on an average iteration? Because wow. That is 4 times the default (2^10) in stan if im not mistaken
  • The DenseEuclideanMetric might also be doing a lot of lifting. We haven't implemented that yet in pigeons
  • Finally, the nasty concentration on lower dimensional submanifolds that are not axis-aligned definitely should kill the slicer, so that is not too surprising.

@sefffal
Copy link
Author

sefffal commented Nov 1, 2024

Hi @miguelbiron , thanks for taking a look.

Yes, I find that increasing the maximum number of steps is important for getting acceptable performance from HMC on these problems. Sometimes the data are not that constraining, in which case NUTS cuts the trajectories short and we average around 2^7 leapfrog steps.
In this case, almost every iteration uses 2^12 steps so an even higher limit might be warranted.

The DenseEuclideanMetric is critical, and HMC effectively does not work on these problems without it.

I'll note that while surely there are some reparameterizations that can help with this specific dataset, it is very hard to find parameterizations that solve the problem for all datasets.

@miguelbiron
Copy link
Collaborator

No problem and thanks for clarifying! It's good to know at least that NUTS works here because you're taking it to extreme tree sizes (and even then, I see it goes beyond the limit 10% of the time), and also using the Dense preconditioner.

From a Pigeons perspective, we instead would prefer to stick with a cheaper explorer and instead crank up the number of chains. But it's weird that that doesn't work. So I'll be doing some tests now to see if I can find a combination of dumb explorer with more chains that gives as good performance as NUTS in the same wallclock time. I'll keep you posted.

@sefffal
Copy link
Author

sefffal commented Nov 1, 2024

Thanks @miguelbiron, I would say this matches my experiences with Pigeons too. The SliceSampler combined with ~40 chains and ~40 variational chains can usually power through just about anything. Unfortunately there are some challenging cases where even that leads to some weird results.

For example, this figure (from a different but related posterior) shows how sometimes PT leaves these gaps along the posterior that ought to be connected (see e.g. inclination vs eccentricity):

image

@miguelbiron
Copy link
Collaborator

The axis aligned gap kinda make sense for a slice sampler failure mode. That could fixed by substituting with a Hit and Run wrapper for the slicer that we implemented for the AutoStep paper with @tgliu0406. Or even fully switching to autoRWMH which showed promising results. We should be releasing a package with these additional explorers in the short term.

I'll go back to the other reproducible example now.

@alexandrebouchard
Copy link
Member

Good point... I think this would be a good test bed for the auto step methods + Hit and Run...

@alexandrebouchard
Copy link
Member

I noticed that the script does not give exactly the same result when I rerun it. Looking at the output of the initializer, I think it might be a cause of this. In my first run I got

┌ Info: Starting point drawn
└   initial_logpost = -850885.8759688964

vs for the second:

┌ Info: Starting point drawn
└   initial_logpost = -1.5588419366108126e6

@sefffal
Copy link
Author

sefffal commented Nov 2, 2024

Hi @alexandrebouchard, my bad. If the starting point aren’t given explicitly they sample from the default RNG.

You should be able to:

  • seed it with Random.seed!(0)
  • Initialize with an explicit RNG argument Octofitter.default_initializer!(rng, model)
  • Specify the starting points manually (make sure there’s at least one per chain):
model.starting_points = [
    model.link([1,2,3]), # start values for each parameter 
    # repeat for each chain, or use fill(…)
]

@miguelbiron
Copy link
Collaborator

miguelbiron commented Nov 2, 2024

I also noticed the non-determinism, which is more annoying in that the HMC sample themselves are completely different from one run to the other. Like I've never seen on my runs the sinusoidal shapes you show above. So unless I f'd up something on my end, I wouldn't trust too much those samples either.

BTW, I'm doing a 20 round run with 60 chains = 4Lambda, autoRWMH with n_refresh=32. I did this following my heuristic advise from the NRST paper, and I actually get

  • max energy autocorrelation <~ 90% (due to increased n_refresh)
  • almost exact equirejection
  • reversibility rate and acceptance probs consistent with autoRWMH in other test sets.
  • With this I observe roundtrips from round 13 onwards. Which is right about the time when the normalizing constant begins stabilizing.

Main takeaway: you haven't reached adaptation convergence if you don't see

  • Lambda ~ 15
  • logZ ~ -16

image

@sefffal
Copy link
Author

sefffal commented Nov 2, 2024

Okay, I can reproduce the non-determinism of the pathfinder initialization. I'll try to track that down.

@sefffal
Copy link
Author

sefffal commented Nov 2, 2024

Wow, this is frustrating.

I think I found non-determinism issues with both Pathfinder.jl, and with AdvancedHMC.jl.
Clearly, I need to add some integration tests here.

For pathfinder, it is non-deterministic when using multiple threads (can be worked around by passing nruns=1 to Octofitter.default_initializer!).

For AdvancedHMC, I am passing an rng object explicitly, but it is ignoring the argument and using the default global RNG.

@sefffal
Copy link
Author

sefffal commented Nov 2, 2024

I pushed an update to the script that works around those issues for now.

@miguelbiron
Copy link
Collaborator

Update: at round 17, roundtrip rate stabilized -- i.e., doubled wrt previous round. Also slight change in logZ ~ -16.7. All other measurements look healthy.

image

@sefffal
Copy link
Author

sefffal commented Nov 2, 2024

It’s crazy that the first restart is in round 13; I usually see restarts beginning in round 8-10. Is that related to the use of autoRWHM, do you think?

@miguelbiron
Copy link
Collaborator

I don't think so, I tried maany combinations of samplers and all take about 12-13 rnds to gove restarts on this problem. I think the issue is that the model is very challenging for our adaptation algorithm. Basically the restarts begin here at the same round that the logZ eatimate approximates the true value

@alexandrebouchard
Copy link
Member

Thanks Miguel, got similar observations with a long run of AutoMALA.

chain_pt,pt = octofit_pigeons(
    model,
    # 12 rounds is about the same wall-clock time as the above HMC
    n_rounds=18,
    n_chains=20,
    n_chains_variational=0,
    variational=nothing,#GaussianReference(),
    # explorer=Compose(SliceSampler(),AutoMALA()),
    explorer=AutoMALA(),
)

got

─────────────────────────────────────────────────────────────────────────────────────────────────────────────
  scans     restarts      Λ        time(s)    allc(B)  log(Z₁/Z₀)   min(α)     mean(α)    min(αₑ)   mean(αₑ) 
────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ──────────
        2          0       1.93       3.34    3.1e+08  -9.95e+06          0      0.899      0.349      0.599 
        4          0       3.71      0.147   3.32e+07  -3.34e+03          0      0.805      0.373      0.596 
        8          0       4.55      0.131   6.61e+05  -1.63e+03   1.1e-198      0.761      0.619      0.661 
       16          0       4.92      0.276   1.27e+06   -1.4e+03   0.000308      0.741      0.485      0.628 
       32          0       7.51      0.565   2.37e+06   -1.4e+03     0.0927      0.605      0.581      0.643 
       64          0       9.31       1.09   4.52e+06   -1.4e+03     0.0535       0.51      0.528       0.61 
      128          0       9.84       2.36   8.82e+06  -1.39e+03     0.0455      0.482      0.585      0.626 
      256          0       10.8       4.52   1.76e+07  -1.39e+03     0.0208      0.431        0.6      0.633 
      512          0       11.3       8.89   3.48e+07  -1.38e+03      0.262      0.407      0.551      0.616 
 1.02e+03          0       11.9        917   6.95e+07  -1.38e+03      0.176      0.373      0.563      0.619 
 2.05e+03          0       12.9        823   1.39e+08  -1.35e+03     0.0782      0.321       0.59      0.629 
  4.1e+03          2       14.2       76.9   2.78e+08       -886    0.00274      0.253      0.543      0.638 
 8.19e+03          0       15.3        121   5.53e+08       -437   0.000244      0.195       0.51      0.622 
 1.64e+04          1       15.6        790   1.11e+09       -205   0.000628      0.181      0.515      0.627 
 3.28e+04          1       14.6        527   2.22e+09      -77.4     0.0264      0.233      0.511      0.623 
 6.55e+04          2       12.9   1.03e+04   4.44e+09      -30.5      0.181      0.319      0.539      0.625 
 1.31e+05          6         13   5.63e+03   8.83e+09        -22      0.242      0.315      0.499      0.623 
 2.62e+05         11         13   1.69e+04   1.76e+10      -17.8      0.266      0.314      0.501      0.619 
─────────────────────────────────────────────────────────────────────────────────────────────────────────────

and the following pair plot (note I ran HMC for longer as well, octofit(model,verbosity=4,adaptation=10000,iterations=40100)
display

@alexandrebouchard
Copy link
Member

Some thoughts:

  • Let's write an auto-RieMALA kernel. While Riemannian HMC is tricky (implicit integrator needed), Riemannian MALA is comparatively much easier to write I think. The key is to use not the involution perspective but instead the normal proposal perspective. Question for you William: would this challenging model support.. (1) Hessian calculation? (2) Hessian vector product (if we go on the low rank side).
  • It might be worth putting more effort on initialization/burn-in. For example, we could consider a draw from a Laplace (either full or pathfinder-style), in the style that William is using, but on each intermediate distribution
  • Even the diagonal preconditioning could be improved. I see a potential failure mode, which I think may be happening here: if the initialization is not over-dispersed, then the marginal standard deviation will be substantially under estimated. This in turn makes MALA/HMC methods really slow. With the samples from the prior we can get out of that, but the information has to propagate potentially one level by round. That could explain why it always takes 12-13 rounds to get round trips. An alternative is auto-ApproxRieMALA. If people are interested we could talk briefly on Monday how to implement it. Related question to the group: do you guys know if it is possible to get the diagonal entries of the Hessian matrix in O(d) with some non-standard flavour of autodiff?

@miguelbiron
Copy link
Collaborator

miguelbiron commented Nov 3, 2024

Re Hessian vector product: Seems like SparseDiffTools.jl is the one package offering it. Despite the name, it seems like it still works for arbitrary dense Hessians.

@miguelbiron
Copy link
Collaborator

Finished my 20 round run (without running a similarly long chain for HMC).

The way I understand is is that the problem is fully unidentifiable: there are several (uncountably many perhaps?) strict submanifolds where the distribution of the logpotential restricted to that region is exactly equal to the unrestricted one. Our auto-XYZ samplers tend to be content with staying in those regions -- they basically only look the problem through the distribution of the logpotential. Therefore, the only way that you can fullfill the ergodicity guarantees of PT is via restarts. Hence, the spattered patterns.

octo

@miguelbiron
Copy link
Collaborator

For completeness, the full report.

image

@sefffal
Copy link
Author

sefffal commented Nov 3, 2024

@miguelbiron & @alexandrebouchard , thank you for looking into this so thoroughly.

The way I understand is is that the problem is fully unidentifiable: there are several (uncountably many perhaps?) strict submanifolds where the distribution of the logpotential restricted to that region is exactly equal to the unrestricted one

I admit I’m a bit out of my depth here, but I’m glad there is an explanation for the splatter patterns. Actually, we’ve seen this before with ptemcee.py too. Maybe we could discuss tomorrow?

Question for you William: would this challenging model support.. (1) Hessian calculation? (2)

I think I have done this before using ForwardDiff without issue.

@miguelbiron
Copy link
Collaborator

miguelbiron commented Nov 3, 2024

Yes, happy to discuss!

Just to finish the thought about the difference with NUTS: in contrast to our samplers, the no-U-turn condition actually looks at the state vector (x,p), forcing the sampler to go far in phase space -- not just in logdensity space. It would be interesting to think about what sort of modification the autoXYZ approach would require to force movement in phase space too.

Edit: BTW, if this interpretation is correct, then a Riemannian approach would not be sufficient to match NUTS' performance. At most, it would reduce the required n_refresh back to 1-3 instead of the 32 or higher to get good mixing in logpotential. But hopefully we can prove this theory is wrong :D

@miguelbiron
Copy link
Collaborator

Also for the record, this is how the samples look like if I just do Chains(pt) (I guess the Octofitter version applies transformations and appends other variables). I'm not sure if these are in the space where the sampling occurs or if these have already been backtransformed to constrained space.
pairplot_raw

@sefffal
Copy link
Author

sefffal commented Nov 3, 2024

@miguelbiron I wonder if the samples have some transformations applied that way too. Might be safest to try:

pairplot(stack(get_sample(pt)))

the last row will be the log posterior density.

@miguelbiron
Copy link
Collaborator

Hmm I lost the pt object (had to turn off pc). How would you do that with the hmc samples? That should run faster

@sefffal
Copy link
Author

sefffal commented Nov 3, 2024

That should be possible too if the hmc chains object is in memory (it isn’t saved to disk via savechain / loadchain).

the following should plot the HMC samples in the unconstrained space:

pairplot(stack(chain_hmc.info.samples_transformed))

@sefffal
Copy link
Author

sefffal commented Nov 4, 2024

Re: autodiff for hessians and hessian-vector products.

I tested SparseDiff's numauto_hesvec and it seems to be quite efficient. It only takes about 15x longer than a regular call of the log density function.

@miguelbiron
Copy link
Collaborator

Ok so I reran the samplers to draw the pairplots in unconstrained space. I'd say the diagnostic is unchanged wrt the picture in constrained space---a proper house of horrors of complex geometry.
pairplot_raw

@alexandrebouchard
Copy link
Member

Beautiful!

@sefffal
Copy link
Author

sefffal commented Nov 5, 2024

To add one observation to this discussion, i notice that another orbit fitting code often produces similar splatter patterns.

This code uses “ptemcee”, (reversible) parallel tempered affine invariant ensemble sampler.

Do you think this symptom could have a common cause?

This is from the paper: https://iopscience.iop.org/article/10.3847/1538-3881/ac042e

@alexandrebouchard
Copy link
Member

Maybe we could see if we can reproduce the splatter pattern with a simple synthetic example, e.g. like example 7.2 in @nikola-sur 's paper, https://arxiv.org/pdf/2405.11384

@miguelbiron
Copy link
Collaborator

What is pi0 in that example?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants