-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactoring SteinVI #1883
Refactoring SteinVI #1883
Conversation
@fehiepsi Let me know if you want me to split it into multiple PRs. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM pending minor comments.
from jax import numpy as jnp, random, vmap | ||
from jax.tree_util import tree_flatten, tree_map |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is deprecated. You can use jax.tree.flatten
and jax.tree.map
now.
numpyro/contrib/einstein/steinvi.py
Outdated
from jax.flatten_util import ravel_pytree | ||
from jax.tree_util import tree_map |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above
from jax import numpy as jnp | ||
from jax.tree_util import tree_flatten, tree_map |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above
Thanks! I'll change it to |
SVGD has been extended to a few varieties (ASVGD, GSVGD, HSVGD) which I want to include in the einstein module. This PR consists of the changes to
SteinVI
to allow for the extensions and introduction of SVGD and ASVGD. GSVGD and HSVGD will be in subsequent PRs.To this end, I've:
1/m
form
particles.num_elbo_particles
because it's always 1 for SVGD.setup_run
method to SteinVI.setup_run
to change theloss_temperature
.AutoIAFNormal
,AutoBNAFNormal
,AutoDAIS
,AutoSemiDAIS
andAutoSurrogateLikelihoodDAIS
.Misc changes
ProbabilityProductKernel
has been removed. The kernel is still a proper kernel; however, this version avoids vanishing/exploding when the guide variances deviate from 1 for "high" dimensional models.enum
parameter as it is currently unsupported.TODO
Fix SteinVI documentationI'll do this in a separate PR.