Skip to content

Commit

Permalink
Validation set for Gradient Based Models woohoo
Browse files Browse the repository at this point in the history
  • Loading branch information
jameschapman19 committed Sep 21, 2023
1 parent aab530c commit 14ea8cd
Show file tree
Hide file tree
Showing 8 changed files with 37 additions and 19 deletions.
27 changes: 21 additions & 6 deletions cca_zoo/deep/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,26 +55,41 @@ def loss(

def training_step(self, batch: Dict[str, Any], batch_idx: int) -> torch.Tensor:
"""Performs one step of training on a batch of views."""
loss = self.loss(batch["views"])
loss = self.loss(batch)
for k, v in loss.items():
# Use f-string instead of concatenation
self.log(f"train/{k}", v, prog_bar=True, on_epoch=True, batch_size=batch["views"][0].shape[0])
self.log(f"train/{k}",
v,
on_step=False,
on_epoch=True,
batch_size=batch["views"][0].shape[0],
)
return loss["objective"]

def validation_step(self, batch: Dict[str, Any], batch_idx: int) -> torch.Tensor:
"""Performs one step of validation on a batch of views."""
loss = self.loss(batch["views"])
loss = self.loss(batch)
for k, v in loss.items():
# Use f-string instead of concatenation
self.log(f"val/{k}", v, on_epoch=True, batch_size=batch["views"][0].shape[0])
self.log(f"val/{k}",
v,
on_step=False,
on_epoch=True,
batch_size=batch["views"][0].shape[0],
)
return loss["objective"]

def test_step(self, batch: Dict[str, Any], batch_idx: int) -> torch.Tensor:
"""Performs one step of testing on a batch of views."""
loss = self.loss(batch["views"])
loss = self.loss(batch)
for k, v in loss.items():
# Use f-string instead of concatenation
self.log(f"test/{k}", v, on_epoch=True, batch_size=batch["views"][0].shape[0])
self.log(f"test/{k}",
v,
on_step=False,
on_epoch=True,
batch_size=batch["views"][0].shape[0],
)
return loss["objective"]

@torch.no_grad()
Expand Down
4 changes: 2 additions & 2 deletions cca_zoo/deep/_discriminative/_dcca.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def forward(self, views, **kwargs):
z = [encoder(view) for encoder, view in zip(self.encoders, views)]
return z

def loss(self, views, **kwargs):
z = self(views)
def loss(self, batch, **kwargs):
z = self(batch['views'])
return {"objective": self.objective.loss(z)}

def pairwise_correlations(self, loader: torch.utils.data.DataLoader):
Expand Down
4 changes: 2 additions & 2 deletions cca_zoo/deep/_discriminative/_dcca_barlow_twins.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def forward(self, views, **kwargs):
z.append(bn(encoder(views[i]))) # encode and normalize each view
return z # return a list of normalized latent representations

def loss(self, views, **kwargs):
z = self(views) # get the latent representations
def loss(self, batch, **kwargs):
z = self(batch['views']) # get the latent representations
cross_cov = (
z[0].T @ z[1] / z[0].shape[0]
) # compute the cross-covariance matrix between the two views
Expand Down
5 changes: 3 additions & 2 deletions cca_zoo/deep/_discriminative/_dcca_ey.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ def __init__(self, latent_dimensions: int, encoders=None, r: float = 0, **kwargs
)
self.r = r

def loss(self, views, independent_views=None, **kwargs):
def loss(self, batch, **kwargs):
# Encoding the views with the forward method
z = self(views)
z = self(batch['views'])
independent_views = batch.get("independent_views", None)
# Getting A and B matrices from z
A, B = self.get_AB(z)
rewards = torch.trace(2 * A)
Expand Down
3 changes: 2 additions & 1 deletion cca_zoo/deep/_discriminative/_dcca_gha.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ def get_AB(self, z):
z
) # return the normalized matrices (divided by the number of views)

def loss(self, views, independent_views=None, **kwargs):
def loss(self, batch, **kwargs):
z = self(views)
independent_views = batch.get("independent_views", None)
A, B = self.get_AB(z)
rewards = torch.trace(2 * A)
if independent_views is None:
Expand Down
4 changes: 2 additions & 2 deletions cca_zoo/deep/_discriminative/_dcca_noi.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ def __init__(
self.mse = torch.nn.MSELoss(reduction="sum")
self.rand = torch.rand(N, self.latent_dimensions)

def loss(self, views, **kwargs):
z = self(views)
def loss(self, batch, **kwargs):
z = self(batch['views'])
z_copy = [z_.detach().clone() for z_ in z]
self._update_covariances(z_copy, train=self.training)
covariance_inv = [inv_sqrtm(cov, self.eps) for cov in self.covs]
Expand Down
4 changes: 2 additions & 2 deletions cca_zoo/deep/_discriminative/_dcca_sdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def forward(self, views, **kwargs):
z.append(bn(encoder(views[i])))
return z

def loss(self, views, **kwargs):
z = self(views)
def loss(self, batch, **kwargs):
z = self(batch['views'])
l2_loss = F.mse_loss(z[0], z[1])
self._update_covariances(z, train=self.training)
SDL_loss = self._sdl_loss(self.covs)
Expand Down
5 changes: 3 additions & 2 deletions cca_zoo/deep/_discriminative/_dcca_svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,16 @@ def __init__(self, latent_dimensions: int, encoders=None, r: float = 0, **kwargs
f"Expected 2 views, got {len(self.encoders)} views instead."
)

def loss(self, views, independent_views=None, **kwargs):
def loss(self, batch, **kwargs):
# views here is a list of 'paired' views (i.e. [view1, view2])
z = self(views) # get the latent representations
z = self(batch['views']) # get the latent representations
C = torch.cov(torch.hstack(z).T)
latent_dims = z[0].shape[1]

Cxy = C[:latent_dims, latent_dims:]
Cxx = C[:latent_dims, :latent_dims]

independent_views = batch.get("independent_views", None)
if independent_views is None:
Cyy = C[latent_dims:, latent_dims:]
else:
Expand Down

0 comments on commit 14ea8cd

Please sign in to comment.