Skip to content

Commit

Permalink
Cov vis
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Nov 12, 2024
1 parent 0875237 commit f148071
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 11 deletions.
135 changes: 124 additions & 11 deletions src/dartsort/vis/gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
from tqdm.auto import tqdm

from ..cluster import gaussian_mixture
from ..cluster import gaussian_mixture, stable_features
from ..util.multiprocessing_util import (CloudpicklePoolExecutor,
ThreadPoolExecutor, get_pool, cloudpickle)
from ..util import spiketorch
Expand Down Expand Up @@ -97,25 +97,32 @@ def draw(self, panel, gmm, unit_id):


class MStep(GMMPlot):
kind = "waveform"
kind = "mstep"
width = 5
height = 9
alpha = 0.05
n_show = 64

def __init__(self, n_waveforms_show=64, with_covs=True):
self.with_covs = with_covs
self.height = 5 + 4 * with_covs
self.n_waveforms_show = n_waveforms_show

def draw(self, panel, gmm, unit_id, axes=None):
panel_top, panel_bottom = panel.subfigures(nrows=2, height_ratios=[1, 1])
if self.with_covs:
panel_top, panel_bottom = panel.subfigures(nrows=2, height_ratios=[1, 1])
else:
panel_top = panel
ax = panel_top.subplots()
ax.axis("off")

# panel_bottom, panel_cbar = panel_bottom.subfigures(ncols=2, width_ratios=[5, 0.5])
cov_axes = panel_bottom.subplots(
nrows=3, ncols=2, sharey=True, sharex=True
)
# cax = panel_cbar.add_subplot(3, 1, 2)
if self.with_covs:
cov_axes = panel_bottom.subplots(
nrows=3, ncols=2, sharey=True, sharex=True
)
# cax = panel_cbar.add_subplot(3, 1, 2)

# get spike data and determine channel set by plotting
sp = gmm.random_spike_data(unit_id, max_size=self.n_show, with_reconstructions=True)
sp = gmm.random_spike_data(unit_id, max_size=self.n_waveforms_show, with_reconstructions=True)
maa = sp.waveforms.abs().nan_to_num().max()
geomplot_kw = dict(
max_abs_amp=maa,
Expand All @@ -136,7 +143,6 @@ def draw(self, panel, gmm, unit_id, axes=None):
sp, weights=None, n_channels=gmm.data.n_channels, storage=None
)
features_full, weights_full, count_data, weights_normalized = tup
print(f"{features_full.shape=}")
feats = features_full[:, :, chans]
n, r, c = feats.shape
emp_mean = torch.nanmean(feats, dim=0)
Expand All @@ -155,6 +161,8 @@ def draw(self, panel, gmm, unit_id, axes=None):
)
ax.axis("off")
ax.set_title("reconstructed mean and example inputs")
if not self.with_covs:
return

# covariance vis
feats = features_full[:, :, gmm.units[unit_id].channels]
Expand Down Expand Up @@ -188,6 +196,109 @@ def draw(self, panel, gmm, unit_id, axes=None):
# plt.colorbar(im, cax=cax, shrink=0.5)


class CovarianceResidual(GMMPlot):
kind = "mstep"
width = 7
height = 5

def draw(self, panel, gmm, unit_id):
sp = gmm.random_spike_data(unit_id)
weights = gmm.get_fit_weights(unit_id, sp.indices, getattr(gmm, 'log_liks', None))

achans = gaussian_mixture.occupied_chans(
sp, gmm.noise.n_channels
)
if weights is None:
weights = sp.features.new_ones(len(sp))
afeats, aweights = stable_features.pad_to_chans(
sp,
achans,
gmm.noise.n_channels,
weights=weights,
pad_value=torch.nan
)
aweights_sum = torch.nansum(aweights, 0)
aweights_norm = aweights / aweights_sum

mean = torch.linalg.vecdot(
aweights_norm.unsqueeze(1).nan_to_num(), afeats.nan_to_num(), dim=0
)
afeatsc = afeats - mean

emp_cov = spiketorch.nancov(afeatsc.view(len(sp), -1), weights=weights, correction=0, force_posdef=True)
# emp_cov = torch.cov(afeatsc.view(len(sp), -1).nan_to_num().T)
noise_cov = gmm.noise.marginal_covariance(achans).to_dense()
residual = emp_cov - noise_cov

