Skip to content

Commit

Permalink
Big Update
Browse files Browse the repository at this point in the history
  • Loading branch information
jameschapman19 committed Nov 14, 2023
1 parent ec348d7 commit 6562e2d
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 125 deletions.
11 changes: 11 additions & 0 deletions cca_zoo/_utils/cross_correlation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,17 @@
import torch


def torch_cross_cov(A, B):
A = A.T
B = B.T

A = A - A.mean(dim=1, keepdim=True)
B = B - B.mean(dim=1, keepdim=True)

C = A @ B.T
return C / (A.size(1) - 1)


def cross_corrcoef(A, B, rowvar=True):
"""Cross correlation of two matrices.
Expand Down
181 changes: 105 additions & 76 deletions cca_zoo/deep/objectives.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import List, Optional

import tensorly as tl
import torch
from tensorly.cp_tensor import cp_to_tensor
from tensorly.decomposition import parafac

from cca_zoo._utils import cross_cov
from cca_zoo._utils.cross_correlation import torch_cross_cov


def inv_sqrtm(A, eps=1e-9):
Expand Down Expand Up @@ -209,71 +210,78 @@ def loss(self, views):
return torch.linalg.norm(M - M_hat)


class _CCAAB:
def get_AB(self, representations):
latent_dimensions = representations[0].shape[1]
A = torch.zeros(
latent_dimensions, latent_dimensions, device=representations[0].device
) # initialize the cross-covariance matrix
B = torch.zeros(
latent_dimensions, latent_dimensions, device=representations[0].device
) # initialize the auto-covariance matrix
for i, zi in enumerate(representations):
for j, zj in enumerate(representations):
if i == j:
B += torch.cov(zi.T) # add the auto-covariance of each view to B
A += cross_cov(
zi, zj, rowvar=False
) # add the cross-covariance of each view to A
return A / len(representations), B / len(
representations
) # return the normalized matrices (divided by the number of representations)


class _PLSAB:
def get_AB(self, representations, weights=None):
latent_dimensions = representations[0].shape[1]
A = torch.zeros(
latent_dimensions, latent_dimensions, device=representations[0].device
) # initialize the cross-covariance matrix
B = torch.zeros(
latent_dimensions, latent_dimensions, device=representations[0].device
) # initialize the auto-covariance matrix
for i, zi in enumerate(representations):
for j, zj in enumerate(representations):
if i == j:
B += weights[i].T @ weights[i]
else:
A += cross_cov(zi, zj, rowvar=False)
return A / len(representations), B / len(representations)


class _CCA_EYLoss(_CCAAB):
@torch.jit.script
def CCA_AB(representations: list[torch.Tensor]):
latent_dimensions = representations[0].shape[1]
A = torch.zeros(
latent_dimensions, latent_dimensions, device=representations[0].device
) # initialize the cross-covariance matrix
B = torch.zeros(
latent_dimensions, latent_dimensions, device=representations[0].device
) # initialize the auto-covariance matrix
for zi in representations:
B.add_(torch.cov(zi.T)) # In-place addition
for zj in representations:
A.add_(torch_cross_cov(zi, zj)) # In-place addition

A.div_(len(representations)) # In-place division
B.div_(len(representations)) # In-place division
return A, B


@torch.jit.script
def PLS_AB(representations: list[torch.Tensor], weights: list[torch.Tensor]):
latent_dimensions = representations[0].shape[1]
A = torch.zeros(
latent_dimensions, latent_dimensions, device=representations[0].device
) # initialize the cross-covariance matrix
B = torch.zeros(
latent_dimensions, latent_dimensions, device=representations[0].device
) # initialize the auto-covariance matrix

for i, zi in enumerate(representations):
for j, zj in enumerate(representations):
if i == j:
B.add_(weights[i].T @ weights[i]) # In-place addition
else:
A.add_(torch_cross_cov(zi, zj)) # In-place addition

A.div_(len(representations)) # In-place division
B.div_(len(representations)) # In-place division
return A, B


class _CCA_EYLoss:
def __init__(self, eps: float = 1e-4):
self.eps = eps

def loss(self, representations, independent_representations=None):
A, B = self.get_AB(representations)
@staticmethod
@torch.jit.script
def loss(
representations: List[torch.Tensor],
independent_representations: Optional[List[torch.Tensor]] = None,
):
A, B = CCA_AB(representations)
rewards = torch.trace(2 * A)
if independent_representations is None:
penalties = torch.trace(B @ B)
else:
independent_A, independent_B = self.get_AB(independent_representations)
independent_A, independent_B = CCA_AB(independent_representations)
penalties = torch.trace(B @ independent_B)
return {
"objective": -rewards + penalties,
"rewards": rewards,
"penalties": penalties,
}

@staticmethod
def derivative(
self,
views,
representations,
independent_views=None,
independent_representations=None,
views: List[torch.Tensor],
representations: List[torch.Tensor],
independent_views: Optional[List[torch.Tensor]] = None,
independent_representations: Optional[List[torch.Tensor]] = None,
):
A, B = self.get_AB(representations)
A, B = CCA_AB(representations)
sum_representations = torch.sum(torch.stack(representations), dim=0)
n = sum_representations.shape[0]
rewards = [2 * view.T @ sum_representations / (n - 1) for view in views]
Expand All @@ -283,7 +291,7 @@ def derivative(
for view, representation in zip(views, representations)
]
else:
_, independent_B = self.get_AB(independent_representations)
_, independent_B = CCA_AB(independent_representations)
penalties = [
view.T @ representation @ B / (n - 1)
+ independent_view.T
Expand All @@ -301,14 +309,19 @@ def derivative(


class _CCA_GHALoss(_CCA_EYLoss):
def loss(self, representations, independent_representations=None):
A, B = self.get_AB(representations)
@staticmethod
@torch.jit.script
def loss(
representations: List[torch.Tensor],
independent_representations: Optional[List[torch.Tensor]] = None,
):
A, B = CCA_AB(representations)
rewards = torch.trace(A)
if independent_representations is None:
rewards += torch.trace(A)
penalties = torch.trace(A @ B)
else:
independent_A, independent_B = self.get_AB(independent_representations)
independent_A, independent_B = CCA_AB(independent_representations)
rewards += torch.trace(independent_A)
penalties = torch.trace(independent_A @ B)
return {
Expand All @@ -317,14 +330,14 @@ def loss(self, representations, independent_representations=None):
"penalties": penalties,
}

@staticmethod
def derivative(
self,
views,
representations,
independent_views=None,
independent_representations=None,
views: List[torch.Tensor],
representations: List[torch.Tensor],
independent_views: Optional[List[torch.Tensor]] = None,
independent_representations: Optional[List[torch.Tensor]] = None,
):
A, B = self.get_AB(representations)
A, B = CCA_AB(representations)
sum_representations = torch.sum(torch.stack(representations), dim=0)
n = sum_representations.shape[0]
if independent_representations is None:
Expand All @@ -335,7 +348,7 @@ def derivative(
for view, representation in zip(views, representations)
]
else:
independent_A, independent_B = self.get_AB(independent_representations)
independent_A, independent_B = CCA_AB(independent_representations)
rewards = [2 * view.T @ sum_representations / (n - 1) for view in views]
penalties = [
view.T @ sum_representations @ independent_B / (n - 1)
Expand All @@ -346,7 +359,12 @@ def derivative(


class _CCA_SVDLoss(_CCA_EYLoss):
def loss(self, representations, independent_representations=None):
@staticmethod
@torch.jit.script
def loss(
representations: List[torch.Tensor],
independent_representations: Optional[List[torch.Tensor]] = None,
):
C = torch.cov(torch.hstack(representations).T)
latent_dims = representations[0].shape[1]

Expand All @@ -366,12 +384,12 @@ def loss(self, representations, independent_representations=None):
"penalties": penalties, # return the penalties matrix
}

@staticmethod
def derivative(
self,
views,
representations,
independent_views=None,
independent_representations=None,
views: List[torch.Tensor],
representations: List[torch.Tensor],
independent_views: Optional[List[torch.Tensor]] = None,
independent_representations: Optional[List[torch.Tensor]] = None,
):
C = torch.cov(torch.hstack(representations).T)
latent_dims = representations[0].shape[1]
Expand All @@ -394,9 +412,11 @@ def derivative(
return [2 * (-reward + penalty) for reward, penalty in zip(rewards, penalties)]


class _PLS_EYLoss(_PLSAB):
def loss(self, representations, weights):
A, B = self.get_AB(representations, weights)
class _PLS_EYLoss:
@staticmethod
@torch.jit.script
def loss(representations: List[torch.Tensor], weights: List[torch.Tensor]):
A, B = PLS_AB(representations, weights)
rewards = torch.trace(2 * A)
penalties = torch.trace(B @ B)
return {
Expand All @@ -405,8 +425,14 @@ def loss(self, representations, weights):
"penalties": penalties,
}

def derivative(self, views, representations, weights):
A, B = self.get_AB(representations, weights)
@staticmethod
@torch.jit.script
def derivative(
views: List[torch.Tensor],
representations: List[torch.Tensor],
weights: List[torch.Tensor],
):
A, B = PLS_AB(representations, weights)
sum_representations = torch.sum(torch.stack(representations), dim=0)
n = sum_representations.shape[0]
rewards = [2 * view.T @ sum_representations / (n - 1) for view in views]
Expand All @@ -417,8 +443,10 @@ def derivative(self, views, representations, weights):
return [2 * (-reward + penalty) for reward, penalty in zip(rewards, penalties)]


class _PLS_PowerLoss(_PLSAB):
def loss(self, representations):
class _PLS_PowerLoss:
@staticmethod
@torch.jit.script
def loss(representations: List[torch.Tensor]):
cov = torch.cov(torch.hstack(representations).T)
return {
"objective": torch.trace(
Expand All @@ -427,6 +455,7 @@ def loss(self, representations):
}

@staticmethod
def derivative(views, representations):
@torch.jit.script
def derivative(views: List[torch.Tensor], representations: List[torch.Tensor]):
grads = [views[0].T @ representations[1], views[1].T @ representations[0]]
return grads
24 changes: 18 additions & 6 deletions cca_zoo/linear/_gradient/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
logger=False,
enable_model_summary=False,
enable_progress_bar=False,
accelerator="cpu",
)

DEFAULT_LOADER_KWARGS = dict(pin_memory=False, drop_last=True, shuffle=True)
Expand All @@ -35,9 +36,9 @@ def __init__(
epochs=1,
learning_rate=1e-2,
initialization: Union[str, callable] = "random",
trainer_kwargs=None,
optimizer_kwargs=None,
early_stopping=False,
logging=False,
):
_BaseModel.__init__(
self,
Expand All @@ -59,19 +60,29 @@ def __init__(
else:
self.initialization = initialization
self.dataloader_kwargs = dataloader_kwargs or DEFAULT_LOADER_KWARGS
self.trainer_kwargs = trainer_kwargs or DEFAULT_TRAINER_KWARGS
self.optimizer_kwargs = optimizer_kwargs or DEFAULT_OPTIMIZER_KWARGS
self.early_stopping = early_stopping
self.logging = logging

def fit(self, views: Iterable[np.ndarray], y=None, validation_views=None, **kwargs):
def fit(
self,
views: Iterable[np.ndarray],
y=None,
validation_views=None,
**trainer_kwargs
):
views = self._validate_data(views)
if validation_views is not None:
validation_views = self._validate_data(validation_views)
self._check_params()
self.weights_ = self._fit(views, validation_views=validation_views)
self.weights_ = self._fit(
views, validation_views=validation_views, **trainer_kwargs
)
return self

def _fit(self, views: Iterable[np.ndarray], validation_views=None):
def _fit(
self, views: Iterable[np.ndarray], validation_views=None, **trainer_kwargs
):
self._initialize(views)
# Set the weights_ attribute as torch parameters with gradients
self.torch_weights = torch.nn.ParameterList(
Expand All @@ -82,7 +93,8 @@ def _fit(self, views: Iterable[np.ndarray], validation_views=None):
)
trainer = pl.Trainer(
max_epochs=self.epochs,
**self.trainer_kwargs,
# if trainer_kwargs is not None trainer_kwargs will override the defaults
**{**DEFAULT_TRAINER_KWARGS, **trainer_kwargs},
)
train_dataset, val_dataset = self.get_dataset(
views, validation_views=validation_views
Expand Down
Loading

0 comments on commit 6562e2d

Please sign in to comment.