Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] torchrl.objectives.SACLoss is broken when there is more than one qvalue_network #2589

Open
3 tasks done
fmeirinhos opened this issue Nov 20, 2024 · 0 comments
Open
3 tasks done
Assignees
Labels
bug Something isn't working

Comments

@fmeirinhos
Copy link

Describe the bug

The torchrl.objectives.SACLoss module is currently broken when the input type of qvalue_network is a List[TensorDictModule].

Note also the discrepancy between the docstring type TensorDictModule and the constructor-argument union-type TensorDictModule | List[TensorDictModule].

The bug is there because the internal method _set_in_keys cannot extract the in_keys of a List[TensorDictModule].

NOTE: I do not know if this is the only extent for which the method breaks down when there are multiple qvalue_networks.

To Reproduce

This is the same example given in the docstring, but with two qvalue_networks

import torch
from torch import nn
from torchrl.data import Bounded
from tensordict import TensorDict
from torchrl.modules.distributions import NormalParamExtractor, TanhNormal
from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
from torchrl.modules.tensordict_module.common import SafeModule
from torchrl.objectives.sac import SACLoss

_ = torch.manual_seed(42)
n_act, n_obs = 4, 3
spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,))
module = SafeModule(
    module=nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()),
    in_keys=["observation"],
    out_keys=["loc", "scale"],
)
actor = ProbabilisticActor(
    module=module,
    in_keys=["loc", "scale"],
    spec=spec,
    distribution_class=TanhNormal,
)


class ValueClass(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(n_obs + n_act, 1)

    def forward(self, obs, act):
        return self.linear(torch.cat([obs, act], -1))


qvalue = ValueOperator(
    module=ValueClass(),
    in_keys=["observation", "action"],
)
value = ValueOperator(
    module=nn.Linear(n_obs, 1),
    in_keys=["observation"],
)
loss = SACLoss(actor, [qvalue, qvalue], num_qvalue_nets=2)
batch = [
    2,
]
action = spec.rand(batch)
data = TensorDict(
    {
        "observation": torch.randn(*batch, n_obs),
        "action": action,
        ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool),
        ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool),
        ("next", "reward"): torch.randn(*batch, 1),
        ("next", "observation"): torch.randn(*batch, n_obs),
    },
    batch,
)

loss(data)

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)
@fmeirinhos fmeirinhos added the bug Something isn't working label Nov 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants