From 7a27112d506effe0f82bdadae2cadaede75051ba Mon Sep 17 00:00:00 2001 From: Du Phan Date: Thu, 21 Sep 2023 17:44:58 -0400 Subject: [PATCH] fix some deprecation warnings --- numpyro/contrib/einstein/steinvi.py | 5 ++--- numpyro/ops/provenance.py | 2 +- test/infer/test_mcmc.py | 2 +- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/numpyro/contrib/einstein/steinvi.py b/numpyro/contrib/einstein/steinvi.py index 5ff91f03a..0436b10ed 100644 --- a/numpyro/contrib/einstein/steinvi.py +++ b/numpyro/contrib/einstein/steinvi.py @@ -10,7 +10,6 @@ import operator from jax import grad, jacfwd, numpy as jnp, random, vmap -from jax.random import KeyArray from jax.tree_util import tree_map from numpyro import handlers @@ -370,10 +369,10 @@ def _update_force(attr_force, rep_force, jac): ) return jnp.linalg.norm(particle_grads), res_grads - def init(self, rng_key: KeyArray, *args, **kwargs): + def init(self, rng_key, *args, **kwargs): """Register random variable transformations, constraints and determine initialize positions of the particles. - :param KeyArray rng_key: Random number generator seed. + :param rng_key: Random number generator seed. :param args: Arguments to the model / guide. :param kwargs: Keyword arguments to the model / guide. :return: initial :data:`SteinVIState` diff --git a/numpyro/ops/provenance.py b/numpyro/ops/provenance.py index 5386435f7..a53f26ae8 100644 --- a/numpyro/ops/provenance.py +++ b/numpyro/ops/provenance.py @@ -7,7 +7,7 @@ from jax.experimental.pjit import pjit_p from jax.interpreters.partial_eval import trace_to_jaxpr_dynamic from jax.interpreters.pxla import xla_pmap_p -import jax.linear_util as lu +import jax.extend.linear_util as lu import jax.numpy as jnp diff --git a/test/infer/test_mcmc.py b/test/infer/test_mcmc.py index 18b21d529..368fde593 100644 --- a/test/infer/test_mcmc.py +++ b/test/infer/test_mcmc.py @@ -917,7 +917,7 @@ def model(cov): cov = np.zeros((5, 5)) cov[:2, :2] = w_cov cov[2:4, 2:4] = xy_cov - cov[4, 4] = z_var + cov[4, 4] = z_var[0] kernel = NUTS(model, dense_mass=[("w",), ("x", "y")]) mcmc = MCMC(kernel, num_warmup=1000, num_samples=1)