-
Notifications
You must be signed in to change notification settings - Fork 286
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added directclr loss #963
base: master
Are you sure you want to change the base?
Added directclr loss #963
Conversation
Codecov ReportBase: 88.95% // Head: 88.46% // Decreases project coverage by
Additional details and impacted files@@ Coverage Diff @@
## master #963 +/- ##
==========================================
- Coverage 88.95% 88.46% -0.50%
==========================================
Files 108 96 -12
Lines 5023 4567 -456
==========================================
- Hits 4468 4040 -428
+ Misses 555 527 -28
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report at Codecov. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @Atharva-Phatak, thank you so much for implementing this! I have left some comments and a few commit suggestions but the overall contribution looks already great.
Regarding the tests, I'd suggest the following steps:
- create a new file under
tests/loss/test_InfoNCELoss
- add a class
class TestInfoNCELoss(unittest.TestCase)
- add functions of the form
test_xyz(self)
to verify that the loss is computed correctly (e.g. for a certain input we expect a certain output).
You can take a look at other examples for inspiration.
|
||
#Adapted from https://github.com/facebookresearch/directclr/blob/main/directclr/main.py | ||
class InfoNCELoss(nn.Module): | ||
"""Implementation of InfoNCELoss as required for DIRECTCLR""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can leave the reference to DIRECTCLR
away here.
lightly/loss/directclr_loss.py
Outdated
dim : Dimension of subvector to be used to compute InfoNCELoss. | ||
temprature: The value used to scale logits. | ||
""" | ||
self.temprature = temprature |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo, it's temperature
🙂
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry :(
lightly/loss/directclr_loss.py
Outdated
#dimension of subvector sent to infoNCE | ||
self.dim = dim | ||
|
||
def normalize(self, x:torch.Tensor) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would raise the question if it's necessary to put this in its own function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well technically not but I would avoid writing torch.nn.functional(x, dim = 1)
again and again :)
Co-authored-by: Philipp Wirth <[email protected]>
Co-authored-by: Philipp Wirth <[email protected]>
Co-authored-by: Philipp Wirth <[email protected]>
Co-authored-by: Philipp Wirth <[email protected]>
Co-authored-by: Philipp Wirth <[email protected]>
I will add tests ASAP :) |
Hey @Atharva-Phatak do you need more help regarding the tests? |
@philippmwirth I am sorry for the delay, I am currently busy with my mid-semesters. I have my last paper tomorrow, then I would add the tests at the EOD tomorrow. |
No, don't worry! Good luck with your mid-semesters 🙂 |
@philippmwirth Can you help me out here? I am not able to spot the mistake or Have I written the test wrong ? |
Hi @Atharva-Phatak I believe you have to call The error message indicates that some modules related attributes are missing:
You sadly have to scroll quite far back up in the unit tests output to find the actual stacktrace. |
@guarin Need help fixing this conflict. |
Hi @Atharva-Phatak, sorry for the late reply but I fixed the merge conflict. There still seems to be an issue with the loss implementation though. |
Let me check the implementation once again. |
Addded Implementation of DirectCLR loss as proposed in #781
@philippmwirth I need some help on writing tests. If you could guide me that would be amazing.