Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] SAC loss masking #2612

Closed
matteobettini opened this issue Nov 27, 2024 · 13 comments
Closed

[BUG] SAC loss masking #2612

matteobettini opened this issue Nov 27, 2024 · 13 comments
Assignees
Labels
bug Something isn't working

Comments

@matteobettini
Copy link
Contributor

PR #2606

Introduces indexing of loss tensordict using done signals

next_tensordict_select = next_tensordict[~done.squeeze(-1)]

next_tensordict_select = next_tensordict[~done.squeeze(-1)]

I have multiple concerns regarding this PR:

  1. masking using done signals is not always possible as the done signal could have more dimensions than the loss tensordict, leading to errors in multiagent settings
  2. the action network could be expecting a specific imput dimension, masking the input could lead to arbitrary unsupported shapes and crashes in the actor (also not in multiagent settings)
  3. in continuous sac, the mask is applied to output actions and logprobabilities (with arbitrary 0 entries). this does not seem necessary, those values will be discarded anyway as the value_estimate() will read the dones
  4. same in descrete sac, it seems to me that the value_estimate() already reads done and discards next_values for done states. plus the target of a done state should be the reward, so here by using 0s we are actually introducing a further bug

In my opinion this change was not needed as the done target_values of done states are already discarded in value_estimate().

Maybe I am wrong in this analysis, please let me know.

I do not think it is possible to avoid submitting inputs of done states to the policy without changing the input shape (which we should avoid as it could lead to errors)

@matteobettini matteobettini added the bug Something isn't working label Nov 27, 2024
@matteobettini
Copy link
Contributor Author

matteobettini commented Nov 27, 2024

If the problem is that the observation of done states could be nan (which I argue it shouldn't) we could consider replacing nans with 0 so that the forward can run

@vmoens
Copy link
Contributor

vmoens commented Nov 27, 2024

In my opinion this change was not needed as the done target_values of done states are already discarded in value_estimate().

This change was indeed required as some network cannot accept values of observations that are written when the environment is done. There should not be any forward pass for values that are nonsensical and where outputs will not be used.

See #2590 for context.

If the problem is that the observation of done states could be nan (which I argue it shouldn't) we could consider replacing nans with 0 so that the forward can run

NaN were there just to exemplify. In practice, the network should not be queried if the values are not used. For instance you could have a model that implements some sort of internal state update at each query and you woulnd't want this to be modified by values that will be discarded.

Re (1), we could decide not to mask with done if the shapes cannot be broadcast.
Re (2) I think this can be addressed to by checking that the squeezed done has the shape of the tensordict. If so, we're only modifying the batch size and things should work ok (maybe not with RNNs?)

If this can be addressed differently I'm happy to give it a look!

cc @fmeirinhos

@vmoens
Copy link
Contributor

vmoens commented Nov 27, 2024

I think we also need more coverage of MARL usage of the losses in the tests, because this could have been easily spotted if running pytest test/test_cost.py -k SAC had any MARL setting in it

@matteobettini
Copy link
Contributor Author

matteobettini commented Nov 27, 2024

I understand the orginal issue, but this seems to be a very difficult pickle.

not all networks can be queried with sparse data or data of arbitrary shape (list is long)

@matteobettini
Copy link
Contributor Author

There should not be any forward pass for values that are nonsensical and where outputs will not be used.

Not sure it makes sense to enforce this here without introducing problems bigger than the original

@vmoens
Copy link
Contributor

vmoens commented Nov 27, 2024

Not sure it makes sense to enforce this here without introducing problems bigger than the original

You know what, I'm happy to revert that PR as soon as we can figure out what to do for people who have networks that simply cannot accept "done" observation values!

I do think it's a valid concern and it should be addressed, but obviously by a non-buggy solution.

@matteobettini
Copy link
Contributor Author

matteobettini commented Nov 27, 2024

I have been checking a bit how other libraries do it and they seem to pass the next obs anyway

maybe it is my opinion, but i don’t see what is particular about an observation of a done state, it should be part of the same observation space as the others

furthermore, for policies that have an internal state or counter, this in sac is a bit unnatural as the policy is called from 2 places anyway (actions and values) so keeping track of meaningful states is hard

@vmoens
Copy link
Contributor

vmoens commented Nov 27, 2024

I have been checking a bit how other libraries do it and they seem to pass the next obs anyway

I don't think we should overfit too much to what other libs do but to issues our users are facing.

furthermore, for policies that have an internal state or counter

this is just an example. The point is that if an error is thrown when invalid data is passed to a network, we should never reach that error (or give the tooling necessary to avoid that).

We could add a flag in the constructor like skip_done_states which defaults to False.

Then we capture errors where relevant and if the actor network raises during a call on the next data we tell users about this flag. In all other cases the behaviour is unchanged.

I gave it a shot in #2613 (without the capture of the error)

@vmoens
Copy link
Contributor

vmoens commented Nov 27, 2024

For the record, this is an example of a function that errors when there's a NaN

>>> import torch
>>> matrix_with_nan = torch.tensor([[1.0, 2.0], [float('nan'), 4.0]])
>>> 
>>> result = torch.linalg.cholesky(matrix_with_nan)

Note that replacing NaN with 0s with cholesky is also problematic

>>> torch.linalg.cholesky(torch.zeros(4, 4))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
torch._C._LinAlgError: linalg.cholesky: The factorization could not be completed because the input is not positive-definite (the leading minor of order 1 is not positive-definite).

@matteobettini
Copy link
Contributor Author

That seems a reasonable solution to me, happy to review.

i think we are just facing a difficult issue as i can clearly understand where both problems are coming from. I also don’t like padding or calls on useless data

@matteobettini
Copy link
Contributor Author

Note that replacing NaN with 0s with cholesky is also problematic

Also flattening the cholesky input and removing the NaN value is problematic no?

@vmoens
Copy link
Contributor

vmoens commented Nov 27, 2024

Also flattening the cholesky input and removing the NaN value is problematic no?

no the matrix is in the feature dim, not the batch dim. It isn't flattened

@matteobettini
Copy link
Contributor Author

Solved by #2613

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants