Skip to content

Commit

Permalink
Validation set for Gradient Based Models woohoo
Browse files Browse the repository at this point in the history
  • Loading branch information
jameschapman19 committed Oct 18, 2023
1 parent f46c68b commit 54c289c
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 21 deletions.
28 changes: 12 additions & 16 deletions cca_zoo/deep/objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from tensorly.cp_tensor import cp_to_tensor
from tensorly.decomposition import parafac

from cca_zoo.utils import cross_cov


def inv_sqrtm(A, eps=1e-9):
"""Compute the inverse square-root of a positive definite matrix."""
Expand Down Expand Up @@ -231,9 +233,7 @@ def get_AB(self, representations):
if i == j:
B += torch.cov(zi.T) # add the auto-covariance of each view to B
else:
A += torch.cov(torch.hstack((zi, zj)).T)[
latent_dimensions :, : latent_dimensions
] # add the cross-covariance of each pair of representations to A
A += cross_cov(zi, zj, rowvar=False) # add the cross-covariance of each view to A
return A / len(representations), B / len(
representations
) # return the normalized matrices (divided by the number of representations)
Expand All @@ -243,10 +243,10 @@ def loss(self, representations, independent_representations=None):
A, B = self.get_AB(representations)
rewards = torch.trace(2 * A)
if independent_representations is None:
penalties = torch.trace(B @ B)
penalties = torch.trace(A.detach() @ B)
else:
independent_A, independent_B = self.get_AB(independent_representations)
penalties = torch.trace(B @ independent_B)
penalties = torch.trace(independent_A.detach() @ B)
return {
"objective": -rewards + penalties,
"rewards": rewards,
Expand All @@ -264,9 +264,9 @@ def loss(self, representations, independent_representations=None):
if independent_representations is None:
Cyy = C[latent_dims:, latent_dims:]
else:
Cyy = torch.cov(torch.hstack(independent_representations).T)[latent_dims:, latent_dims:]
Cyy = cross_cov(independent_representations[1], independent_representations[1], rowvar=False)

rewards = torch.trace(2 * Cxy)
rewards = torch.trace(2*Cxy)
penalties = torch.trace(Cxx @ Cyy)
return {
"objective": -rewards + penalties, # return the negative objective value
Expand Down Expand Up @@ -299,21 +299,17 @@ def get_AB(self, representations, weights=None):
if i == j:
B += weights[i].T @ weights[i] / n
else:
A += torch.cov(torch.hstack((zi, zj)).T)[
latent_dimensions:, :latent_dimensions
] # add the cross-covariance of each pair of representations to A
A += cross_cov(zi, zj, rowvar=False)
return A / len(representations), B / len(representations)


class PLS_SVDLoss(PLS_EYLoss):
def loss(self, representations, weights=None):
C = torch.cov(torch.hstack(representations).T)
latent_dims = representations[0].shape[1]
C = cross_cov(representations[0], representations[1], rowvar=False)

n = representations[0].shape[0]
Cxy = C[:latent_dims, latent_dims:]
Cxx = weights[0].T @ weights[0] / n
Cyy = weights[1].T @ weights[1] / n
Cxy = C
Cxx = weights[0].T @ weights[0] / representations[0].shape[0]
Cyy = weights[1].T @ weights[1] / representations[1].shape[0]

rewards = torch.trace(2 * Cxy)
penalties = torch.trace(Cxx @ Cyy)
Expand Down
2 changes: 1 addition & 1 deletion cca_zoo/linear/_gradient/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def configure_optimizers(self) -> torch.optim.Optimizer:
torch.optim.Optimizer: The optimizer object.
"""
# construct optimizer using optimizer_kwargs
optimizer_name = self.optimizer_kwargs.get("optimizer", "SGD")
optimizer_name = self.optimizer_kwargs.get("optimizer", "Adam")
kwargs = self.optimizer_kwargs.copy()
kwargs.pop("optimizer", None)
optimizer = getattr(torch.optim, optimizer_name)(
Expand Down
4 changes: 2 additions & 2 deletions cca_zoo/linear/_gradient/_svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@ class CCA_SVD(CCA_EY):
objective = CCA_SVDLoss()


class PLS_SVD(CCA_SVD):
objective = PLS_SVDLoss()
# class PLS_SVD(CCA_SVD):
# objective = PLS_SVDLoss()
4 changes: 2 additions & 2 deletions test/test_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_batch_pls():
plssvd = PLS_SVD(
latent_dimensions=latent_dims,
epochs=epochs,
learning_rate=learning_rate,
learning_rate=learning_rate/2,
random_state=random_state,
).fit((X, Y))
pls_score = scale_transform(pls, X, Y)
Expand Down Expand Up @@ -87,7 +87,7 @@ def test_batch_cca():
ccasvd = CCA_SVD(
latent_dimensions=latent_dims,
epochs=epochs,
learning_rate=learning_rate * 10,
learning_rate=learning_rate,
random_state=random_state,
).fit((X, Y))
cca_score = cca.score((X, Y))
Expand Down

0 comments on commit 54c289c

Please sign in to comment.