Skip to content

Commit

Permalink
fixed failing doctest
Browse files Browse the repository at this point in the history
  • Loading branch information
OlaRonning committed Nov 4, 2024
1 parent 2c69e9e commit 7aa4aaa
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions numpyro/contrib/einstein/steinvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,7 @@ class SVGD(SteinVI):
>>> svgd_result = svgd.run(random.PRNGKey(0), 200, data)
>>> params = svgd_result.params
>>> predictive = Predictive(model, svgd.guide, params, num_samples=10, batch_ndims=1)
>>> predictive = Predictive(model, guide=svgd.guide, params=params, num_samples=10, batch_ndims=1)
>>> samples = predictive(random.PRNGKey(1), data=None)
:param Callable model: Python callable with NumPyro primitives for the model.
Expand Down Expand Up @@ -640,7 +640,7 @@ class ASVGD(SVGD):
>>> from numpyro.distributions.constraints import positive
>>> from numpyro.optim import Adagrad
>>> from numpyro.contrib.einstein import SVGD, RBFKernel
>>> from numpyro.contrib.einstein import ASVGD, RBFKernel
>>> from numpyro.infer import Predictive
>>> def model(data):
Expand All @@ -658,7 +658,7 @@ class ASVGD(SVGD):
>>> asvgd_result = asvgd.run(random.PRNGKey(0), 200, data)
>>> params = asvgd_result.params
>>> predictive = Predictive(model, asvgd.guide, params, num_samples=10, batch_ndims=1)
>>> predictive = Predictive(model, guide=asvgd.guide, params=params, num_samples=10, batch_ndims=1)
>>> samples = predictive(random.PRNGKey(1), data=None)
:param Callable model: Python callable with NumPyro primitives for the model.
Expand Down

0 comments on commit 7aa4aaa

Please sign in to comment.