Skip to content

Commit

Permalink
fix some deprecation warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi committed Sep 21, 2023
1 parent c823024 commit 7a27112
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 5 deletions.
5 changes: 2 additions & 3 deletions numpyro/contrib/einstein/steinvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import operator

from jax import grad, jacfwd, numpy as jnp, random, vmap
from jax.random import KeyArray
from jax.tree_util import tree_map

from numpyro import handlers
Expand Down Expand Up @@ -370,10 +369,10 @@ def _update_force(attr_force, rep_force, jac):
)
return jnp.linalg.norm(particle_grads), res_grads

def init(self, rng_key: KeyArray, *args, **kwargs):
def init(self, rng_key, *args, **kwargs):
"""Register random variable transformations, constraints and determine initialize positions of the particles.
:param KeyArray rng_key: Random number generator seed.
:param rng_key: Random number generator seed.
:param args: Arguments to the model / guide.
:param kwargs: Keyword arguments to the model / guide.
:return: initial :data:`SteinVIState`
Expand Down
2 changes: 1 addition & 1 deletion numpyro/ops/provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from jax.experimental.pjit import pjit_p
from jax.interpreters.partial_eval import trace_to_jaxpr_dynamic
from jax.interpreters.pxla import xla_pmap_p
import jax.linear_util as lu
import jax.extend.linear_util as lu
import jax.numpy as jnp


Expand Down
2 changes: 1 addition & 1 deletion test/infer/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,7 +917,7 @@ def model(cov):
cov = np.zeros((5, 5))
cov[:2, :2] = w_cov
cov[2:4, 2:4] = xy_cov
cov[4, 4] = z_var
cov[4, 4] = z_var[0]

kernel = NUTS(model, dense_mass=[("w",), ("x", "y")])
mcmc = MCMC(kernel, num_warmup=1000, num_samples=1)
Expand Down

0 comments on commit 7a27112

Please sign in to comment.