Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is the loss function calculated on the raw data? #189

Open
Neel-132 opened this issue Dec 18, 2023 · 22 comments
Open

Is the loss function calculated on the raw data? #189

Neel-132 opened this issue Dec 18, 2023 · 22 comments

Comments

@Neel-132
Copy link

Sorry to disturb you again; but I am confused while coding the training loop by myself, is the loss calculated on the input data and not on the latent representation learnt by the encoder?

@jameschapman19
Copy link
Owner

On the representations

@jameschapman19
Copy link
Owner

(And trust me the EY loss is going to work better! I'm almost thinking about changing the default in this package to that solver)

@Neel-132
Copy link
Author

Neel-132 commented Dec 18, 2023

Well, I will be transparent with my use case. I have a list of three tensors say X, Y, Z for which I need to use CCA to project them in a common space. Now, since its more than two views I am using TCCA. Also, I need to write the training loop by myself, so I had this doubt about the loss function definition in class DCCA.

As shown in the screenshot below def forward in line 34 returns an encoded representation of each view and stores it in a list comprehension.

However, in the def loss in line 41 as shown in the screenshot, it has this line
representations = self(batch["views"])
return {"objective": self.objective(representations)}

so, this batch is the batch from trainloader. So, my question is this loss function does not use the encoded representations right? Or if it does, please help me a bit in this.

Screenshot 2023-12-18 135334

@jameschapman19
Copy link
Owner

Loss is applied to representations because we are calling self.objective with representations as argument

@Neel-132
Copy link
Author

Yes, I understand that but this representation variable is batch["views"] which is just a batch from trainloader right?

@jameschapman19
Copy link
Owner

I actually think in future I will change the callable loss classes like TCCALoss just class methods of DTCCA. It's just a throwback to an older version where it was helpful to do it the current way.

@Neel-132
Copy link
Author

Okay okay. But if you could answer my question it would really help me.

@jameschapman19
Copy link
Owner

Sorry half my response got lost so that looked really weird like I was just offering up something totally random 😅

The answer to your question is NO! I'm assuming familiarity with PyTorch but when we call self(arg) with a nn.Module we call its forward method.

So representations is the output of the forward method applied to the batch.

@jameschapman19
Copy link
Owner

jameschapman19 commented Dec 18, 2023

More generally in python you can use the__call__ method with classes so loss(args) is the same as loss__call__(args) (which is how the current loss class works).

@Neel-132
Copy link
Author

yes exactly but this throws an error which got me confused in the first place
Screenshot 2023-12-18 143539

@Neel-132
Copy link
Author

so, either I have to store the output of forward for a specific batch, store it in a dictionary with "views" as the key and then pass it. But that is kind of the last thing I wanted to do and that's why I thought of raising the issue first.

Thanks

@jameschapman19
Copy link
Owner

No? Output of self() is a list

Your data loader needs to have the structure that views is a key with a list of arrays.

@jameschapman19
Copy link
Owner

Yeah looking at your code I think you should read what eg the DCCA class of mine is doing. It's not a loss function it's a pytorch lightning module that implements dcca with a specific loss function.

If you take a look in the files you can see the loss function that DCCA is using behind the scenes

@Neel-132
Copy link
Author

Yes, in the DCCA class, the def loss function takes batch as its argument and passes batch['views'] to the respective objective. So, while I pass z = DCCA()(batch), loss(z) the error is obvious because the loss function expects a dictionary and not a list of tensor. So, that's where the error is coming from

@jameschapman19
Copy link
Owner

Ahhh! I've understood our confusion now!

So my DCCA class 'loss' method is applied to data not representation!

Apologies I thought you meant in general is the DCCA loss applied to data or representations.

So just change your snippet to DCCA.loss(batch) and will be fine

@jameschapman19
Copy link
Owner

(Because my DCCA loss method has a forward call inside)

@jameschapman19
Copy link
Owner

Have just got off a 12 hour flight so forgive me for not realising the motivation behind your Q!

@Neel-132
Copy link
Author

Exactly so this was my confusion. So, the DCCA loss takes batch of dataloader as its argument right.

Forgive me I dont know a lot about the theoretical aspect of CCA, so the for my case when I use
loss = loss(batch)
loss.backward()

How will the neural network learn the parameters if we don't pass the encoded representations in the loss?

@jameschapman19
Copy link
Owner

If the loss function was instead called 'encode_and_calculate_loss' you would understand right? It's a bit misleading maybe but also a done thing in NN code.

the function first passes raw data through encoder and then calculates the loss - look at what the function is actually doing.

@Neel-132
Copy link
Author

Yes, that is alright but my question is so, after a forward pass the DCCA encodes the given batch, then calls the loss on the batch and performs loss.backward().
So, is this sequence correct? That is my question

@jameschapman19
Copy link
Owner

This:

loss=DCCA().loss(batch) #batch contains dictionary with "views":list of tensors
loss.backward()
optimiser.step()

@jameschapman19
Copy link
Owner

The function that takes representations and returns a scalar loss is .objective()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants