Skip to content

Commit

Permalink
REMOVED deprecated jax.tree_util
Browse files Browse the repository at this point in the history
  • Loading branch information
OlaRonning committed Oct 19, 2024
1 parent c34dd66 commit 4e2296e
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 13 deletions.
7 changes: 3 additions & 4 deletions numpyro/contrib/einstein/mixture_guide_predictive.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from typing import Optional
import warnings

from jax import numpy as jnp, random, vmap
from jax.tree_util import tree_flatten, tree_map
from jax import numpy as jnp, random, tree, vmap

from numpyro.handlers import substitute
from numpyro.infer import Predictive
Expand Down Expand Up @@ -77,7 +76,7 @@ def __init__(

self.guide = guide
self.return_sites = return_sites
self.num_mixture_components = jnp.shape(tree_flatten(params)[0][0])[0]
self.num_mixture_components = jnp.shape(tree.flatten(params)[0][0])[0]
self.mixture_assignment_sitename = mixture_assignment_sitename

def _call_with_params(self, rng_key, params, args, kwargs):
Expand Down Expand Up @@ -113,7 +112,7 @@ def __call__(self, rng_key, *args, **kwargs):
minval=0,
maxval=self.num_mixture_components,
)
predictive_assign = tree_map(
predictive_assign = tree.map(
lambda arr: vmap(lambda i, assign: arr[i, assign])(
jnp.arange(self._batch_shape[0]), assigns
),
Expand Down
7 changes: 3 additions & 4 deletions numpyro/contrib/einstein/steinvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
from itertools import chain
import operator

from jax import grad, numpy as jnp, random, vmap
from jax import grad, numpy as jnp, random, tree, vmap
from jax.flatten_util import ravel_pytree
from jax.tree_util import tree_map

from numpyro import handlers
from numpyro.contrib.einstein.stein_loss import SteinLoss
Expand Down Expand Up @@ -346,7 +345,7 @@ def loss_fn(particle, i):
stein_param_grads = unravel_pytree_batched(particle_grads)

# 8. Return loss and gradients (based on parameter forces)
res_grads = tree_map(
res_grads = tree.map(
lambda x: -x, {**non_mixture_param_grads, **stein_param_grads}
)
return jnp.linalg.norm(particle_grads), res_grads
Expand Down Expand Up @@ -405,7 +404,7 @@ def init(self, rng_key, *args, **kwargs):
if site["name"] in guide_init_params:
pval = guide_init_params[site["name"]]
if self.non_mixture_params_fn(site["name"]):
pval = tree_map(lambda x: x[0], pval)
pval = tree.map(lambda x: x[0], pval)
else:
pval = site["value"]
params[site["name"]] = transform.inv(pval)
Expand Down
9 changes: 4 additions & 5 deletions test/contrib/einstein/test_steinvi_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
import pytest
import scipy

from jax import numpy as jnp
from jax.tree_util import tree_flatten, tree_map
from jax import numpy as jnp, tree

from numpyro.contrib.einstein.stein_util import batch_ravel_pytree, posdef, sqrth

Expand Down Expand Up @@ -82,10 +81,10 @@ def test_sqrth_shape(batch_shape):
def test_ravel_pytree_batched(pytree, nbatch_dims):
flat, _, unravel_fn = batch_ravel_pytree(pytree, nbatch_dims)
unravel = unravel_fn(flat)
tree_flatten(tree_map(lambda x, y: assert_allclose(x, y), unravel, pytree))
tree.flatten(tree.map(lambda x, y: assert_allclose(x, y), unravel, pytree))
assert all(
tree_flatten(
tree_map(
tree.flatten(
tree.map(
lambda x, y: jnp.result_type(x) == jnp.result_type(y), unravel, pytree
)
)[0]
Expand Down

0 comments on commit 4e2296e

Please sign in to comment.