mmT = mean.view(-1, 1) @ mean.view(1, -1)
scale = (mmT * residual).sum() / mmT.square().sum()
model = noise_cov + scale * mmT
model_residual = emp_cov - model

emp_eigs = torch.linalg.eigvalsh(emp_cov)
noise_eigs = torch.linalg.eigvalsh(noise_cov)
residual_eigs, residual_vecs = torch.linalg.eigh(residual)
model_residual_eigs = torch.linalg.eigvalsh(model_residual)

rank1 = (residual_vecs[:, -1:] * residual_eigs[-1:]) @ residual_vecs[:, -1:].T
rank1_model = noise_cov + rank1
rank1_residual = emp_cov - rank1_model
rank1_residual_eigs = torch.linalg.eigvalsh(rank1_residual)

top, bot = panel.subfigures(nrows=2, height_ratios=[5, 2])
axes = top.subplots(nrows=3, ncols=3, sharex=True, sharey=True)
# ax_eig, ax_r2 = bot.subplots(ncols=2)
ax_eig = bot.subplots()

vm = emp_cov.abs().max()
imk = dict(vmin=-vm, vmax=vm, cmap=plt.cm.seismic, interpolation='none')

covs = dict(
emp=emp_cov,
noise=noise_cov,
noise_resid=residual,
mmT_scaled=scale * mmT,
mmT_model=model,
mmT_resid=model_residual,
rank1=rank1,
rank1_model=rank1_model,
rank1_resid=rank1_residual,
)
colors = ("k", "g", "r", "gray", "c", "orange", "gray", "palegreen", "fuchsia")
eigs = dict(
emp=emp_eigs,
noise=noise_eigs,
noise_resid=residual_eigs,
mmT_resid=model_residual_eigs,
rank1_resid=rank1_residual_eigs,
)
for (name, cov), ax, color in zip(covs.items(), axes.flat, colors):
if name == 'mmT_scaled':
vm = cov.abs().max() * 0.9
mimk = dict(vmin=-vm, vmax=vm, cmap=plt.cm.seismic, interpolation='none')
else:
mimk = imk
im = ax.imshow(cov, **mimk)
cb = plt.colorbar(im, ax=ax, shrink=0.2)
cb.outline.set_visible(False)
title = name
if name.startswith("mmT_sc"):
title = title + f" (scale={scale:0.2f})"
ax.set_title(title, color=color)
if name in eigs:
ax_eig.plot(eigs[name].flip(0), color=color, lw=1)
# r2 = (eigs['emp'].sum() - F.relu(eigs[name].flip(0)).cumsum(0)) / eigs['emp'].sum()
# ax_r2.plot(r2, color=color, lw=1)

for ax in axes.flat[len(covs):]:
ax.axis("off")
ax_eig.set_xlabel('eig index')
ax_eig.set_ylabel('eigenvalues')
ax_eig.axhline(0, color='k', lw=0.8)
# ax_r2.set_ylabel('1-R^2')
# ax_eig.axhline([0, 1], color='k', lw=0.8)


class Likelihoods(GMMPlot):
kind = "widescatter"
width = 4
Expand Down Expand Up @@ -345,6 +456,8 @@ def draw(self, panel, gmm, unit_id, split_info=None):
chans = torch.cdist(gmm.data.prgeom[mainchan[None]], gmm.data.prgeom)
chans = chans.view(-1)
(chans,) = torch.nonzero(chans <= gmm.data.core_radius, as_tuple=True)
if len(split_ids) < len(split_info["units"]):
split_info["units"] = [split_info["units"][j] for j in split_ids]
gmm_helpers.plot_means(
mcmeans_row, gmm.data.prgeom, gmm.data.tpca, chans, split_info["units"], split_ids, title=None
)
Expand Down
1 change: 1 addition & 0 deletions src/dartsort/vis/gmm_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def plot_means(panel, prgeom, tpca, chans, units, labels, title="nearest neighbo
means.append(tpca.force_reconstruct(mean).numpy(force=True))

colors = glasbey1024[labels]
print(f"{len(means)=} {labels.shape=} {colors.shape=}")
geomplot(
np.stack(means, axis=0),
channels=chans[None]
Expand Down

0 comments on commit f148071

Please sign in to comment.