diff --git a/src/dartsort/util/noise_util.py b/src/dartsort/util/noise_util.py index d3247321..ce5bb5d4 100644 --- a/src/dartsort/util/noise_util.py +++ b/src/dartsort/util/noise_util.py @@ -436,10 +436,13 @@ def _marginal_covariance(self, channels=slice(None)): if self.cov_kind == "factorized": rank_root = self.rank_vt.T * self.rank_std - rank_root = operators.RootLinearOperator(rank_root) - chan_root = self.channel_vt.T * self.channel_std - chan_root = operators.RootLinearOperator(chan_root) - return torch.kron(rank_root, chan_root) + rank_cov = rank_root @ rank_root.T + # rank_root = operators.RootLinearOperator(rank_root) + chan_root = self.channel_vt.T[channels] * self.channel_std + chan_cov = chan_root @ chan_root.T + # chan_root = operators.RootLinearOperator(chan_root) + # return torch.kron(rank_root, chan_root) + return operators.KroneckerProductLinearOperator(rank_cov, chan_cov) if self.cov_kind == "factorized_rank_diag": rank_cov = operators.DiagLinearOperator(self.rank_std.square()) @@ -565,7 +568,12 @@ def estimate(cls, snippets, mean_kind="zero", cov_kind="scalar"): channel_vt[q] = qv.T else: x_spatial = x_spatial.reshape(n * rank, n_channels) - cov_spatial = spiketorch.nancov(x_spatial) + valid = x_spatial.isfinite().any(0) + cov = spiketorch.nancov(x_spatial[:, valid]) + cov_spatial = torch.eye( + x_spatial.shape[1], dtype=cov.dtype, device=cov.device + ) + cov_spatial[valid[:, None] & valid[None, :]] = cov.view(-1) channel_eig, channel_v = torch.linalg.eigh(cov_spatial) channel_std = channel_eig.sqrt() channel_vt = channel_v.T.contiguous()