Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow AutoSemiDAIS to work without global variable #1665

Merged
merged 6 commits into from
Nov 12, 2023

Conversation

fehiepsi
Copy link
Member

No description provided.

Copy link
Collaborator

@martinjankowiak martinjankowiak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this PR support amortized q(z|x)?

@@ -1250,13 +1250,16 @@ def local_model(theta):
during partial momentum refreshments in HMC. Defaults to 0.9.
:param float init_scale: Initial scale for the standard deviation of the variational
distribution for each (unconstrained transformed) local latent variable. Defaults to 0.1.
:param str subsample_plate: Optional name of the subsample plate site. This is required
when the model does not have subsample plate (like in VAE settings).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does not?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just revised it to clarify that this is required when the model has a subsample plate without subsample_size specified.

numpyro/infer/autoguide.py Outdated Show resolved Hide resolved
@@ -1301,6 +1308,10 @@ def _setup_prototype(self, *args, **kwargs):
and isinstance(site["args"][1], int)
and site["args"][0] > site["args"][1]
}
if self.subsample_plate is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm a bit confused by these args/checks. afaik we should support the following cases:

  • there is a single plate but there is no subsampling
  • there is a single plate and it is subsampled
    any other scenario (e.g. 0 plates or > 1 plates) is not supported. is that right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC we allow multiple plates but only one subsample plate. I just added an elif not subsample_plates: ... branch to cover the case you mentioned above:

  • there is a single plate but there is no subsampling

event_dim=1,
)
if self.use_global_dais_params:
z_0_loc_init = jnp.zeros(D)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't think this part makes sense. the z params should always be local.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, I think amortized guide should have local_guide specified.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but what if i just want non-amortized mean-field variational distributions for the locals? i would need to specify local_guide as opposed to relying on a convenient default behavior?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, in that case, you might want to set use_global_dais_params=False. We might also change the semantics of this flag to global_dais_params=None/"dynamic"/"full". wdyt?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what exactly are the three options?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None is the current master behavior, or False in this PR. dynamic is your request. full is this PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry for being slow here (traveling). i think use_global_dais_params which controls e.g. betas should be entirely separate from whatever controls the distribution over z_0. what about the following behavior:

  • if local_guide is provided use that.
  • otherwise if local_guide=None and there exist local variables instantiate a auto mean-field guide?

Copy link
Member Author

@fehiepsi fehiepsi Nov 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, that will imply that if users want to have global params for the base dist, they will need to use local_guide? I don't have a preference here. To me, base params play a similar role as betas; the dynamic will depend on model density etc. In the DAIS paper, the author even uses a fixed base dist.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't think there's any reason in practice why you'd really want global/shared params for the base dist q(z_0). having global params for e.g. beta makes sense because it's basically a "higher order quantity" and so harder to estimate and probably varies less from data point to data point. just in the same way that we'd probably generally be more comfortable in sharing scales/variances across data points than locs/means.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense to me, thanks!

@fehiepsi
Copy link
Member Author

yes, this supports amortized local_guide. Users need to specify it in the construction.

@fehiepsi fehiepsi merged commit 1528036 into pyro-ppl:master Nov 12, 2023
4 checks passed
amifalk pushed a commit to amifalk/numpyro that referenced this pull request Nov 20, 2023
* support vae for dais

* revise docstring

* use local params for base dist and fix kumaraswamy vmap tests
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants