Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main'
Browse files Browse the repository at this point in the history
# Conflicts:
#	test/test_deepmodels.py
  • Loading branch information
jameschapman19 committed Oct 6, 2023
2 parents 3648c90 + 963d1dc commit 20f2190
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions cca_zoo/deep/_discriminative/_dcca_noi.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ def __init__(
device=None,
dtype=None,
) -> None:

factory_kwargs = {"device": device, "dtype": dtype}
super(BatchWhiten, self).__init__()

Expand Down Expand Up @@ -80,9 +79,13 @@ def forward(self, input: Tensor) -> Tensor:
if self.training:
bn_training = True
else:
bn_training = (self.running_covar is None)
bn_training = self.running_covar is None

running_covar=self.running_covar if not self.training or self.track_running_stats else None
running_covar = (
self.running_covar
if not self.training or self.track_running_stats
else None
)

# Calculate batch covariance
covar = torch.matmul(input.T, input) / input.shape[0]
Expand All @@ -91,7 +94,9 @@ def forward(self, input: Tensor) -> Tensor:
if bn_training:
with torch.no_grad():
if running_covar is not None:
running_covar.mul_(exponential_average_factor).add_(covar, alpha=1 - exponential_average_factor)
running_covar.mul_(exponential_average_factor).add_(
covar, alpha=1 - exponential_average_factor
)

# Calculate whitened input
if running_covar is not None:
Expand All @@ -104,6 +109,7 @@ def forward(self, input: Tensor) -> Tensor:
input = torch.matmul(input, B)
return input


def inv_sqrtm(A, eps=1e-9):
"""Compute the inverse square-root of a positive definite matrix."""
# Perform eigendecomposition of covariance matrix
Expand Down

0 comments on commit 20f2190

Please sign in to comment.