Skip to content

Commit

Permalink
Replace deprecated jax.linear_util.wrap_init with jax.extend.linear_u…
Browse files Browse the repository at this point in the history
…til.wrap_init.

PiperOrigin-RevId: 562017973
  • Loading branch information
suryabhupa authored and copybara-github committed Sep 1, 2023
1 parent 568d4b1 commit f1a47a8
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions haiku/_src/dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def to_graph(fun):
@functools.wraps(fun)
def wrapped_fun(*args):
"""See `fun`."""
f = jax.linear_util.wrap_init(fun)
f = jax.extend.linear_util.wrap_init(fun)
args_flat, in_tree = jax.tree_util.tree_flatten((args, {}))
flat_fun, out_tree = jax.api_util.flatten_fun(f, in_tree)
graph = Graph.create(title=name_or_str(fun))
Expand Down Expand Up @@ -202,7 +202,7 @@ def process_primitive(self, primitive, tracers, params):
if primitive is pjit.pjit_p:
f = jax.core.jaxpr_as_fun(params['jaxpr'])
f.__name__ = params['name']
fun = jax.linear_util.wrap_init(f)
fun = jax.extend.linear_util.wrap_init(f)
return self.process_call(primitive, fun, tracers, params)

inputs = [t.val for t in tracers]
Expand Down

0 comments on commit f1a47a8

Please sign in to comment.