Skip to content

Commit

Permalink
Require sample_shape passed as keyword argument to AutoGuides (#1659)
Browse files Browse the repository at this point in the history
* Require sample_shape passed as keyword argument

* Remove PriorModelGen
  • Loading branch information
tare authored Oct 9, 2023
1 parent 1a96406 commit dca5b2b
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 7 deletions.
3 changes: 1 addition & 2 deletions numpyro/contrib/nested_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
Model,
NestedSamplerResults,
Prior,
PriorModelGen,
TerminationCondition,
plot_cornerplot,
plot_diagnostics,
Expand Down Expand Up @@ -243,7 +242,7 @@ def run(self, rng_key, *args, **kwargs):
loglik_fn = local_dict["loglik_fn"]

# use NestedSampler with identity prior chain
def prior_model() -> PriorModelGen:
def prior_model():
params = []
for name in param_names:
shape = prototype_trace[name]["fn"].shape()
Expand Down
10 changes: 5 additions & 5 deletions numpyro/infer/autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def __call__(self, *args, **kwargs):
raise NotImplementedError

@abstractmethod
def sample_posterior(self, rng_key, params, sample_shape=()):
def sample_posterior(self, rng_key, params, *, sample_shape=()):
"""
Generate samples from the approximate posterior over the latent
sites in the model.
Expand Down Expand Up @@ -444,7 +444,7 @@ def _constrain(self, latent_samples):
else:
return self._postprocess_fn(latent_samples)

def sample_posterior(self, rng_key, params, sample_shape=()):
def sample_posterior(self, rng_key, params, *, sample_shape=()):
locs = {k: params["{}_{}_loc".format(k, self.prefix)] for k in self._init_locs}
scales = {k: params["{}_{}_scale".format(k, self.prefix)] for k in locs}
with handlers.seed(rng_seed=rng_key):
Expand Down Expand Up @@ -776,7 +776,7 @@ def get_posterior(self, params):
transform = self.get_transform(params)
return dist.TransformedDistribution(base_dist, transform)

def sample_posterior(self, rng_key, params, sample_shape=()):
def sample_posterior(self, rng_key, params, *, sample_shape=()):
latent_sample = handlers.substitute(
handlers.seed(self._sample_latent, rng_key), params
)(sample_shape=sample_shape)
Expand Down Expand Up @@ -965,7 +965,7 @@ def scan_body(carry, eps_beta):

return z

def sample_posterior(self, rng_key, params, sample_shape=()):
def sample_posterior(self, rng_key, params, *, sample_shape=()):
def _single_sample(_rng_key):
latent_sample = handlers.substitute(
handlers.seed(self._sample_latent, _rng_key), params
Expand Down Expand Up @@ -1916,7 +1916,7 @@ def get_posterior(self, params):
transform = self.get_transform(params)
return dist.MultivariateNormal(transform.loc, scale_tril=transform.scale_tril)

def sample_posterior(self, rng_key, params, sample_shape=()):
def sample_posterior(self, rng_key, params, *, sample_shape=()):
latent_sample = self.get_posterior(params).sample(rng_key, sample_shape)
return self._unpack_and_constrain(latent_sample, params)

Expand Down

0 comments on commit dca5b2b

Please sign in to comment.