Skip to content

Commit

Permalink
Cov vis
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Nov 9, 2024
1 parent cc1623a commit 8fddb4f
Showing 1 changed file with 23 additions and 18 deletions.
41 changes: 23 additions & 18 deletions src/dartsort/vis/gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,13 @@ class MStep(GMMPlot):
n_show = 64

def draw(self, panel, gmm, unit_id, axes=None):
panel_top, panel_bottom = panel.subfigures(nrows=2, height_ratios=[1.5, 1])
panel_top, panel_bottom = panel.subfigures(nrows=2, height_ratios=[1, 1])
ax = panel_top.subplots()
ax.axis("off")

# panel_bottom, panel_cbar = panel_bottom.subfigures(ncols=2, width_ratios=[5, 0.5])
cov_axes = panel_bottom.subplots(
nrows=2, ncols=2, sharey=True, sharex=True
nrows=3, ncols=2, sharey=True, sharex=True
)
# cax = panel_cbar.add_subplot(3, 1, 2)

Expand Down Expand Up @@ -159,31 +159,31 @@ def draw(self, panel, gmm, unit_id, axes=None):
# covariance vis
feats = features_full[:, :, gmm.units[unit_id].channels]
model_mean = gmm.units[unit_id].mean[:, gmm.units[unit_id].channels]
feats = feats - model_mean
n, r, c = feats.shape
emp_cov, nobs = spiketorch.nancov(feats.view(n, r * c), return_nobs=True)
denom = nobs + gmm.units[unit_id].prior_pseudocount
emp_cov = (nobs / denom) * emp_cov
noise_cov = gmm.noise.marginal_covariance(channels=gmm.units[unit_id].channels).to_dense()
m = model_mean.abs().reshape(-1)
m = model_mean.reshape(-1)
mmt = m[:, None] @ m[None, :]
covs = (emp_cov, noise_cov, mmt)
vmax = max(c.abs().max() for c in covs)
names = ("regemp", "noise", "|temptempT|")
print(f"{feats.shape=} {gmm.units[unit_id].channels.shape=}")
print(f"{vmax=}")
print(f"{emp_cov.abs().max()=}")
print(f"{noise_cov.abs().max()=}")
print(f"{mmt.abs().max()=}")
print(f"{emp_cov.shape=}")
print(f"{noise_cov.shape=}")
print(f"{mmt.shape=}")
modelcov = gmm.units[unit_id].marginal_covariance(channels=gmm.units[unit_id].channels).to_dense()
residual = emp_cov - modelcov
covs = (emp_cov, noise_cov, mmt.abs(), mmt, modelcov, emp_cov - modelcov)
# vmax = max(c.abs().max() for c in covs)
names = ("regemp", "noise", "|temptempT|", "temptempT", "model", "resid")

for ax, cov, name in zip(cov_axes.flat, covs, names):
vmax = cov.abs().triu(diagonal=1)
vmax = vmax[vmax>0].quantile(.95)
vmax = vmax[vmax>0].quantile(.975)
im = ax.imshow(cov.numpy(force=True), vmin=-vmax, vmax=vmax, cmap=plt.cm.seismic)
ax.axis("off")
ax.set_title(name, fontsize="small")
ax.set_title(
name
+ f" max={cov.abs().max().numpy(force=True).item():.2f}, "
+ f"rms={cov.square().mean().sqrt_().numpy(force=True).item():.2f}",
fontsize="small",
)
plt.colorbar(im, ax=ax, shrink=0.5)
# plt.colorbar(im, cax=cax, shrink=0.5)

Expand Down Expand Up @@ -269,9 +269,14 @@ def __init__(self, layout="vert"):
def draw(self, panel, gmm, unit_id, split_info=None):
if split_info is None:
split_info = gmm.kmeans_split_unit(unit_id, debug=True)
if not split_info:
failed = not split_info or "reas_labels" not in split_info
if failed:
ax = panel.subplots()
ax.text(.5, .5, "no channels!", ha="center", transform=ax.transAxes)
if not split_info:
msg = "no channels!"
else:
msg = "split abandoned"
ax.text(.5, .5, msg, ha="center", transform=ax.transAxes)
ax.axis("off")
return

Expand Down

0 comments on commit 8fddb4f

Please sign in to comment.