Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Colab error with haiku/jax dependency #5

Open
hstojic opened this issue Dec 3, 2021 · 0 comments
Open

Colab error with haiku/jax dependency #5

hstojic opened this issue Dec 3, 2021 · 0 comments

Comments

@hstojic
Copy link

hstojic commented Dec 3, 2021

A following error appears on executing the colab notebook:

AttributeError                            Traceback (most recent call last)
<ipython-input-7-4863fdbb2553> in <module>()
     13     num_ensemble=FLAGS.index_dim,
     14     prior_scale=FLAGS.prior_scale,
---> 15     seed=FLAGS.seed,
     16 )
     17 

4 frames
/usr/local/lib/python3.7/dist-packages/enn/networks/ensembles.py in __init__(self, output_sizes, dummy_input, num_ensemble, prior_scale, seed, w_init, b_init)
    137     """Ensemble of MLPs with matched prior functions."""
    138     mlp_priors = make_mlp_ensemble_prior_fns(
--> 139         output_sizes, dummy_input, num_ensemble, seed)
    140     enn = priors.EnnWithAdditivePrior(
    141         enn=MLPEnsembleEnn(

/usr/local/lib/python3.7/dist-packages/enn/networks/ensembles.py in make_mlp_ensemble_prior_fns(output_sizes, dummy_input, num_ensemble, seed, w_init, b_init)
     90     return hk.Sequential(layers)(x)
     91 
---> 92   transformed = hk.without_apply_rng(hk.transform(net_fn))
     93 
     94   prior_fns = []

/usr/local/lib/python3.7/dist-packages/haiku/_src/transform.py in transform(f, apply_rng)
    301         "Replace hk.transform(..., apply_rng=True) with hk.transform(...).")
    302 
--> 303   return without_state(transform_with_state(f))
    304 
    305 

/usr/local/lib/python3.7/dist-packages/haiku/_src/transform.py in transform_with_state(f)
    359   """
    360   analytics.log_once("transform_with_state")
--> 361   check_not_jax_transformed(f)
    362 
    363   unexpected_tracer_hint = (

/usr/local/lib/python3.7/dist-packages/haiku/_src/transform.py in check_not_jax_transformed(f)
    306 def check_not_jax_transformed(f):
    307   # TODO(tomhennigan): Consider `CompiledFunction = type(jax.jit(lambda: 0))`.
--> 308   if isinstance(f, (jax.xla.xe.CompiledFunction, jax.xla.xe.PmapFunction)):  # pytype: disable=name-error
    309     raise ValueError("A common error with Haiku is to pass an already jit "
    310                      "(or pmap) decorated function into hk.transform (e.g. "

AttributeError: module 'jaxlib.xla_extension' has no attribute 'PmapFunction'
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant