-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow AutoSemiDAIS to work without global variable #1665
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this PR support amortized q(z|x)?
numpyro/infer/autoguide.py
Outdated
@@ -1250,13 +1250,16 @@ def local_model(theta): | |||
during partial momentum refreshments in HMC. Defaults to 0.9. | |||
:param float init_scale: Initial scale for the standard deviation of the variational | |||
distribution for each (unconstrained transformed) local latent variable. Defaults to 0.1. | |||
:param str subsample_plate: Optional name of the subsample plate site. This is required | |||
when the model does not have subsample plate (like in VAE settings). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does not?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just revised it to clarify that this is required when the model has a subsample plate without subsample_size
specified.
@@ -1301,6 +1308,10 @@ def _setup_prototype(self, *args, **kwargs): | |||
and isinstance(site["args"][1], int) | |||
and site["args"][0] > site["args"][1] | |||
} | |||
if self.subsample_plate is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i'm a bit confused by these args/checks. afaik we should support the following cases:
- there is a single plate but there is no subsampling
- there is a single plate and it is subsampled
any other scenario (e.g. 0 plates or > 1 plates) is not supported. is that right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIRC we allow multiple plates but only one subsample plate. I just added an elif not subsample_plates: ...
branch to cover the case you mentioned above:
- there is a single plate but there is no subsampling
numpyro/infer/autoguide.py
Outdated
event_dim=1, | ||
) | ||
if self.use_global_dais_params: | ||
z_0_loc_init = jnp.zeros(D) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i don't think this part makes sense. the z
params should always be local.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, I think amortized guide should have local_guide
specified.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but what if i just want non-amortized mean-field variational distributions for the locals? i would need to specify local_guide
as opposed to relying on a convenient default behavior?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, in that case, you might want to set use_global_dais_params=False
. We might also change the semantics of this flag to global_dais_params=None/"dynamic"/"full"
. wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what exactly are the three options?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
None is the current master behavior, or False in this PR. dynamic is your request. full is this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry for being slow here (traveling). i think use_global_dais_params
which controls e.g. betas
should be entirely separate from whatever controls the distribution over z_0
. what about the following behavior:
- if
local_guide
is provided use that. - otherwise if
local_guide=None
and there exist local variables instantiate a auto mean-field guide?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, that will imply that if users want to have global params for the base dist, they will need to use local_guide? I don't have a preference here. To me, base params play a similar role as betas; the dynamic will depend on model density etc. In the DAIS paper, the author even uses a fixed base dist.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i don't think there's any reason in practice why you'd really want global/shared params for the base dist q(z_0). having global params for e.g. beta makes sense because it's basically a "higher order quantity" and so harder to estimate and probably varies less from data point to data point. just in the same way that we'd probably generally be more comfortable in sharing scales/variances across data points than locs/means.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
makes sense to me, thanks!
yes, this supports amortized |
* support vae for dais * revise docstring * use local params for base dist and fix kumaraswamy vmap tests
No description provided.