From c588ecbf2240cd484d532a6a34e5878298d1b86b Mon Sep 17 00:00:00 2001 From: jameschapman19 Date: Mon, 13 Nov 2023 19:02:14 +0000 Subject: [PATCH] Format code with black --- cca_zoo/deep/objectives.py | 36 ++++++++++++---------- cca_zoo/linear/_gradient/_stochasticpls.py | 4 +-- tests/test_gradient.py | 2 +- 3 files changed, 21 insertions(+), 21 deletions(-) diff --git a/cca_zoo/deep/objectives.py b/cca_zoo/deep/objectives.py index 06ee7866..ebc6d4b3 100644 --- a/cca_zoo/deep/objectives.py +++ b/cca_zoo/deep/objectives.py @@ -267,11 +267,11 @@ def loss(self, representations, independent_representations=None): } def derivative( - self, - views, - representations, - independent_views=None, - independent_representations=None, + self, + views, + representations, + independent_views=None, + independent_representations=None, ): A, B = self.get_AB(representations) sum_representations = torch.sum(torch.stack(representations), dim=0) @@ -318,11 +318,11 @@ def loss(self, representations, independent_representations=None): } def derivative( - self, - views, - representations, - independent_views=None, - independent_representations=None, + self, + views, + representations, + independent_views=None, + independent_representations=None, ): A, B = self.get_AB(representations) sum_representations = torch.sum(torch.stack(representations), dim=0) @@ -367,11 +367,11 @@ def loss(self, representations, independent_representations=None): } def derivative( - self, - views, - representations, - independent_views=None, - independent_representations=None, + self, + views, + representations, + independent_views=None, + independent_representations=None, ): C = torch.cov(torch.hstack(representations).T) latent_dims = representations[0].shape[1] @@ -421,10 +421,12 @@ class _PLS_PowerLoss(_PLSAB): def loss(self, representations): cov = torch.cov(torch.hstack(representations).T) return { - "objective": torch.trace(cov[: representations[0].shape[1], representations[0].shape[1]:]) + "objective": torch.trace( + cov[: representations[0].shape[1], representations[0].shape[1] :] + ) } @staticmethod - def derivative( views, representations): + def derivative(views, representations): grads = [views[0].T @ representations[1], views[1].T @ representations[0]] return grads diff --git a/cca_zoo/linear/_gradient/_stochasticpls.py b/cca_zoo/linear/_gradient/_stochasticpls.py index 354c2ddc..f051845e 100644 --- a/cca_zoo/linear/_gradient/_stochasticpls.py +++ b/cca_zoo/linear/_gradient/_stochasticpls.py @@ -25,9 +25,7 @@ def training_step(self, batch, batch_idx): on_epoch=True, batch_size=batch["views"][0].shape[0], ) - manual_grads = self.objective.derivative( - batch["views"], representations - ) + manual_grads = self.objective.derivative(batch["views"], representations) for i, weights in enumerate(self.torch_weights): weights.grad = manual_grads[i] opt.step() diff --git a/tests/test_gradient.py b/tests/test_gradient.py index 5b7a6d6c..30c8318d 100644 --- a/tests/test_gradient.py +++ b/tests/test_gradient.py @@ -52,7 +52,7 @@ def test_batch_pls(): random_state=random_state, trainer_kwargs=trainer_kwargs, ).fit((X, Y)) - spls=PLSStochasticPower( + spls = PLSStochasticPower( latent_dimensions=latent_dims, epochs=epochs, random_state=random_state,