Skip to content

Commit

Permalink
Debug factorized noise
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Nov 8, 2024
1 parent dec84b5 commit e9dd70a
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions src/dartsort/util/noise_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,10 +436,13 @@ def _marginal_covariance(self, channels=slice(None)):

if self.cov_kind == "factorized":
rank_root = self.rank_vt.T * self.rank_std
rank_root = operators.RootLinearOperator(rank_root)
chan_root = self.channel_vt.T * self.channel_std
chan_root = operators.RootLinearOperator(chan_root)
return torch.kron(rank_root, chan_root)
rank_cov = rank_root @ rank_root.T
# rank_root = operators.RootLinearOperator(rank_root)
chan_root = self.channel_vt.T[channels] * self.channel_std
chan_cov = chan_root @ chan_root.T
# chan_root = operators.RootLinearOperator(chan_root)
# return torch.kron(rank_root, chan_root)
return operators.KroneckerProductLinearOperator(rank_cov, chan_cov)

if self.cov_kind == "factorized_rank_diag":
rank_cov = operators.DiagLinearOperator(self.rank_std.square())
Expand Down Expand Up @@ -565,7 +568,12 @@ def estimate(cls, snippets, mean_kind="zero", cov_kind="scalar"):
channel_vt[q] = qv.T
else:
x_spatial = x_spatial.reshape(n * rank, n_channels)
cov_spatial = spiketorch.nancov(x_spatial)
valid = x_spatial.isfinite().any(0)
cov = spiketorch.nancov(x_spatial[:, valid])
cov_spatial = torch.eye(
x_spatial.shape[1], dtype=cov.dtype, device=cov.device
)
cov_spatial[valid[:, None] & valid[None, :]] = cov.view(-1)
channel_eig, channel_v = torch.linalg.eigh(cov_spatial)
channel_std = channel_eig.sqrt()
channel_vt = channel_v.T.contiguous()
Expand Down

0 comments on commit e9dd70a

Please sign in to comment.