Skip to content

Commit

Permalink
Validation set for Gradient Based Models woohoo
Browse files Browse the repository at this point in the history
  • Loading branch information
jameschapman19 committed Oct 19, 2023
1 parent d3f99f7 commit a710bbd
Showing 1 changed file with 20 additions and 20 deletions.
40 changes: 20 additions & 20 deletions cca_zoo/deep/objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,26 +266,26 @@ def loss(self, representations, independent_representations=None):
}


# class CCA_SVDLoss(CCA_EYLoss):
# def loss(self, representations, independent_representations=None):
# C = torch.cov(torch.hstack(representations).T)
# latent_dims = representations[0].shape[1]
#
# Cxy = C[:latent_dims, latent_dims:]
# Cxx = C[:latent_dims, :latent_dims]
#
# if independent_representations is None:
# Cyy = C[latent_dims:, latent_dims:]
# else:
# Cyy = torch.cov(independent_representations[1].T)
#
# rewards = torch.trace(2 * Cxy)
# penalties = torch.trace(Cxx @ Cyy)
# return {
# "objective": -rewards + penalties, # return the negative objective value
# "rewards": rewards, # return the total rewards
# "penalties": penalties, # return the penalties matrix
# }
class CCA_SVDLoss(CCA_EYLoss):
def loss(self, representations, independent_representations=None):
C = torch.cov(torch.hstack(representations).T)
latent_dims = representations[0].shape[1]

Cxy = C[:latent_dims, latent_dims:]
Cxx = C[:latent_dims, :latent_dims]

if independent_representations is None:
Cyy = C[latent_dims:, latent_dims:]
else:
Cyy = torch.cov(independent_representations[1].T)

rewards = torch.trace(2 * Cxy)
penalties = torch.trace(Cxx @ Cyy)
return {
"objective": -rewards + penalties, # return the negative objective value
"rewards": rewards, # return the total rewards
"penalties": penalties, # return the penalties matrix
}


class PLS_EYLoss(CCA_EYLoss):
Expand Down

0 comments on commit a710bbd

Please sign in to comment.