-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add AutoGuideList #1644
Add AutoGuideList #1644
Changes from 7 commits
7a16de2
9dffe9b
e52b77a
3f57f4e
b4b2a31
afc07c7
ec6c463
5cdb72c
0a10e53
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -213,6 +213,103 @@ def quantiles(self, params, quantiles): | |
raise NotImplementedError | ||
|
||
|
||
class AutoGuideList(AutoGuide): | ||
""" | ||
Container class to combine multiple automatic guides. | ||
|
||
Example usage:: | ||
|
||
rng_key_init = random.PRNGKey(0) | ||
guide = AutoGuideList(my_model) | ||
guide.append( | ||
AutoNormal( | ||
numpyro.handlers.block(numpyro.handlers.seed(model, rng_seed=0), hide=["coefs"]) | ||
) | ||
) | ||
guide.append( | ||
AutoDelta( | ||
numpyro.handlers.block(numpyro.handlers.seed(model, rng_seed=1), expose=["coefs"]) | ||
) | ||
) | ||
svi = SVI(model, guide, optim, Trace_ELBO()) | ||
svi_state = svi.init(rng_key_init, data, labels) | ||
params = svi.get_params(svi_state) | ||
|
||
:param callable model: a NumPyro model | ||
""" | ||
|
||
def __init__( | ||
self, model, *, prefix="auto", init_loc_fn=init_to_uniform, create_plates=None | ||
): | ||
self._guides = [] | ||
super().__init__( | ||
model, prefix=prefix, init_loc_fn=init_loc_fn, create_plates=create_plates | ||
) | ||
|
||
def append(self, part): | ||
""" | ||
Add an automatic or custom guide for part of the model. The guide should | ||
have been created by blocking the model to restrict to a subset of | ||
sample sites. No two parts should operate on any one sample site. | ||
|
||
:param part: a partial guide to add | ||
:type part: AutoGuide | ||
""" | ||
self._guides.append(part) | ||
|
||
def __call__(self, *args, **kwargs): | ||
if self.prototype_trace is None: | ||
# run model to inspect the model structure | ||
self._setup_prototype(*args, **kwargs) | ||
|
||
# create all plates | ||
self._create_plates(*args, **kwargs) | ||
|
||
# run slave guides | ||
result = {} | ||
for part in self._guides: | ||
result.update(part(*args, **kwargs)) | ||
return result | ||
|
||
def __getitem__(self, key): | ||
return self._guides[key] | ||
|
||
def __len__(self): | ||
return len(self._guides) | ||
|
||
def __iter__(self): | ||
yield from self._guides | ||
|
||
def sample_posterior(self, rng_key, params, *args, sample_shape=(), **kwargs): | ||
result = {} | ||
for part in self._guides: | ||
if isinstance(part, numpyro.infer.autoguide.AutoDelta) or isinstance( | ||
part, numpyro.infer.autoguide.AutoSemiDAIS | ||
): | ||
result.update( | ||
part.sample_posterior( | ||
rng_key, params, *args, sample_shape=sample_shape, **kwargs | ||
) | ||
) | ||
else: | ||
result.update( | ||
part.sample_posterior(rng_key, params, sample_shape=sample_shape) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you usage signature of other algorithm to be the same as AutoDelta? That signature is more general I think. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To reduce the chance of misunderstanding, can you please clarify what you mean. The reason for the if-else clause is that
I suppose we could break things if we change the signature of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Good point! Maybe we can keep your current implementation here as-is and adding There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds reasonable to me! I'm happy to create another PR to address this if you want. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, thank you! |
||
) | ||
return result | ||
|
||
def median(self, params): | ||
result = {} | ||
for part in self._guides: | ||
result.update(part.median(params)) | ||
return result | ||
|
||
def quantiles(self, params, quantiles): | ||
result = {} | ||
for part in self._guides: | ||
result.update(part.quantiles(params, quantiles)) | ||
return result | ||
|
||
|
||
class AutoNormal(AutoGuide): | ||
""" | ||
This implementation of :class:`AutoGuide` uses Normal distributions | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably you can raise ValueError in the constructor if the guide is either AutoDAIS or AutoSemiDAIS. I'm not sure if the algorithm gives reasonable results under AutoGuideList.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also
AutoSurrogateLikelihoodDAIS