Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring SteinVI #1883

Merged
merged 11 commits into from
Oct 19, 2024
Merged

Conversation

OlaRonning
Copy link
Member

@OlaRonning OlaRonning commented Oct 8, 2024

SVGD has been extended to a few varieties (ASVGD, GSVGD, HSVGD) which I want to include in the einstein module. This PR consists of the changes to SteinVI to allow for the extensions and introduction of SVGD and ASVGD. GSVGD and HSVGD will be in subsequent PRs.

To this end, I've:

  1. Added a specialized constructor for SVGD.
    • The SVGD constructor does not require a guide because it's always a delta
    • It correctly sets the scaling on the attractive force to 1/m for m particles.
    • It doesn't allow users to change num_elbo_particles because it's always 1 for SVGD.
  2. Added a setup_run method to SteinVI.
    • The method encapsulates a step of SteinVI, which inheriting constructors can manipulate.
  3. Added a constructor for ASVGD.
    • ASVGD introduces an annealing schedule on the attractive force.
    • This inherits from SVGD and overwrites setup_run to change the loss_temperature.
  4. Removed Jacobian projection from the Stein force
    • The projection is unnecessary as we do not allow AutoIAFNormal, AutoBNAFNormal, AutoDAIS, AutoSemiDAIS and AutoSurrogateLikelihoodDAIS.
    • This simplifies the force computation to attractive+repulsive.

Misc changes

  1. The normalization factor in the ProbabilityProductKernel has been removed. The kernel is still a proper kernel; however, this version avoids vanishing/exploding when the guide variances deviate from 1 for "high" dimensional models.
  2. Added a rng_key to the kernel. This is convenient when experimenting with kernels.
  3. Removed the enum parameter as it is currently unsupported.
  4. Kernel smoothing of the attractive force and the repulsion is now computed on particles in unconstraint space, consistent with particles moving in unconstraint space. NB: Constraint particles are removed from the test.
  5. Added the manual computation for kernel tests as comments.
    • I inadvertently changed the value of one of the particles when doing the manual computations, which is why the evaluation has changed.

TODO

  • Add docstrings to ASVGD and SVGD.
  • Add manual computation to the IMQKernel and GraphicalKernel tests.
  • Fix SteinVI documentation I'll do this in a separate PR.

@OlaRonning OlaRonning added the WIP label Oct 8, 2024
@OlaRonning OlaRonning marked this pull request as draft October 8, 2024 12:44
docs/source/conf.py Outdated Show resolved Hide resolved
@OlaRonning OlaRonning changed the title [WIP] Refactoring SteinVI Refactoring SteinVI Oct 9, 2024
@OlaRonning OlaRonning removed the WIP label Oct 9, 2024
@OlaRonning OlaRonning marked this pull request as ready for review October 10, 2024 10:27
@OlaRonning OlaRonning requested a review from fehiepsi October 11, 2024 15:50
@OlaRonning
Copy link
Member Author

@fehiepsi Let me know if you want me to split it into multiple PRs.

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM pending minor comments.

from jax import numpy as jnp, random, vmap
from jax.tree_util import tree_flatten, tree_map
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is deprecated. You can use jax.tree.flatten and jax.tree.map now.

from jax.flatten_util import ravel_pytree
from jax.tree_util import tree_map
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

from jax import numpy as jnp
from jax.tree_util import tree_flatten, tree_map
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

@OlaRonning
Copy link
Member Author

Thanks! I'll change it to jax.tree instead.

@OlaRonning OlaRonning merged commit 8ace34f into pyro-ppl:master Oct 19, 2024
4 checks passed
@OlaRonning OlaRonning deleted the refactor/SteinVI branch October 19, 2024 07:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants