diff --git a/numpyro/ops/provenance.py b/numpyro/ops/provenance.py index 68797396e..a634d550d 100644 --- a/numpyro/ops/provenance.py +++ b/numpyro/ops/provenance.py @@ -14,7 +14,6 @@ from jax.interpreters.partial_eval import trace_to_jaxpr_dynamic from jax.interpreters.pxla import xla_pmap_p -import jax.numpy as jnp def eval_provenance(fn, **kwargs): @@ -53,18 +52,11 @@ def eval_provenance(fn, **kwargs): # get provenances of flatten kwargs aval_kwargs = {} for n, v in kwargs.items(): - aval = jax.ShapeDtypeStruct((), jnp.bool_, {"provenance": frozenset({n})}) - aval_kwargs[n] = jax.tree.map(lambda _: aval, v) - aval_args, _ = jax.tree.flatten(((), aval_kwargs)) - provenance_inputs = jax.tree.map(lambda x: x.named_shape["provenance"], aval_args) + aval_kwargs[n] = jax.tree.map(lambda _: frozenset({n}), v) + provenance_inputs, _ = jax.tree.flatten(((), aval_kwargs)) provenance_outputs = track_deps_jaxpr(jaxpr, provenance_inputs) - out_flat = [] - for v, p in zip(avals_out, provenance_outputs): - val = jax.ShapeDtypeStruct(jnp.shape(v), jnp.result_type(v), {"provenance": p}) - out_flat.append(val) - out = jax.tree.unflatten(out_tree(), out_flat) - return jax.tree.map(lambda x: x.named_shape["provenance"], out) + return jax.tree.unflatten(out_tree(), provenance_outputs) def track_deps_jaxpr(jaxpr, provenance_inputs):