From f1a47a811bb2b261ffbd50808e85ab007fbf713d Mon Sep 17 00:00:00 2001 From: Surya Bhupatiraju Date: Fri, 1 Sep 2023 12:17:25 -0700 Subject: [PATCH] Replace deprecated jax.linear_util.wrap_init with jax.extend.linear_util.wrap_init. PiperOrigin-RevId: 562017973 --- haiku/_src/dot.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/haiku/_src/dot.py b/haiku/_src/dot.py index ab961ada5..e655d9b35 100644 --- a/haiku/_src/dot.py +++ b/haiku/_src/dot.py @@ -134,7 +134,7 @@ def to_graph(fun): @functools.wraps(fun) def wrapped_fun(*args): """See `fun`.""" - f = jax.linear_util.wrap_init(fun) + f = jax.extend.linear_util.wrap_init(fun) args_flat, in_tree = jax.tree_util.tree_flatten((args, {})) flat_fun, out_tree = jax.api_util.flatten_fun(f, in_tree) graph = Graph.create(title=name_or_str(fun)) @@ -202,7 +202,7 @@ def process_primitive(self, primitive, tracers, params): if primitive is pjit.pjit_p: f = jax.core.jaxpr_as_fun(params['jaxpr']) f.__name__ = params['name'] - fun = jax.linear_util.wrap_init(f) + fun = jax.extend.linear_util.wrap_init(f) return self.process_call(primitive, fun, tracers, params) inputs = [t.val for t in tracers]