diff --git a/src/dartsort/cluster/gaussian_mixture.py b/src/dartsort/cluster/gaussian_mixture.py index dad80733..b296af95 100644 --- a/src/dartsort/cluster/gaussian_mixture.py +++ b/src/dartsort/cluster/gaussian_mixture.py @@ -11,7 +11,7 @@ from scipy.special import logsumexp from tqdm.auto import tqdm, trange -from ..util import more_operators, noise_util +from ..util import more_operators, noise_util, spiketorch from .cluster_util import agglomerate from .kmeans import kmeans from .modes import smoothed_dipscore_at @@ -75,6 +75,7 @@ def __init__( merge_bimodality_masked: bool = False, merge_sym_function: callable = np.minimum, em_converged_prop: float = 0.02, + em_converged_churn: float = 0.01, em_converged_atol: float = 1e-2, ): super().__init__() @@ -107,6 +108,7 @@ def __init__( self.merge_linkage = merge_linkage self.em_converged_prop = em_converged_prop self.em_converged_atol = em_converged_atol + self.em_converged_churn = em_converged_churn self.split_em_iter = split_em_iter self.split_whiten = split_whiten self.use_proportions = use_proportions @@ -117,7 +119,7 @@ def __init__( self.labels = torch.from_numpy(self.labels) # this is populated by self.m_step() - self.units = torch.nn.ModuleList() + self._units = torch.nn.ModuleDict() self.log_proportions = None # store arguments to the unit constructor in a dict @@ -151,6 +153,62 @@ def __init__( self.storage = threading.local() self.next_round_annotations = {} + # -- unit management + + def __getitem__(self, ix): + ix = self.normalize_key(ix) + if ix not in self: + raise KeyError(f"Mixture has no unit with ID {ix}") + return self._units[ix] + + def __setitem__(self, ix, value): + ix = self.normalize_key(ix) + self._units[ix] = value + + def __contains__(self, ix): + ix = self.normalize_key(ix) + return ix in self._units + + def update(self, other): + if isinstance(other, dict): + other = other.items() + for k, v in other: + self[k] = v + + def empty(self): + return not self._units + + def clear_units(self): + self._stack = None + self._units.clear() + + def unit_ids(self): + return np.array([int(k) for k in self._units.keys()]) + + def items(self): + for k, v in self._units.items(): + yield int(k), v + + def ids_and_units(self): + return self.unit_ids(), self._units.values() + + def n_units(self): + nu_u = max(self.unit_ids()) + 1 + nu_l = self.label_ids().max() + 1 + return max(nu_u, nu_l) + + def label_ids(self): + uids = torch.unique(self.labels) + return uids[uids >= 0] + + def ids(self): + return torch.arange(self.n_units()) + + def n_labels(self): + unit_ids = self.label_ids() + nu = unit_ids.max() + 1 + return nu + # -- headliners def to_sorting(self): @@ -158,10 +216,6 @@ def to_sorting(self): labels[self.data.kept_indices] = self.labels.numpy(force=True) return replace(self.data.original_sorting, labels=labels) - def unit_ids(self): - uids = torch.unique(self.labels) - return uids[uids >= 0] - def em(self, n_iter=None, show_progress=True, final_e_step=True): n_iter = self.n_em_iters if n_iter is None else n_iter if show_progress: @@ -171,62 +225,87 @@ def em(self, n_iter=None, show_progress=True, final_e_step=True): its = range(n_iter) # if we have no units, we can't E step. - if not self.units: + if self.empty(): self.m_step(show_progress=step_progress) + self.cleanup(min_count=1) + convergence_props = {} + log_liks = None for _ in its: # for convergence testing... - self.cleanup(min_count=1) - means, *_ = self.stack_units(mean_only=True) + log_liks, convergence_props = self.cleanup( + log_liks, min_count=1, clean_props=convergence_props + ) + + recompute_mask = None + if "adif" in convergence_props: + recompute_mask = convergence_props["adif"] > 0 + # recompute_mask = torch.ones(self.n_units(), dtype=bool) - reas_count, log_liks, spike_logliks = self.e_step( - show_progress=step_progress + unit_churn, reas_count, log_liks, spike_logliks = self.e_step( + show_progress=step_progress, + recompute_mask=recompute_mask, + prev_log_liks=log_liks, + ) + convergence_props["unit_churn"] = unit_churn + log_liks, convergence_props = self.cleanup( + log_liks, clean_props=convergence_props ) meanlogpx = spike_logliks.mean() # M step: fit units based on responsibilities - mres = self.m_step(log_liks, show_progress=step_progress, prev_means=means) + to_fit = convergence_props["unit_churn"] >= self.em_converged_churn + mres = self.m_step(log_liks, show_progress=step_progress, to_fit=to_fit) + convergence_props["adif"] = mres["adif"] # extra info for description + max_adif = mres["max_adif"] if show_progress: opct = (self.labels < 0).sum() / self.data.n_spikes opct = f"{100 * opct:.1f}" - nu = len(self.units) + nu = self.n_units().item() reas_prop = reas_count / self.data.n_spikes rpct = f"{100 * reas_prop:.1f}" - adif = f"{mres['adif']:.2f}" - msg = f"EM[{nu=},out={opct}%,reas={rpct}%,dmu={adif},meanlogpx={meanlogpx:g}]" + adif = f"{max_adif:.2f}" + msg = ( + f"EM[K={nu},Ka={to_fit.sum()};{opct}%fp," + f"{rpct}%reas,dmu={adif};logpx/n={meanlogpx:.1f}]" + ) its.set_description(msg) if reas_prop < self.em_converged_prop: break - if mres["adif"] < self.em_converged_atol: + if max_adif is not None and max_adif < self.em_converged_atol: break if not final_e_step: return # final e step for caller - reas_count, log_liks, spike_logliks = self.e_step(clean_units=True) + unit_churn, reas_count, log_liks, spike_logliks = self.e_step( + show_progress=step_progress + ) return log_liks - def e_step(self, show_progress=False): + def e_step(self, show_progress=False, prev_log_liks=None, recompute_mask=None): # E step: get responsibilities and update hard assignments - log_liks = self.log_likelihoods(show_progress=show_progress) + log_liks = self.log_likelihoods( + show_progress=show_progress, + previous_logliks=prev_log_liks, + recompute_mask=recompute_mask, + ) # replace log_liks by csc - reas_count, spike_logliks, log_liks = self.reassign(log_liks) - - # garbage collection -- get rid of low count labels - log_liks = self.cleanup(log_liks) - - return reas_count, log_liks, spike_logliks + unit_churn, reas_count, spike_logliks, log_liks = self.reassign(log_liks) + return unit_churn, reas_count, log_liks, spike_logliks - def m_step(self, likelihoods=None, show_progress=False, prev_means=None): + def m_step(self, likelihoods=None, show_progress=False, to_fit=None): """Beware that this flattens the labels.""" - needs_append = not self.units - unit_ids = self.unit_ids() - if not needs_append: - assert unit_ids.max() < len(self.units) + warm_start = not self.empty() + fit_ids = unit_ids = self.label_ids() + if to_fit is not None: + fit_ids = unit_ids[to_fit[unit_ids]] + if warm_start: + _, prev_means, *_ = self.stack_units(mean_only=True) if self.use_proportions and likelihoods is not None: self.update_proportions(likelihoods) @@ -236,34 +315,44 @@ def m_step(self, likelihoods=None, show_progress=False, prev_means=None): pool = Parallel(self.n_threads, backend="threading", return_as="generator") results = pool( delayed(self.fit_unit)( - j, likelihoods=likelihoods, **self.next_round_annotations.get(j, {}) + j, + likelihoods=likelihoods, + warm_start=warm_start, + **self.next_round_annotations.get(j, {}), ) - for j in unit_ids + for j in fit_ids ) if show_progress: results = tqdm( - results, desc="M step", unit="unit", total=len(unit_ids), **tqdm_kw + results, desc="M step", unit="unit", total=len(fit_ids), **tqdm_kw ) - if needs_append: - for j, unit in enumerate(zip(unit_ids, results)): - assert unit.annotations["unit_id"] == j - self.units.append(unit) - adif = None - if prev_means is not None: - nu = len(unit_ids) - new_means, *_ = self.stack_units(mean_only=True) - dmu = (prev_means - new_means).abs_().view(nu, -1) - adif = torch.max(dmu) + for uid, unit in zip(fit_ids, results): + if not warm_start: + self[uid] = unit + max_adif = adif = None self._stack = None + if warm_start: + ids, new_means, *_ = self.stack_units(mean_only=True) + dmu = (prev_means - new_means).abs_().view(len(new_means), -1) + adif_ = torch.max(dmu, dim=1).values + max_adif = adif_.max() + adif = torch.zeros(self.n_units()) + adif[ids] = adif_ self.clear_scheduled_annotations() - return dict(adif=adif) + return dict(max_adif=max_adif, adif=adif) def log_likelihoods( - self, unit_ids=None, with_noise_unit=True, use_storage=True, show_progress=False + self, + unit_ids=None, + with_noise_unit=True, + use_storage=True, + show_progress=False, + previous_logliks=None, + recompute_mask=None, ): """Noise unit last so that rows correspond to unit ids without 1 offset""" if unit_ids is None: - unit_ids = range(len(self.units)) + unit_ids = self.unit_ids() # how many units does each core neighborhood overlap with? n_cores = self.data.core_neighborhoods.n_neighborhoods @@ -277,11 +366,26 @@ def log_likelihoods( neighb_info = [] nnz = 0 for j in unit_ids: - unit = self.units[j] - neighbs, ns_unit = self.data.core_neighborhoods.subset_neighborhoods( - unit.channels, add_to_overlaps=core_overlaps - ) - neighb_info.append((j, neighbs, ns_unit)) + unit = self[j] + if recompute_mask is None or recompute_mask[j]: + covered_neighbs, neighbs, ns_unit = ( + self.data.core_neighborhoods.subset_neighborhoods( + unit.channels, add_to_overlaps=core_overlaps + ) + ) + unit.annotations["covered_neighbs"] = covered_neighbs + neighb_info.append((j, neighbs, ns_unit)) + else: + row = previous_logliks[[j]].tocoo(copy=True) + six = row.coords[1] + ns_unit = row.nnz + if "covered_neighbs" in unit.annotations: + covered_neighbs = unit.annotations["covered_neighbs"] + else: + covered_neighbs = self.data.core_neighborhoods.neighborhood_ids[six] + covered_neighbs = torch.unique(covered_neighbs) + neighb_info.append((j, six, row.data, ns_unit)) + core_overlaps[covered_neighbs] += 1 nnz += ns_unit # how many units does each spike overlap with? needed to write csc @@ -293,6 +397,17 @@ def log_likelihoods( if with_noise_unit: nnz = nnz + self.data.n_spikes + @delayed + def job(ninfo): + if len(ninfo) == 4: + j, coo, data, ns = ninfo + return coo, data + else: + assert len(ninfo) == 3 + j, neighbs, ns = ninfo + ix, ll = self.unit_log_likelihoods(unit_id=j, neighbs=neighbs, ns=ns) + return ix, ll + # get the big nnz-length csc buffers. these can be huge so we cache them. csc_indices, csc_data = get_csc_storage(nnz, self.storage, use_storage) # csc compressed indptr. spikes are columns. @@ -302,10 +417,7 @@ def log_likelihoods( # spike, we increment the spike's "write head". idea is to directly make csc write_offsets = indptr[:-1].copy() pool = Parallel(self.n_threads, backend="threading", return_as="generator") - results = pool( - delayed(self.unit_log_likelihoods)(unit_id=j, neighbs=neighbs, ns=ns) - for j, neighbs, ns in neighb_info - ) + results = pool(job(ninfo) for ninfo in neighb_info) if show_progress: results = tqdm( results, @@ -317,8 +429,9 @@ def log_likelihoods( for j, (inds, liks) in enumerate(results): if inds is None: continue - inds = inds.numpy(force=True) - liks = liks.numpy(force=True) + if torch.is_tensor(inds): + inds = inds.numpy(force=True) + liks = liks.numpy(force=True) csc_insert(j, write_offsets, inds, csc_indices, csc_data, liks) # inds = inds.numpy(force=True) # data_ixs = write_offsets[inds] @@ -333,7 +446,7 @@ def log_likelihoods( csc_indices[data_ixs] = j + 1 csc_data[data_ixs] = liks.numpy(force=True) - shape = (len(unit_ids) + with_noise_unit, self.data.n_spikes) + shape = (self.n_units() + with_noise_unit, self.data.n_spikes) log_liks = csc_array((csc_data, csc_indices, indptr), shape=shape) log_liks.has_canonical_format = True @@ -369,36 +482,63 @@ def update_proportions(self, log_liks): ) def reassign(self, log_liks): - has_noise_unit = log_liks.shape[1] > len(self.units) + n_units = self.n_units() assignments, spike_logliks, log_liks_csc = loglik_reassign( log_liks, - has_noise_unit=has_noise_unit, + has_noise_unit=True, log_proportions=self.log_proportions, ) assignments = torch.from_numpy(assignments).to(self.labels) - reassign_count = (self.labels != assignments).sum() + + # track reassignments, first globally + same = torch.zeros_like(assignments) + torch.eq(self.labels, assignments, out=same) + reassign_count = len(same) - same.sum() + + # and per unit IoU + intersection = torch.zeros(n_units, dtype=int) + kept = assignments >= 0 + spiketorch.add_at_(intersection, assignments[kept], same[kept]) + union = torch.zeros(n_units, dtype=int) + buf1 = torch.zeros(same.shape, dtype=bool) + buf2 = torch.zeros(same.shape, dtype=bool) + buf3 = torch.zeros(same.shape, dtype=bool) + for j in range(n_units): + torch.eq(assignments, j, out=buf1) + torch.eq(self.labels, j, out=buf2) + union[j] = max(1, torch.logical_or(buf1, buf2, out=buf3).sum()) + iou = intersection / union + unit_churn = 1.0 - iou + + # update labels self.labels.copy_(assignments) - return reassign_count, spike_logliks, log_liks_csc - def cleanup(self, log_liks=None, min_count=None): + return unit_churn, reassign_count, spike_logliks, log_liks_csc + + def cleanup(self, log_liks=None, min_count=None, clean_props=None): """Remove too-small units, make label space contiguous, tidy all properties""" - unit_ids, counts = torch.unique(self.labels, return_counts=True) - counts = counts[unit_ids >= 0] - unit_ids = unit_ids[unit_ids >= 0] if min_count is None: min_count = self.min_count + + label_ids, counts = torch.unique(self.labels, return_counts=True) + counts = counts[label_ids >= 0] + label_ids = label_ids[label_ids >= 0] big_enough = counts >= min_count - keep = torch.zeros(len(self.units), dtype=bool) - keep[unit_ids] = big_enough + + keep = torch.zeros(self.n_units(), dtype=bool) + keep[label_ids] = big_enough + self._stack = None + if keep.all(): - return log_liks + return log_liks, clean_props + + if clean_props: + clean_props = {k: v[keep] for k, v in clean_props.items()} keep_noise = torch.concatenate((keep, torch.ones_like(keep[:1]))) keep = keep.numpy(force=True) - if log_liks is not None: - has_noise_unit = log_liks.shape[1] > len(self.units) - kept_ids = unit_ids[big_enough] + kept_ids = label_ids[big_enough] new_ids = torch.arange(kept_ids.numel()) old2new = dict(zip(kept_ids, new_ids)) self._relabel(kept_ids) @@ -410,36 +550,37 @@ def cleanup(self, log_liks=None, min_count=None): lps -= logsumexp(lps) self.log_proportions = self.log_proportions.new_tensor(lps) - if len(self.units): - keep_units = [] + if not self.empty(): + keep_units = {} for oi, ni in zip(kept_ids, new_ids): - u = self.units[oi] - u.annotations["unit_id"] = ni - keep_units.append(u) - del self.units[:] - self.units.extend(keep_units) + u = self[oi] + keep_units[ni] = u + self.clear_units() + self.update(keep_units) if self.next_round_annotations: next_round_annotations = {} for j, nra in self.next_round_annotations.items(): if keep[j]: next_round_annotations[old2new[j]] = nra + self.next_round_annotations = next_round_annotations if log_liks is None: - return + return log_liks, clean_props - keep_ll = keep_noise.numpy(force=True) if has_noise_unit else keep + keep_ll = keep_noise.numpy(force=True) assert keep_ll.size == log_liks.shape[0] if isinstance(log_liks, coo_array): log_liks = coo_sparse_mask_rows(log_liks, keep_ll) elif isinstance(log_liks, csc_array): keep_ll = np.flatnonzero(keep_ll) + assert keep_ll.max() == log_liks.shape[0] - 1 log_liks = log_liks[keep_ll] else: assert False - return log_liks + return log_liks, clean_props def merge(self, log_liks=None, show_progress=True): distances = self.distances(show_progress=show_progress) @@ -468,9 +609,9 @@ def merge(self, log_liks=None, show_progress=True): for new_id in np.unique(new_ids): merge_parents = np.flatnonzero(new_ids == new_ids) - self.schedule_annotations(new_id, dict(merge_parents=merge_parents)) + self.schedule_annotations(new_id, merge_parents=merge_parents) - del self.units[:] + self.clear_units() if self.log_proportions is not None: log_props = self.log_proportions.numpy(force=True) @@ -484,7 +625,6 @@ def merge(self, log_liks=None, show_progress=True): self.log_proportions = torch.asarray( new_log_props, device=self.log_proportions.device ) - self._stack = None def split(self, show_progress=True): pool = Parallel( @@ -500,8 +640,7 @@ def split(self, show_progress=True): if "new_ids" in res: for nid in res["new_ids"]: self.schedule_annotations(nid, split_parent=res["parent_id"]) - del self.units[:] - self._stack = None + self.clear_units() def distances( self, kind=None, noise_normalized=None, units=None, show_progress=True @@ -513,21 +652,28 @@ def distances( noise_normalized = self.distance_noise_normalized if units is None: - units = self.units - nu = len(units) + nu = self.n_units() + ids, units = self.ids_and_units() + else: + nu = len(units) + ids = range(nu) # stack unit data into one place mean_only = kind == "noise_metric" - means, covs, logdets = self.stack_units(units, mean_only=mean_only) + ids, means, covs, logdets = self.stack_units( + nu=nu, ids=ids, units=units, mean_only=mean_only + ) # compute denominator of noised normalized distances if noise_normalized: - denom = self.noise_unit.divergence( + denom_ = self.noise_unit.divergence( means, other_covs=covs, other_logdets=logdets, kind=kind ) - denom = np.sqrt(denom.numpy(force=True)) + denom = np.zeros(nu) + denom[ids] = np.sqrt(denom_.numpy(force=True)) - dists = np.zeros((nu, nu), dtype=np.float32) + dists = np.full((nu, nu), np.inf, dtype=np.float32) + np.fill_diagonal(dists, 0.0) @delayed def dist_job(j, unit): @@ -537,12 +683,12 @@ def dist_job(j, unit): d = d.numpy(force=True).astype(dists.dtype) if noise_normalized: d /= denom * denom[j] - dists[j] = d + dists[j, ids] = d pool = Parallel( self.n_threads, backend="threading", return_as="generator_unordered" ) - results = pool(dist_job(j, u) for j, u in enumerate(units)) + results = pool(dist_job(j, u) for j, u in zip(ids, units)) if show_progress: results = tqdm(results, desc="Distances", total=nu, unit="unit", **tqdm_kw) for _ in results: @@ -570,7 +716,7 @@ def bimodalities( min_overlap = self.merge_bimodality_overlap if masked is None: masked = self.merge_bimodality_masked - nu = len(self.units) + nu = self.n_units() in_units = [ torch.nonzero(self.labels == j, as_tuple=True)[0] for j in range(nu) ] @@ -683,6 +829,7 @@ def fit_unit( weights=None, features=None, verbose=False, + warm_start=False, **unit_args, ): if features is None: @@ -696,16 +843,14 @@ def fit_unit( if verbose and weights is not None: print(f"{weights.sum()=} {weights.min()=} {weights.max()=}") unit_args = self.unit_args | unit_args - if len(self.units) > unit_id: - unit = self.units[unit_id] - assert unit.annotations.get(unit_id, unit_id) == unit_id + if warm_start: + unit = self[unit_id] unit.fit(features, weights, neighborhoods=self.data.extract_neighborhoods) else: unit = GaussianUnit.from_features( features, weights, neighborhoods=self.data.extract_neighborhoods, - unit_id=unit_id, **unit_args, ) return unit @@ -733,7 +878,7 @@ def unit_log_likelihoods( log_likelihoods """ if unit is None: - unit = self.units[unit_id] + unit = self[unit_id] if ignore_channels: core_channels = torch.arange(self.data.n_channels) else: @@ -749,9 +894,10 @@ def unit_log_likelihoods( core_channels, spike_indices ) else: - neighbs, ns = self.data.core_neighborhoods.subset_neighborhoods( - core_channels + covered_neighbs, neighbs, ns = ( + self.data.core_neighborhoods.subset_neighborhoods(core_channels) ) + unit.annotations["covered_neighbs"] = covered_neighbs if not ns: return None, None @@ -812,8 +958,8 @@ def noise_log_likelihoods(self, show_progress=False): def kmeans_split_unit(self, unit_id, debug=False): # get spike data and use interpolation to fill it out to the # unit's channel set - result = dict(parent_id=unit_id) - unit = self.units[unit_id] + result = dict(parent_id=unit_id, new_ids=[unit_id]) + unit = self[unit_id] if not unit.channels.numel(): return result @@ -861,7 +1007,8 @@ def kmeans_split_unit(self, unit_id, debug=False): if not valid.any(): return result split_ids = split_ids[valid] - assert np.array_equal(split_ids, np.arange(len(split_ids))) + if not np.array_equal(split_ids, np.arange(len(split_ids))): + raise ValueError(f"Bad {split_ids=}") if debug: result["merge_labels"] = split_labels @@ -874,7 +1021,7 @@ def kmeans_split_unit(self, unit_id, debug=False): with self.labels_lock: self.labels[indices_full] = -1 self.labels[sp.indices[split_labels >= 0]] = unit_id - return + return result # else, tack new units onto the end # we need to lock up the labels array access here because labels.max() @@ -1090,7 +1237,7 @@ def unit_group_criterion( debug=False, ): """See if a single unit explains a group as far as AIC/BIC/MDL go.""" - assert fit_type in ("avg_preexisting", "refit_all") + assert fit_type in ("avg_preexisting", "refit_all", "refit_avg") unit_ids = torch.tensor(unit_ids) # pick spikes for likelihood computation @@ -1103,7 +1250,7 @@ def unit_group_criterion( spikes_core = self.data.spike_data(in_any, neighborhood="core") n = in_any.numel() - if fit_type == "refit_all": + if fit_type.startswith("refit"): units = [] subunit_logliks = spikes_core.features.new_full( (len(unit_ids), len(in_any)), -torch.inf @@ -1117,20 +1264,25 @@ def unit_group_criterion( features=spikes_extract, ) units.append(u) - _, subunit_logliks[i] = self.unit_log_likelihoods( - unit=u, spikes=spikes_core - ) + _, sll = self.unit_log_likelihoods(unit=u, spikes=spikes_core) + if sll is not None: + subunit_logliks[i] = sll subunit_log_props = F.softmax(subunit_logliks, dim=0).mean(1).log_() # loglik per spik full_loglik = torch.logsumexp( subunit_logliks.T + subunit_log_props, dim=1 ).mean() - unit = self.fit_unit(indices=in_any, features=spikes_extract) - likelihoods = None + + if fit_type == "refit_all": + unit = self.fit_unit(indices=in_any, features=spikes_extract) + elif fit_type == "refit_avg": + unit = average_units(units, props=subunit_log_props.exp()) + else: + assert False elif fit_type == "avg_preexisting": - unit = self.units[unit_ids[0]].avg_with( - *[self.units[u] for u in unit_ids[1:]] - ) + subunit_log_props = self.log_proportions[unit_ids] + units = [self[uid] for uid in unit_ids] + unit = average_units(units, F.softmax(subunit_log_props)) if debug: subunit_logliks = likelihoods[:, in_any][unit_ids] full_loglik = marginal_loglik( @@ -1144,7 +1296,10 @@ def unit_group_criterion( # extract likelihoods... no proportions! _, unit_logliks = self.unit_log_likelihoods(unit=unit, spikes=spikes_core) - unit_loglik = unit_logliks.mean() + if unit_logliks is not None: + unit_loglik = unit_logliks.mean() + else: + unit_loglik = np.nan # parameter counting... since we use marginal likelihoods, I'm restricting # the parameter counts to just the marginal set considered for each spike. @@ -1153,8 +1308,8 @@ def unit_group_criterion( unique_nids, inverse = torch.unique(nids, return_inverse=True) unique_chans = self.data.core_neighborhoods.neighborhoods[unique_nids] unique_k_merged = unit.n_params(unique_chans) - unique_k_full = [self.units[u].n_params(unique_chans) for u in unit_ids] - unique_k_full = torch.stack(unique_k_full, dim=1) + unique_k_full = [self[u].n_params(unique_chans) for u in unit_ids] + unique_k_full = torch.stack(unique_k_full, dim=1).sum(1) k_merged = unique_k_merged[inverse] k_full = unique_k_full[inverse] @@ -1222,6 +1377,14 @@ def get_fit_weights(self, unit_id, indices, likelihoods=None): # -- gizmos + @staticmethod + def normalize_key(ix): + if torch.is_tensor(ix): + ix = ix.numpy(force=True).item() + elif isinstance(ix, np.ndarray): + ix = ix.item() + return str(ix) + @property def rg(self): # thread-local rgs since they aren't safe @@ -1233,7 +1396,7 @@ def rg(self): def _relabel(self, old_labels, new_labels=None, flat=False): """Re-label units - !! This could invalidate self.units and props. + !! This could invalidate self._units and props. Suggested to only call .cleanup(), this is its low-level helper. @@ -1264,20 +1427,29 @@ def _relabel(self, old_labels, new_labels=None, flat=False): self.labels[torch.logical_not(kept)] = -1 self._stack = None - def stack_units(self, units=None, mean_only=True, use_cache=False): - if units is None: - units = self.units + def stack_units( + self, nu=None, units=None, ids=None, mean_only=True, use_cache=False + ): + if ids is not None: + assert units is not None + elif units is not None: + ids = np.arange(len(units)) + else: + ids, units = self.ids_and_units() + if nu is None: + nu = len(ids) + if use_cache and self._stack is not None: if mean_only or self._stack[1] is not None: return self._stack - nu, rank, nc = len(units), self.data.rank, self.data.n_channels + rank, nc = self.data.rank, self.data.n_channels - means = torch.zeros((nu, rank, nc), device=self.data.device) + means = torch.full((nu, rank, nc), torch.nan, device=self.data.device) covs = logdets = None if not mean_only: - covs = means.new_zeros((nu, rank * nc, rank * nc)) - logdets = means.new_zeros((nu,)) + covs = means.new_full((nu, rank * nc, rank * nc), torch.nan) + logdets = means.new_full((nu,), torch.nan) for j, unit in enumerate(units): means[j] = unit.mean @@ -1286,9 +1458,11 @@ def stack_units(self, units=None, mean_only=True, use_cache=False): logdets[j] = unit.logdet if use_cache: - self._stack = means, covs, logdets + self._stack = ids, means, covs, logdets + else: + self._stack = None - return means, covs, logdets + return ids, means, covs, logdets def schedule_annotations(self, unit_id, **annotations): if unit_id not in self.next_round_annotations: @@ -1370,9 +1544,9 @@ def __init__( self.scale_mean = scale_mean self.scale_alpha = float(prior_pseudocount) self.scale_beta = float(prior_pseudocount) / scale_mean - self.annotations = annotations self.ppca_rank = ppca_rank self.ppca_inner_em_iter = ppca_inner_em_iter + self.annotations = annotations @classmethod def from_features( @@ -1390,6 +1564,7 @@ def from_features( prior_pseudocount=10, ppca_inner_em_iter=1, scale_mean: float = 0.1, + **annotations, ): self = cls( rank=features.features.shape[1], @@ -1404,6 +1579,7 @@ def from_features( scale_mean=scale_mean, ppca_rank=ppca_rank, ppca_inner_em_iter=ppca_inner_em_iter, + **annotations, ) self.fit(features, weights, neighborhoods=neighborhoods) self = self.to(features.features.device) @@ -1585,6 +1761,42 @@ def kl_divergence(self, other_means, other_covs, other_logdets): tqdm_kw = dict(smoothing=0, mininterval=1.0 / 24.0) +def average_units(units, proportions): + ua = units[0] + assert ua.cov_kind == "zero" + new_unit = GaussianUnit( + rank=ua.rank, + n_channels=ua.n_channels, + noise=ua.noise, + mean_kind=ua.mean_kind, + cov_kind=ua.cov_kind, + prior_type=ua.prior_type, + channels_strategy=ua.channels_strategy, + channels_strategy_snr_min=ua.channels_strategy_snr_min, + prior_pseudocount=ua.prior_pseudocount, + scale_mean=ua.scale_mean, + ) + if ua.mean_kind == "zero": + return ua + + if ua.channels_strategy == "all": + channels = ua.channels + elif ua.channels_strategy == "snr": + avg_snr = sum(pi * u.snr for pi, u in zip(proportions, units)) + new.register_buffer("snr", avg_snr) + (channels,) = torch.nonzero( + avg_snr >= new.channels_strategy_snr_min, as_tuple=True + ) + else: + assert False + new.register_buffer("channels", channels) + + new_mean = sum(pi * u.mean for pi, u in zip(proportions, units)) + new.register_buffer("mean", new_mean) + + return new + + def to_full_probe(features, weights, n_channels, storage): """ Arguments diff --git a/src/dartsort/cluster/stable_features.py b/src/dartsort/cluster/stable_features.py index 6aac97c9..dd7df86d 100644 --- a/src/dartsort/cluster/stable_features.py +++ b/src/dartsort/cluster/stable_features.py @@ -442,9 +442,7 @@ def subset_neighborhoods(self, channels, min_coverage=1.0, add_to_overlaps=None) for j in covered_ids } n_spikes = self.popcounts[covered_ids].sum() - if add_to_overlaps is not None: - add_to_overlaps[covered_ids] += 1 - return neighborhood_info, n_spikes + return covered_ids, neighborhood_info, n_spikes def spike_neighborhoods(self, channels, spike_indices, min_coverage=1.0): """Like subset_neighborhoods, but for an already chosen collection of spikes diff --git a/src/dartsort/templates/pairwise_util.py b/src/dartsort/templates/pairwise_util.py index 7d98e3ce..405f2d41 100644 --- a/src/dartsort/templates/pairwise_util.py +++ b/src/dartsort/templates/pairwise_util.py @@ -386,7 +386,7 @@ def conv_to_resid( svs_a = low_rank_templates_a.singular_values[template_indices_a] svs_b = low_rank_templates_b.singular_values[template_indices_b] active_a = torch.any(spatial_a > 0, dim=1).to(svs_a) - if ignore_empty_channels: + if ignore_empty_channels and low_rank_templates_b.spike_counts_by_channel is not None: active_b = low_rank_templates_b.spike_counts_by_channel[template_indices_b] active_b = active_b > 0 active_b = torch.from_numpy(active_b).to(svs_a) diff --git a/src/dartsort/util/hybrid_util.py b/src/dartsort/util/hybrid_util.py index e0fe0f29..e7e02e30 100644 --- a/src/dartsort/util/hybrid_util.py +++ b/src/dartsort/util/hybrid_util.py @@ -1,19 +1,16 @@ +import dataclasses import numpy as np +from tqdm.auto import tqdm import warnings -from spikeinterface.core import BaseRecording, BaseRecordingSegment, Templates -from spikeinterface.core.core_tools import define_function_from_class from spikeinterface.extractors import NumpySorting from spikeinterface.generation.drift_tools import InjectDriftingTemplatesRecording, DriftingTemplates, move_dense_templates -from spikeinterface.preprocessing.basepreprocessor import ( - BasePreprocessor, BasePreprocessorSegment) from probeinterface import Probe +from scipy.spatial import KDTree +from scipy.sparse import csgraph, coo_array from ..templates import TemplateData -from .analysis import DARTsortAnalysis from .data_util import DARTsortSorting from ..config import unshifted_raw_template_config -from ..templates import TemplateData - def get_drifty_hybrid_recording( @@ -82,7 +79,7 @@ def get_drifty_hybrid_recording( displacement_vectors=[disp], displacement_sampling_frequency=displacement_sampling_frequency, displacement_unit_factor=displacement_unit_factor, - amplitude_factor=amplitude_factor + amplitude_factor=amplitude_factor, ) rec.annotate(peak_channel=peak_channels.tolist()) return rec @@ -205,6 +202,7 @@ def refractory_poisson_spike_train( return spike_samples + def piecewise_refractory_poisson_spike_train(rates, bins, binsize_samples, **kwargs): """ Returns a spike train with variable firing rate using refractory_poisson_spike_train(). @@ -214,13 +212,14 @@ def piecewise_refractory_poisson_spike_train(rates, bins, binsize_samples, **kwa :param binsize_samples: number of samples per bin :param **kwargs: kwargs to feed to refractory_poisson_spike_train() """ - sp_tr = np.concatenate( - [ - refractory_poisson_spike_train(r, binsize_samples, **kwargs) + bins[i] if r > 0.1 else [] - for i, r in enumerate(rates) - ] - ) - return sp_tr + st = [] + for rate, bin in zip(rates, bins): + if rate < 0.1: + continue + binst = refractory_poisson_spike_train(rate, binsize_samples, **kwargs) + st.append(bin + binst) + st = np.concatenate(st) + return st def precompute_displaced_registered_templates( @@ -255,6 +254,66 @@ def precompute_displaced_registered_templates( return ret +def closest_clustering(gt_st, peel_st, geom=None, match_dt_ms=0.1, match_radius_um=0.0, p=2.0, M=50.0): + frames_per_ms = gt_st.sampling_frequency / 1000 + delta_frames = match_dt_ms * frames_per_ms + rescale = [delta_frames] + gt_pos = gt_st.times_samples[:, None] + peel_pos = peel_st.times_samples[:, None] + if match_radius_um: + rescale = rescale + (geom.shape[1] * [match_radius_um]) + gt_pos = np.c_[gt_pos, geom[gt_st.channels]] + peel_pos = np.c_[peel_pos, geom[peel_st.channels]] + else: + gt_pos = gt_pos.astype(float) + peel_pos = peel_pos.astype(float) + gt_pos /= rescale + peel_pos /= rescale + labels = greedy_match(gt_pos, peel_pos, dx=1.0 / frames_per_ms) + labels[labels >= 0] = gt_st.labels[labels[labels >= 0]] + + return dataclasses.replace(peel_st, labels=labels) + + +def greedy_match(gt_coords, test_coords, max_val=1.0, dx=1./30, workers=-1, p=2.0): + assignments = np.full(len(test_coords), -1) + gt_unmatched = np.ones(len(gt_coords), dtype=bool) + + for j, thresh in enumerate( + tqdm(np.arange(0.0, max_val + dx + 2e-5, dx), desc="match") + ): + test_unmatched = np.flatnonzero(assignments < 0) + if not test_unmatched.size: + break + test_kdtree = KDTree(test_coords[test_unmatched]) + gt_ix = np.flatnonzero(gt_unmatched) + d, i = test_kdtree.query( + gt_coords[gt_ix], + k=1, + distance_upper_bound=min(thresh, max_val), + workers=workers, + p=p, + ) + # handle multiple gt spikes getting matched to the same peel ix + thresh_matched = i < test_kdtree.n + _, ii = np.unique(i, return_index=True) + i = i[ii] + thresh_matched = thresh_matched[ii] + + gt_ix = gt_ix[ii] + gt_ix = gt_ix[thresh_matched] + i = i[thresh_matched] + assignments[test_unmatched[i]] = gt_ix + gt_unmatched[gt_ix] = False + + if not gt_unmatched.any(): + break + if thresh > max_val: + break + + return assignments + + def sorting_from_times_labels(times, labels, recording=None, sampling_frequency=None, determine_channels=True, template_config=unshifted_raw_template_config, n_jobs=0): channels = np.zeros_like(labels) if sampling_frequency is None: @@ -273,6 +332,14 @@ def sorting_from_times_labels(times, labels, recording=None, sampling_frequency= return sorting, td -def sorting_from_spikeinterface(sorting, recording=None, determine_channels=True, template_config=unshifted_raw_template_config, n_jobs=0): +def sorting_from_spikeinterface( + sorting, + recording=None, + determine_channels=True, + template_config=unshifted_raw_template_config, + n_jobs=0, +): sv = sorting.to_spike_vector() - return sorting_from_times_labels(sv['sample_index'], sv['unit_index'], sampling_frequency=sorting.sampling_frequency, recording=recording, determine_channels=determine_channels, template_config=template_config, n_jobs=n_jobs) + return sorting_from_times_labels( + sv['sample_index'], sv['unit_index'], sampling_frequency=sorting.sampling_frequency, recording=recording, determine_channels=determine_channels, template_config=template_config, n_jobs=n_jobs + ) diff --git a/src/dartsort/util/spiketorch.py b/src/dartsort/util/spiketorch.py index 45011a3d..9b3b0c46 100644 --- a/src/dartsort/util/spiketorch.py +++ b/src/dartsort/util/spiketorch.py @@ -5,6 +5,7 @@ import torch.nn.functional as F from scipy.fftpack import next_fast_len from torch.fft import irfft, rfft +import warnings def fast_nanmedian(x, axis=-1): @@ -315,9 +316,12 @@ def nancov(x, weights=None, correction=1, nan_free=False, return_nobs=False, for cov = xtx / denom if force_posdef: - vals, vecs = torch.linalg.eigh(cov) - good = vals > 0 - cov = (vecs[:, good] * vals[good]) @ vecs[:, good].T + try: + vals, vecs = torch.linalg.eigh(cov) + good = vals > 0 + cov = (vecs[:, good] * vals[good]) @ vecs[:, good].T + except Exception as e: + warnings.warn(f"Error in nancov eigh: {e}") if return_nobs: return cov, nobs diff --git a/src/dartsort/vis/gmm.py b/src/dartsort/vis/gmm.py index 9a5c39e5..3c978a51 100644 --- a/src/dartsort/vis/gmm.py +++ b/src/dartsort/vis/gmm.py @@ -1,4 +1,6 @@ from pathlib import Path +import warnings +import itertools import matplotlib.pyplot as plt import numpy as np @@ -33,7 +35,7 @@ def draw(self, panel, gmm, unit_id): class ISIHistogram(GMMPlot): kind = "small" width = 2 - height = 2 + height = 1.5 def __init__(self, bin_ms=0.1, max_ms=5): self.bin_ms = bin_ms @@ -68,13 +70,13 @@ def draw(self, panel, gmm, unit_id): s = ax.scatter(*xy[unique_ixs].T, c=counts, lw=0, cmap=self.cmap) plt.colorbar(s, ax=ax, shrink=0.3, label="chan count") ax.scatter( - *xy[gmm.units[unit_id].channels.numpy(force=True)].T, + *xy[gmm[unit_id].channels.numpy(force=True)].T, color="r", lw=1, fc="none", ) ax.scatter( - *xy[np.atleast_1d(gmm.units[unit_id].snr.argmax().numpy(force=True))].T, + *xy[np.atleast_1d(gmm[unit_id].snr.argmax().numpy(force=True))].T, color="g", lw=0, ) @@ -92,6 +94,30 @@ def draw(self, panel, gmm, unit_id): nspikes = (gmm.labels == unit_id).sum() msg += f"n spikes: {nspikes}\n" + if gmm[unit_id].annotations: + msg += 'annots:\n' + for k, v in gmm[unit_id].annotations.items(): + if torch.is_tensor(k): + k = k.numpy(force=True) + if k.size == 1: + k = k.item() + if torch.is_tensor(v): + v = v.numpy(force=True) + if isinstance(v, np.ndarray): + if not v.size: + v = "[]" + elif v.size == 1: + v = v.item() + elif v.ndim == 1: + vv = [str(v[0])] + for vvv in map(str, v[1:]): + if len(vv[-1]) > 16: + vv[-1] += "\n" + vv.append(vvv) + continue + vv[-1] += "," + vvv + v = "\n".join(vv) + msg += f"{k}:\n{v}" axis.text(0, 0, msg, fontsize=6.5) @@ -103,7 +129,7 @@ class MStep(GMMPlot): def __init__(self, n_waveforms_show=64, with_covs=True): self.with_covs = with_covs - self.height = 5 + 4 * with_covs + self.height = 4 + 4 * with_covs self.n_waveforms_show = n_waveforms_show def draw(self, panel, gmm, unit_id, axes=None): @@ -147,7 +173,7 @@ def draw(self, panel, gmm, unit_id, axes=None): n, r, c = feats.shape emp_mean = torch.nanmean(feats, dim=0) emp_mean = gmm.data.tpca.force_reconstruct(emp_mean.nan_to_num_()) - model_mean = gmm.units[unit_id].mean[:, chans] + model_mean = gmm[unit_id].mean[:, chans] model_mean = gmm.data.tpca.force_reconstruct(model_mean) geomplot( @@ -165,24 +191,23 @@ def draw(self, panel, gmm, unit_id, axes=None): return # covariance vis - feats = features_full[:, :, gmm.units[unit_id].channels] - model_mean = gmm.units[unit_id].mean[:, gmm.units[unit_id].channels] + feats = features_full[:, :, gmm[unit_id].channels] + model_mean = gmm[unit_id].mean[:, gmm[unit_id].channels] feats = feats - model_mean n, r, c = feats.shape emp_cov, nobs = spiketorch.nancov(feats.view(n, r * c), return_nobs=True) - denom = nobs + gmm.units[unit_id].prior_pseudocount + denom = nobs + gmm[unit_id].prior_pseudocount emp_cov = (nobs / denom) * emp_cov noise_cov = gmm.noise.marginal_covariance( - channels=gmm.units[unit_id].channels + channels=gmm[unit_id].channels ).to_dense() m = model_mean.reshape(-1) mmt = m[:, None] @ m[None, :] modelcov = ( - gmm.units[unit_id] - .marginal_covariance(channels=gmm.units[unit_id].channels) + gmm[unit_id] + .marginal_covariance(channels=gmm[unit_id].channels) .to_dense() ) - residual = emp_cov - modelcov covs = (emp_cov, noise_cov, mmt.abs(), mmt, modelcov, emp_cov - modelcov) # vmax = max(c.abs().max() for c in covs) names = ("regemp", "noise", "|temptempT|", "temptempT", "model", "resid") @@ -244,7 +269,11 @@ def draw(self, panel, gmm, unit_id): emp_eigs = torch.linalg.eigvalsh(emp_cov) noise_eigs = torch.linalg.eigvalsh(noise_cov) residual_eigs, residual_vecs = torch.linalg.eigh(residual) - model_residual_eigs = torch.linalg.eigvalsh(model_residual) + try: + model_residual_eigs = torch.linalg.eigvalsh(model_residual) + except Exception as e: + warnings.warn(f"Model residual. {e}") + model_residual_eigs = torch.zeros(model_residual.shape[0]) rank1 = (residual_vecs[:, -1:] * residual_eigs[-1:]) @ residual_vecs[:, -1:].T rank1_model = noise_cov + rank1 @@ -256,14 +285,14 @@ def draw(self, panel, gmm, unit_id): # ax_eig, ax_r2 = bot.subplots(ncols=2) ax_eig = bot.subplots() - vm = emp_cov.abs().max() - imk = dict(vmin=-vm, vmax=vm, cmap=plt.cm.seismic, interpolation="none") + # vm = 0.9 * emp_cov.abs().max() + imk = dict(cmap=plt.cm.seismic, interpolation="none") covs = dict( emp=emp_cov, noise=noise_cov, noise_resid=residual, - mmT_scaled=scale * mmT, + mmT=scale * mmT, mmT_model=model, mmT_resid=model_residual, rank1=rank1, @@ -279,19 +308,21 @@ def draw(self, panel, gmm, unit_id): rank1_resid=rank1_residual_eigs, ) for (name, cov), ax, color in zip(covs.items(), axes.flat, colors): - if name == "mmT_scaled": - vm = cov.abs().max() * 0.9 - mimk = dict( - vmin=-vm, vmax=vm, cmap=plt.cm.seismic, interpolation="none" - ) - else: - mimk = imk + vm = cov.abs().max() * 0.9 + mimk = imk | dict(vmax=vm, vmin=-vm) + # if name == "mmT": + # vm = cov.abs().max() * 0.9 + # mimk = dict( + # vmin=-vm, vmax=vm, cmap=plt.cm.seismic, interpolation="none" + # ) + # else: + # mimk = imk im = ax.imshow(cov, **mimk) cb = plt.colorbar(im, ax=ax, shrink=0.2) cb.outline.set_visible(False) title = name - if name.startswith("mmT_sc"): - title = title + f" (scale={scale:0.2f})" + if name == "mmT": + title = title + f" (scale={scale:.2f})" ax.set_title(title, color=color) if name in eigs: ax_eig.plot(eigs[name].flip(0), color=color, lw=1) @@ -459,7 +490,7 @@ def draw(self, panel, gmm, unit_id, split_info=None): ax.set_xticks([]) ax.axhline(0, color="k", lw=0.8) sns.despine(ax=ax, left=False, right=True, bottom=True, top=True) - mainchan = gmm.units[unit_id].snr.argmax() + mainchan = gmm[unit_id].snr.argmax() for subid, subunit in zip(split_ids, split_info["units"]): subm = subunit.mean[:, mainchan] subm = gmm.data.tpca._inverse_transform_in_probe(subm[None])[0] @@ -518,11 +549,11 @@ def __init__(self, n_neighbors=5): def draw(self, panel, gmm, unit_id): neighbors = gmm_helpers.get_neighbors(gmm, unit_id) - units = [gmm.units[u] for u in reversed(neighbors)] - labels = neighbors.numpy(force=True)[::-1] + units = [gmm[u] for u in reversed(neighbors)] + labels = neighbors[::-1] # means on core channels - chans = gmm.units[unit_id].snr.argmax() + chans = gmm[unit_id].snr.argmax() chans = torch.cdist(gmm.data.prgeom[chans[None]], gmm.data.prgeom) chans = chans.view(-1) (chans,) = torch.nonzero(chans <= gmm.data.core_radius, as_tuple=True) @@ -544,12 +575,12 @@ def __init__(self, n_neighbors=5, dist_vmax=1.0): def draw(self, panel, gmm, unit_id): neighbors = gmm_helpers.get_neighbors(gmm, unit_id) distances = gmm.distances( - units=[gmm.units[u] for u in neighbors], show_progress=False + units=[gmm[u] for u in neighbors], show_progress=False ) ax = analysis_plots.distance_matrix_dendro( panel, distances, - unit_ids=neighbors.numpy(force=True), + unit_ids=neighbors, dendrogram_linkage=None, show_unit_labels=True, vmax=self.dist_vmax, @@ -561,9 +592,9 @@ def draw(self, panel, gmm, unit_id): class NeighborBimodalities(GMMPlot): - kind = "merge" + kind = "bim" width = 4 - height = 8 + height = 9 def __init__(self, n_neighbors=5): self.n_neighbors = n_neighbors @@ -583,7 +614,7 @@ def draw(self, panel, gmm, unit_id): ) kept = labels >= 0 labels_ = np.full_like(labels, -1) - labels_[kept] = neighbors[labels[kept]].numpy(force=True) + labels_[kept] = neighbors[labels[kept]] labels = labels_ others = neighbors[1:] @@ -606,7 +637,10 @@ def draw(self, panel, gmm, unit_id): bimod_ax.set_title("bimodality computation", fontsize="small") if "in_pair_kept" not in bimod_info: - scatter_ax.text(0, 0, f"too few spikes") + scatter_ax.text( + 0.5, 0.5, f"too few spikes", transform=scatter_ax.transAxes, ha='center', va='center' + ) + continue else: c = np.atleast_2d(glasbey1024[labels[bimod_info["in_pair_kept"]]]) scatter_ax.scatter(bimod_info["xi"], bimod_info["xj"], s=3, lw=0, c=c) @@ -615,7 +649,7 @@ def draw(self, panel, gmm, unit_id): if "samples" not in bimod_info: bimod_ax.text( - 0, 0, f"too-small kept prop {bimod_info['keep_prop']:.2f}" + 0.5, 0.5, f"too-small\nkept prop {bimod_info['keep_prop']:.2f}", transform=bimod_ax.transAxes, ha='center', va='center' ) bimod_ax.axis("off") continue @@ -655,30 +689,56 @@ def draw(self, panel, gmm, unit_id): class NeighborInfoCriteria(GMMPlot): - kind = "merge" - width = 3 - height = 8 + kind = "bim" + width = 4 + height = 9 - def __init__(self, n_neighbors=5): + def __init__(self, n_neighbors=5, fit_by_avg=False): self.n_neighbors = n_neighbors + self.fit_by_avg = fit_by_avg def draw(self, panel, gmm, unit_id): neighbors = gmm_helpers.get_neighbors(gmm, unit_id) assert neighbors[0] == unit_id others = neighbors[1:] axes = panel.subplots(nrows=len(others), ncols=1) - histkw = dict(density=True, histtype="step", bins=128) - cstr = "{aic_full=:0.2f} {aic_merged=:0.2f} {bic_full=:0.2f} {bic_merged=:0.2f}" + histkw = dict(density=True, histtype="step", bins=128, log=True) + astr = "AICfull/merged: {aic_full:0.1f} / {aic_merged:0.1f}" + bstr = "BICfull/merged: {bic_full:0.1f} / {bic_merged:0.1f}" + lstr = "LLfull/merged: {full_loglik:0.1f} / {unit_loglik:0.1f}" + cstr = f"{astr}\n{bstr}\n{lstr}\n" + bbox = dict(facecolor='w', alpha=0.5, edgecolor="none") for ax, other_id in zip(axes, others): uids = [unit_id, other_id] - sns.despine(ax=ax, left=True, right=True, top=True) res = gmm.unit_group_criterion(uids, gmm.log_liks, debug=True) - sll = res["subunit_logliks"].tocsr() + sll = res["subunit_logliks"] + if not torch.is_tensor(sll): + sll = sll.tocsr() for row, uid in enumerate(uids): - ax.hist(sll[[row]].data, color=glasbey1024[uid], **histkw) - s = f"other={other_id} {cstr.format(res)}" - ax.set_title(s, fontsize="small") + if not torch.is_tensor(sll): + rowsll = sll[[row]].data + else: + rowsll = sll[row] + rowsll = rowsll[torch.isfinite(rowsll)] + if rowsll.numel(): + ax.hist(rowsll, color=glasbey1024[uid], **histkw) + ull = res["unit_logliks"] + if ull is not None: + ax.hist(ull[torch.isfinite(ull)], color="k", **histkw) + s = f"other={other_id}\n" + cstr.format_map(res) + aic_merge = res['aic_merged'] < res['aic_full'] + bic_merge = res['bic_merged'] < res['bic_full'] + ll_merge = res['unit_loglik'] > res['full_loglik'] + aicdif = res['aic_full'] - res['aic_merged'] + bicdif = res['bic_full'] - res['bic_merged'] + lldif = res['full_loglik'] - res['unit_loglik'] + s += f"aic: {aicdif:0.1f}, " + ("merge!" if aic_merge else "nope.") + "\n" + s += f"bic: {bicdif:0.1f}, " + ("merge!" if bic_merge else "nope.") + "\n" + s += f"ll: {lldif:0.1f}, " + ("merge!" if ll_merge else "nope.") + ax.text(0.05, 0.95, s, transform=ax.transAxes, va="top", bbox=bbox) ax.set_xlabel("log lik") + sns.despine(ax=ax, left=True, right=True, top=True) + # ax.set_yticks([]) # -- main api @@ -686,14 +746,17 @@ def draw(self, panel, gmm, unit_id): default_gmm_plots = ( TextInfo(), ISIHistogram(), + ISIHistogram(bin_ms=1, max_ms=50), ChansHeatmap(), - MStep(), + MStep(with_covs=False), + CovarianceResidual(), Likelihoods(), Amplitudes(), KMeansSplit(), NeighborMeans(), NeighborDistances(), NeighborBimodalities(), + NeighborInfoCriteria(), ) @@ -702,7 +765,7 @@ def make_unit_gmm_summary( unit_id, plots=default_gmm_plots, max_height=9, - figsize=(14, 11), + figsize=(15, 11), hspace=0.1, figure=None, **other_global_params, @@ -731,7 +794,7 @@ def make_all_gmm_summaries( save_folder, plots=default_gmm_plots, max_height=9, - figsize=(14, 11), + figsize=(15, 11), hspace=0.1, dpi=200, image_ext="png", @@ -746,7 +809,7 @@ def make_all_gmm_summaries( ): save_folder = Path(save_folder) if unit_ids is None: - unit_ids = gmm.unit_ids().numpy(force=True) + unit_ids = gmm.unit_ids() if n_units is not None and n_units < len(unit_ids): rg = np.random.default_rng(seed) unit_ids = rg.choice(unit_ids, size=n_units, replace=False) diff --git a/src/dartsort/vis/gmm_helpers.py b/src/dartsort/vis/gmm_helpers.py index adbe25c0..8f668ec2 100644 --- a/src/dartsort/vis/gmm_helpers.py +++ b/src/dartsort/vis/gmm_helpers.py @@ -7,10 +7,10 @@ def get_neighbors(gmm, unit_id, n_neighbors=5): - means, covs, logdets = gmm.stack_units(use_cache=True) - dists = gmm.units[unit_id].divergence(means, covs, logdets, kind=gmm.distance_metric) + ids, means, covs, logdets = gmm.stack_units(use_cache=True) + dists = gmm[unit_id].divergence(means, covs, logdets, kind=gmm.distance_metric) dists = dists.view(-1) - order = torch.argsort(dists) + order = ids[torch.argsort(dists)] assert order[0] == unit_id return order[:n_neighbors + 1] @@ -60,7 +60,6 @@ def plot_means(panel, prgeom, tpca, chans, units, labels, title="nearest neighbo means.append(tpca.force_reconstruct(mean).numpy(force=True)) colors = glasbey1024[labels] - print(f"{len(means)=} {labels.shape=} {colors.shape=}") geomplot( np.stack(means, axis=0), channels=chans[None] diff --git a/src/dartsort/vis/unit.py b/src/dartsort/vis/unit.py index 310c20ea..5325d941 100644 --- a/src/dartsort/vis/unit.py +++ b/src/dartsort/vis/unit.py @@ -51,7 +51,8 @@ def draw(self, panel, sorting_analysis, unit_id): axis.axis("off") msg = f"unit {unit_id}\n" - msg += f"feature source: {sorting_analysis.hdf5_path.name}\n" + if getattr(sorting_analysis, 'hdf5_path', None): + msg += f"feature source: {sorting_analysis.hdf5_path.name}\n" nspikes = sorting_analysis.spike_counts[ sorting_analysis.unit_ids == unit_id