We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
A following error appears on executing the colab notebook:
AttributeError Traceback (most recent call last) <ipython-input-7-4863fdbb2553> in <module>() 13 num_ensemble=FLAGS.index_dim, 14 prior_scale=FLAGS.prior_scale, ---> 15 seed=FLAGS.seed, 16 ) 17 4 frames /usr/local/lib/python3.7/dist-packages/enn/networks/ensembles.py in __init__(self, output_sizes, dummy_input, num_ensemble, prior_scale, seed, w_init, b_init) 137 """Ensemble of MLPs with matched prior functions.""" 138 mlp_priors = make_mlp_ensemble_prior_fns( --> 139 output_sizes, dummy_input, num_ensemble, seed) 140 enn = priors.EnnWithAdditivePrior( 141 enn=MLPEnsembleEnn( /usr/local/lib/python3.7/dist-packages/enn/networks/ensembles.py in make_mlp_ensemble_prior_fns(output_sizes, dummy_input, num_ensemble, seed, w_init, b_init) 90 return hk.Sequential(layers)(x) 91 ---> 92 transformed = hk.without_apply_rng(hk.transform(net_fn)) 93 94 prior_fns = [] /usr/local/lib/python3.7/dist-packages/haiku/_src/transform.py in transform(f, apply_rng) 301 "Replace hk.transform(..., apply_rng=True) with hk.transform(...).") 302 --> 303 return without_state(transform_with_state(f)) 304 305 /usr/local/lib/python3.7/dist-packages/haiku/_src/transform.py in transform_with_state(f) 359 """ 360 analytics.log_once("transform_with_state") --> 361 check_not_jax_transformed(f) 362 363 unexpected_tracer_hint = ( /usr/local/lib/python3.7/dist-packages/haiku/_src/transform.py in check_not_jax_transformed(f) 306 def check_not_jax_transformed(f): 307 # TODO(tomhennigan): Consider `CompiledFunction = type(jax.jit(lambda: 0))`. --> 308 if isinstance(f, (jax.xla.xe.CompiledFunction, jax.xla.xe.PmapFunction)): # pytype: disable=name-error 309 raise ValueError("A common error with Haiku is to pass an already jit " 310 "(or pmap) decorated function into hk.transform (e.g. " AttributeError: module 'jaxlib.xla_extension' has no attribute 'PmapFunction'
The text was updated successfully, but these errors were encountered:
No branches or pull requests
A following error appears on executing the colab notebook:
The text was updated successfully, but these errors were encountered: