Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No loss called: L2LossWithBootstrap #24

Open
rudibhargava opened this issue Mar 27, 2024 · 0 comments
Open

No loss called: L2LossWithBootstrap #24

rudibhargava opened this issue Mar 27, 2024 · 0 comments

Comments

@rudibhargava
Copy link

When trying to build my own enn, got the following error:

AttributeError                            Traceback (most recent call last)
Cell In[5], line 25
     17 enn = networks.MLPEnsembleMatchedPrior(
     18     output_sizes=[50, 50, 1],
     19     num_ensemble=10,
     20     dummy_input=np.zeros(50)
     21 )
     23 # Loss
     24 loss_fn = losses.average_single_index_loss(
---> 25     single_loss=losses.L2LossWithBootstrap(),
     26     num_index_samples=10
     27 )
     29 # Optimizer
     30 optimizer = optax.adam(1e-3)

AttributeError: module 'enn.losses' has no attribute 'L2LossWithBootstrap'

Here is the code that I created for the network:

from enn.loggers import TerminalLogger

from enn import losses
from enn import networks
from enn import supervised
from enn.supervised import regression_data
import optax
import numpy as np

# A small dummy dataset
dataset = regression_data.make_dataset()

# Logger
logger = TerminalLogger('supervised_regression')

# ENN
enn = networks.MLPEnsembleMatchedPrior(
    output_sizes=[50, 50, 1],
    num_ensemble=10,
    dummy_input=np.zeros(50)
)

# Loss
loss_fn = losses.average_single_index_loss(
    single_loss=losses.L2LossWithBootstrap(),
    num_index_samples=10
)

# Optimizer
optimizer = optax.adam(1e-3)

# Train the experiment
experiment = supervised.Experiment(
    enn, loss_fn, optimizer, dataset, seed=0, logger=logger)
experiment.train(FLAGS.num_batch)

Also not that to get the example to work, I had to add the line

    dummy_input=np.zeros(50)

otherwise I got an error that dummy_input was a required positional argument.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant