diff --git a/cca_zoo/deep/_base.py b/cca_zoo/deep/_base.py index e9031f3d..7f8789ff 100644 --- a/cca_zoo/deep/_base.py +++ b/cca_zoo/deep/_base.py @@ -55,26 +55,41 @@ def loss( def training_step(self, batch: Dict[str, Any], batch_idx: int) -> torch.Tensor: """Performs one step of training on a batch of views.""" - loss = self.loss(batch["views"]) + loss = self.loss(batch) for k, v in loss.items(): # Use f-string instead of concatenation - self.log(f"train/{k}", v, prog_bar=True, on_epoch=True, batch_size=batch["views"][0].shape[0]) + self.log(f"train/{k}", + v, + on_step=False, + on_epoch=True, + batch_size=batch["views"][0].shape[0], + ) return loss["objective"] def validation_step(self, batch: Dict[str, Any], batch_idx: int) -> torch.Tensor: """Performs one step of validation on a batch of views.""" - loss = self.loss(batch["views"]) + loss = self.loss(batch) for k, v in loss.items(): # Use f-string instead of concatenation - self.log(f"val/{k}", v, on_epoch=True, batch_size=batch["views"][0].shape[0]) + self.log(f"val/{k}", + v, + on_step=False, + on_epoch=True, + batch_size=batch["views"][0].shape[0], + ) return loss["objective"] def test_step(self, batch: Dict[str, Any], batch_idx: int) -> torch.Tensor: """Performs one step of testing on a batch of views.""" - loss = self.loss(batch["views"]) + loss = self.loss(batch) for k, v in loss.items(): # Use f-string instead of concatenation - self.log(f"test/{k}", v, on_epoch=True, batch_size=batch["views"][0].shape[0]) + self.log(f"test/{k}", + v, + on_step=False, + on_epoch=True, + batch_size=batch["views"][0].shape[0], + ) return loss["objective"] @torch.no_grad() diff --git a/cca_zoo/deep/_discriminative/_dcca.py b/cca_zoo/deep/_discriminative/_dcca.py index fc18742a..8d52313f 100644 --- a/cca_zoo/deep/_discriminative/_dcca.py +++ b/cca_zoo/deep/_discriminative/_dcca.py @@ -40,8 +40,8 @@ def forward(self, views, **kwargs): z = [encoder(view) for encoder, view in zip(self.encoders, views)] return z - def loss(self, views, **kwargs): - z = self(views) + def loss(self, batch, **kwargs): + z = self(batch['views']) return {"objective": self.objective.loss(z)} def pairwise_correlations(self, loader: torch.utils.data.DataLoader): diff --git a/cca_zoo/deep/_discriminative/_dcca_barlow_twins.py b/cca_zoo/deep/_discriminative/_dcca_barlow_twins.py index 72511c59..81308458 100644 --- a/cca_zoo/deep/_discriminative/_dcca_barlow_twins.py +++ b/cca_zoo/deep/_discriminative/_dcca_barlow_twins.py @@ -40,8 +40,8 @@ def forward(self, views, **kwargs): z.append(bn(encoder(views[i]))) # encode and normalize each view return z # return a list of normalized latent representations - def loss(self, views, **kwargs): - z = self(views) # get the latent representations + def loss(self, batch, **kwargs): + z = self(batch['views']) # get the latent representations cross_cov = ( z[0].T @ z[1] / z[0].shape[0] ) # compute the cross-covariance matrix between the two views diff --git a/cca_zoo/deep/_discriminative/_dcca_ey.py b/cca_zoo/deep/_discriminative/_dcca_ey.py index c73e0785..062a6b9d 100644 --- a/cca_zoo/deep/_discriminative/_dcca_ey.py +++ b/cca_zoo/deep/_discriminative/_dcca_ey.py @@ -19,9 +19,10 @@ def __init__(self, latent_dimensions: int, encoders=None, r: float = 0, **kwargs ) self.r = r - def loss(self, views, independent_views=None, **kwargs): + def loss(self, batch, **kwargs): # Encoding the views with the forward method - z = self(views) + z = self(batch['views']) + independent_views = batch.get("independent_views", None) # Getting A and B matrices from z A, B = self.get_AB(z) rewards = torch.trace(2 * A) diff --git a/cca_zoo/deep/_discriminative/_dcca_gha.py b/cca_zoo/deep/_discriminative/_dcca_gha.py index 6ee48295..fa7fb27d 100644 --- a/cca_zoo/deep/_discriminative/_dcca_gha.py +++ b/cca_zoo/deep/_discriminative/_dcca_gha.py @@ -22,8 +22,9 @@ def get_AB(self, z): z ) # return the normalized matrices (divided by the number of views) - def loss(self, views, independent_views=None, **kwargs): + def loss(self, batch, **kwargs): z = self(views) + independent_views = batch.get("independent_views", None) A, B = self.get_AB(z) rewards = torch.trace(2 * A) if independent_views is None: diff --git a/cca_zoo/deep/_discriminative/_dcca_noi.py b/cca_zoo/deep/_discriminative/_dcca_noi.py index 3c5e5e42..5a7307ca 100644 --- a/cca_zoo/deep/_discriminative/_dcca_noi.py +++ b/cca_zoo/deep/_discriminative/_dcca_noi.py @@ -43,8 +43,8 @@ def __init__( self.mse = torch.nn.MSELoss(reduction="sum") self.rand = torch.rand(N, self.latent_dimensions) - def loss(self, views, **kwargs): - z = self(views) + def loss(self, batch, **kwargs): + z = self(batch['views']) z_copy = [z_.detach().clone() for z_ in z] self._update_covariances(z_copy, train=self.training) covariance_inv = [inv_sqrtm(cov, self.eps) for cov in self.covs] diff --git a/cca_zoo/deep/_discriminative/_dcca_sdl.py b/cca_zoo/deep/_discriminative/_dcca_sdl.py index 252c172f..458f0fee 100644 --- a/cca_zoo/deep/_discriminative/_dcca_sdl.py +++ b/cca_zoo/deep/_discriminative/_dcca_sdl.py @@ -52,8 +52,8 @@ def forward(self, views, **kwargs): z.append(bn(encoder(views[i]))) return z - def loss(self, views, **kwargs): - z = self(views) + def loss(self, batch, **kwargs): + z = self(batch['views']) l2_loss = F.mse_loss(z[0], z[1]) self._update_covariances(z, train=self.training) SDL_loss = self._sdl_loss(self.covs) diff --git a/cca_zoo/deep/_discriminative/_dcca_svd.py b/cca_zoo/deep/_discriminative/_dcca_svd.py index 3733cd32..425d4377 100644 --- a/cca_zoo/deep/_discriminative/_dcca_svd.py +++ b/cca_zoo/deep/_discriminative/_dcca_svd.py @@ -22,15 +22,16 @@ def __init__(self, latent_dimensions: int, encoders=None, r: float = 0, **kwargs f"Expected 2 views, got {len(self.encoders)} views instead." ) - def loss(self, views, independent_views=None, **kwargs): + def loss(self, batch, **kwargs): # views here is a list of 'paired' views (i.e. [view1, view2]) - z = self(views) # get the latent representations + z = self(batch['views']) # get the latent representations C = torch.cov(torch.hstack(z).T) latent_dims = z[0].shape[1] Cxy = C[:latent_dims, latent_dims:] Cxx = C[:latent_dims, :latent_dims] + independent_views = batch.get("independent_views", None) if independent_views is None: Cyy = C[latent_dims:, latent_dims:] else: