Skip to content

Commit

Permalink
Format code with black
Browse files Browse the repository at this point in the history
  • Loading branch information
jameschapman19 authored and github-actions[bot] committed Sep 18, 2023
1 parent 2d96b16 commit ed038fe
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 2 deletions.
2 changes: 1 addition & 1 deletion cca_zoo/deep/_discriminative/_dcca_gha.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def loss(self, views, independent_views=None, **kwargs):
rewards = torch.trace(2 * A)
if independent_views is None:
# Hebbian
penalties= torch.trace(A.detach() @ B)
penalties = torch.trace(A.detach() @ B)
# penalties = torch.trace(A @ B)
else:
independent_z = self(independent_views)
Expand Down
1 change: 1 addition & 0 deletions cca_zoo/linear/_gradient/_ey.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def get_dataset(self, views: Iterable[np.ndarray]):

class DoubleNumpyDataset(NumpyDataset):
random_state = np.random.RandomState(0)

def __getitem__(self, index):
views = [view[index] for view in self.views]
independent_index = self.random_state.randint(0, len(self))
Expand Down
2 changes: 1 addition & 1 deletion cca_zoo/linear/_gradient/_gha.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def loss(self, views, independent_views=None, **kwargs):
z = self(views)
# Getting A and B matrices from z
A, B = self.get_AB(z)
rewards= torch.trace(2*A)
rewards = torch.trace(2 * A)
if independent_views is None:
# Hebbian
penalties = torch.trace(A.detach() @ B)
Expand Down

0 comments on commit ed038fe

Please sign in to comment.