Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
jameschapman19 committed Nov 13, 2023
2 parents 1894e3f + c588ecb commit ec348d7
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 21 deletions.
36 changes: 19 additions & 17 deletions cca_zoo/deep/objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,11 +267,11 @@ def loss(self, representations, independent_representations=None):
}

def derivative(
self,
views,
representations,
independent_views=None,
independent_representations=None,
self,
views,
representations,
independent_views=None,
independent_representations=None,
):
A, B = self.get_AB(representations)
sum_representations = torch.sum(torch.stack(representations), dim=0)
Expand Down Expand Up @@ -318,11 +318,11 @@ def loss(self, representations, independent_representations=None):
}

def derivative(
self,
views,
representations,
independent_views=None,
independent_representations=None,
self,
views,
representations,
independent_views=None,
independent_representations=None,
):
A, B = self.get_AB(representations)
sum_representations = torch.sum(torch.stack(representations), dim=0)
Expand Down Expand Up @@ -367,11 +367,11 @@ def loss(self, representations, independent_representations=None):
}

def derivative(
self,
views,
representations,
independent_views=None,
independent_representations=None,
self,
views,
representations,
independent_views=None,
independent_representations=None,
):
C = torch.cov(torch.hstack(representations).T)
latent_dims = representations[0].shape[1]
Expand Down Expand Up @@ -421,10 +421,12 @@ class _PLS_PowerLoss(_PLSAB):
def loss(self, representations):
cov = torch.cov(torch.hstack(representations).T)
return {
"objective": torch.trace(cov[: representations[0].shape[1], representations[0].shape[1]:])
"objective": torch.trace(
cov[: representations[0].shape[1], representations[0].shape[1] :]
)
}

@staticmethod
def derivative( views, representations):
def derivative(views, representations):
grads = [views[0].T @ representations[1], views[1].T @ representations[0]]
return grads
4 changes: 1 addition & 3 deletions cca_zoo/linear/_gradient/_stochasticpls.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@ def training_step(self, batch, batch_idx):
on_epoch=True,
batch_size=batch["views"][0].shape[0],
)
manual_grads = self.objective.derivative(
batch["views"], representations
)
manual_grads = self.objective.derivative(batch["views"], representations)
for i, weights in enumerate(self.torch_weights):
weights.grad = manual_grads[i]
opt.step()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_batch_pls():
random_state=random_state,
trainer_kwargs=trainer_kwargs,
).fit((X, Y))
spls=PLSStochasticPower(
spls = PLSStochasticPower(
latent_dimensions=latent_dims,
epochs=epochs,
random_state=random_state,
Expand Down

0 comments on commit ec348d7

Please sign in to comment.