We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
When trying to build my own enn, got the following error:
AttributeError Traceback (most recent call last) Cell In[5], line 25 17 enn = networks.MLPEnsembleMatchedPrior( 18 output_sizes=[50, 50, 1], 19 num_ensemble=10, 20 dummy_input=np.zeros(50) 21 ) 23 # Loss 24 loss_fn = losses.average_single_index_loss( ---> 25 single_loss=losses.L2LossWithBootstrap(), 26 num_index_samples=10 27 ) 29 # Optimizer 30 optimizer = optax.adam(1e-3) AttributeError: module 'enn.losses' has no attribute 'L2LossWithBootstrap'
Here is the code that I created for the network:
from enn.loggers import TerminalLogger from enn import losses from enn import networks from enn import supervised from enn.supervised import regression_data import optax import numpy as np # A small dummy dataset dataset = regression_data.make_dataset() # Logger logger = TerminalLogger('supervised_regression') # ENN enn = networks.MLPEnsembleMatchedPrior( output_sizes=[50, 50, 1], num_ensemble=10, dummy_input=np.zeros(50) ) # Loss loss_fn = losses.average_single_index_loss( single_loss=losses.L2LossWithBootstrap(), num_index_samples=10 ) # Optimizer optimizer = optax.adam(1e-3) # Train the experiment experiment = supervised.Experiment( enn, loss_fn, optimizer, dataset, seed=0, logger=logger) experiment.train(FLAGS.num_batch)
Also not that to get the example to work, I had to add the line
dummy_input=np.zeros(50)
otherwise I got an error that dummy_input was a required positional argument.
The text was updated successfully, but these errors were encountered:
No branches or pull requests
When trying to build my own enn, got the following error:
Here is the code that I created for the network:
Also not that to get the example to work, I had to add the line
otherwise I got an error that dummy_input was a required positional argument.
The text was updated successfully, but these errors were encountered: