From ce2f0829cb64905b3aaf2e88c13a34550dd0d71f Mon Sep 17 00:00:00 2001 From: jameschapman19 Date: Thu, 21 Sep 2023 14:41:09 +0000 Subject: [PATCH] Format code with black --- cca_zoo/deep/_base.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/cca_zoo/deep/_base.py b/cca_zoo/deep/_base.py index e9031f3d..dc8a3950 100644 --- a/cca_zoo/deep/_base.py +++ b/cca_zoo/deep/_base.py @@ -58,7 +58,13 @@ def training_step(self, batch: Dict[str, Any], batch_idx: int) -> torch.Tensor: loss = self.loss(batch["views"]) for k, v in loss.items(): # Use f-string instead of concatenation - self.log(f"train/{k}", v, prog_bar=True, on_epoch=True, batch_size=batch["views"][0].shape[0]) + self.log( + f"train/{k}", + v, + prog_bar=True, + on_epoch=True, + batch_size=batch["views"][0].shape[0], + ) return loss["objective"] def validation_step(self, batch: Dict[str, Any], batch_idx: int) -> torch.Tensor: @@ -66,7 +72,9 @@ def validation_step(self, batch: Dict[str, Any], batch_idx: int) -> torch.Tensor loss = self.loss(batch["views"]) for k, v in loss.items(): # Use f-string instead of concatenation - self.log(f"val/{k}", v, on_epoch=True, batch_size=batch["views"][0].shape[0]) + self.log( + f"val/{k}", v, on_epoch=True, batch_size=batch["views"][0].shape[0] + ) return loss["objective"] def test_step(self, batch: Dict[str, Any], batch_idx: int) -> torch.Tensor: @@ -74,7 +82,9 @@ def test_step(self, batch: Dict[str, Any], batch_idx: int) -> torch.Tensor: loss = self.loss(batch["views"]) for k, v in loss.items(): # Use f-string instead of concatenation - self.log(f"test/{k}", v, on_epoch=True, batch_size=batch["views"][0].shape[0]) + self.log( + f"test/{k}", v, on_epoch=True, batch_size=batch["views"][0].shape[0] + ) return loss["objective"] @torch.no_grad()