Skip to content

Commit

Permalink
simplify the provenance logic to prepare for the removal of jax named…
Browse files Browse the repository at this point in the history
…_shape (#1837)
  • Loading branch information
fehiepsi authored Jul 29, 2024
1 parent f6eb6ce commit 3e41320
Showing 1 changed file with 3 additions and 11 deletions.
14 changes: 3 additions & 11 deletions numpyro/ops/provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from jax.interpreters.partial_eval import trace_to_jaxpr_dynamic
from jax.interpreters.pxla import xla_pmap_p
import jax.numpy as jnp


def eval_provenance(fn, **kwargs):
Expand Down Expand Up @@ -53,18 +52,11 @@ def eval_provenance(fn, **kwargs):
# get provenances of flatten kwargs
aval_kwargs = {}
for n, v in kwargs.items():
aval = jax.ShapeDtypeStruct((), jnp.bool_, {"provenance": frozenset({n})})
aval_kwargs[n] = jax.tree.map(lambda _: aval, v)
aval_args, _ = jax.tree.flatten(((), aval_kwargs))
provenance_inputs = jax.tree.map(lambda x: x.named_shape["provenance"], aval_args)
aval_kwargs[n] = jax.tree.map(lambda _: frozenset({n}), v)
provenance_inputs, _ = jax.tree.flatten(((), aval_kwargs))

provenance_outputs = track_deps_jaxpr(jaxpr, provenance_inputs)
out_flat = []
for v, p in zip(avals_out, provenance_outputs):
val = jax.ShapeDtypeStruct(jnp.shape(v), jnp.result_type(v), {"provenance": p})
out_flat.append(val)
out = jax.tree.unflatten(out_tree(), out_flat)
return jax.tree.map(lambda x: x.named_shape["provenance"], out)
return jax.tree.unflatten(out_tree(), provenance_outputs)


def track_deps_jaxpr(jaxpr, provenance_inputs):
Expand Down

0 comments on commit 3e41320

Please sign in to comment.