From 8fdfc303a031e14cd9560bf7d0f2c1e1b620427c Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Fri, 15 Nov 2024 10:21:25 -0800 Subject: [PATCH] Mid-update... --- src/dartsort/cluster/gaussian_mixture.py | 204 ++++++++++++++++------- src/dartsort/cluster/merge.py | 8 +- src/dartsort/cluster/stable_features.py | 19 ++- 3 files changed, 159 insertions(+), 72 deletions(-) diff --git a/src/dartsort/cluster/gaussian_mixture.py b/src/dartsort/cluster/gaussian_mixture.py index f3f82166..f7f15a4e 100644 --- a/src/dartsort/cluster/gaussian_mixture.py +++ b/src/dartsort/cluster/gaussian_mixture.py @@ -1,5 +1,4 @@ import threading -import warnings from dataclasses import replace import numba @@ -8,7 +7,6 @@ import torch.nn.functional as F from joblib import Parallel, delayed from linear_operator import operators -from scipy.optimize import root_scalar from scipy.sparse import coo_array, csc_array from scipy.special import logsumexp from tqdm.auto import tqdm, trange @@ -150,6 +148,7 @@ def __init__( self.labels_lock = threading.Lock() self.lock = threading.Lock() self.storage = threading.local() + self.next_round_annotations = {} # -- headliners @@ -179,11 +178,10 @@ def em(self, n_iter=None, show_progress=True, final_e_step=True): self.cleanup(min_count=1) means, *_ = self.stack_units(mean_only=True) - # no need to clean units since they'll be overwritten immediately reas_count, log_liks, spike_logliks = self.e_step( show_progress=step_progress ) - logpx = logsumexp(spike_logliks) + meanlogpx = spike_logliks.mean() # M step: fit units based on responsibilities max_adif = self.m_step( @@ -198,7 +196,7 @@ def em(self, n_iter=None, show_progress=True, final_e_step=True): reas_prop = reas_count / self.data.n_spikes rpct = f"{100 * reas_prop:.1f}" adif = f"{max_adif:.2f}" - msg = f"EM[{nu=},out={opct}%,reas={rpct}%,dmu={adif},logpx={logpx:g}]" + msg = f"EM[{nu=},out={opct}%,reas={rpct}%,dmu={adif},meanlogpx={meanlogpx:g}]" its.set_description(msg) if reas_prop < self.em_converged_prop: @@ -213,21 +211,23 @@ def em(self, n_iter=None, show_progress=True, final_e_step=True): reas_count, log_liks, spike_logliks = self.e_step(clean_units=True) return log_liks - def e_step(self, show_progress=False, clean_units=True): + def e_step(self, show_progress=False): # E step: get responsibilities and update hard assignments log_liks = self.log_likelihoods(show_progress=show_progress) # replace log_liks by csc reas_count, spike_logliks, log_liks = self.reassign(log_liks) # garbage collection -- get rid of low count labels - log_liks = self.cleanup(log_liks, clean_units=clean_units) + log_liks = self.cleanup(log_liks) return reas_count, log_liks, spike_logliks def m_step(self, likelihoods=None, show_progress=False, prev_means=None): """Beware that this flattens the labels.""" - del self.units[:] + needs_append = not self.units unit_ids = self.unit_ids() + if not needs_append: + assert unit_ids.max() < len(self.units) if self.use_proportions and likelihoods is not None: self.update_proportions(likelihoods) @@ -236,14 +236,19 @@ def m_step(self, likelihoods=None, show_progress=False, prev_means=None): pool = Parallel(self.n_threads, backend="threading", return_as="generator") results = pool( - delayed(self.fit_unit)(j, likelihoods=likelihoods) for j in unit_ids + delayed(self.fit_unit)( + j, likelihoods=likelihoods, **self.next_round_annotations.get(j, {}) + ) + for j in unit_ids ) if show_progress: results = tqdm( results, desc="M step", unit="unit", total=len(unit_ids), **tqdm_kw ) - for unit in results: - self.units.append(unit) + if needs_append: + for j, unit in enumerate(zip(unit_ids, results)): + assert unit.annotations["unit_id"] == j + self.units.append(unit) if self.log_proportions is not None: # this is the index of the noise unit. it's got to be larger than # the largest unit index @@ -353,11 +358,18 @@ def update_proportions(self, log_liks): sample.sort() log_liks = log_liks[:, sample].tocoo() log_liks = coo_to_torch(log_liks, torch.float, copy_data=True) + + # log proportions are added to the likelihoods if self.log_proportions is not None: log_props_vec = self.log_proportions.cpu()[log_liks.indices()[0]] log_liks.values().add_(log_props_vec) - log_resps = torch.sparse.softmax(log_liks, dim=0) + + # softmax over units, logged + log_resps = torch.sparse.log_softmax(log_liks, dim=0) log_resps = coo_to_scipy(log_resps).tocsr() + + # now, we want the mean of the softmaxes over the spike dim (dim=1) + # since we have log softmaxes, that means need to exp, then take mean, then log log_props = logmeanexp(log_resps) self.log_proportions = torch.asarray( log_props, dtype=torch.float, device=self.data.device @@ -635,9 +647,9 @@ def random_spike_data( ): if indices is None: indices_full, indices = self.random_indices( - unit_id, - unit_ids, - max_size, + unit_id=unit_id, + unit_ids=unit_ids, + max_size=max_size, indices_full=indices_full, return_full_indices=True, ) @@ -674,12 +686,18 @@ def fit_unit( if verbose and weights is not None: print(f"{weights.sum()=} {weights.min()=} {weights.max()=}") unit_args = self.unit_args | unit_args - unit = GaussianUnit.from_features( - features, - weights, - neighborhoods=self.data.extract_neighborhoods, - **unit_args, - ) + if len(self.units) > unit_id: + unit = self.units[unit_id] + assert unit.annotations.get(unit_id, unit_id) == unit_id + unit.fit(features, weights, neighborhoods=self.data.extract_neighborhoods) + else: + unit = GaussianUnit.from_features( + features, + weights, + neighborhoods=self.data.extract_neighborhoods, + unit_id=unit_id, + **unit_args, + ) return unit def unit_log_likelihoods( @@ -687,6 +705,7 @@ def unit_log_likelihoods( unit_id=None, unit=None, spike_indices=None, + spikes=None, neighbs=None, ns=None, show_progress=False, @@ -709,6 +728,8 @@ def unit_log_likelihoods( core_channels = torch.arange(self.data.n_channels) else: core_channels = unit.channels + if spikes is not None: + spike_indices = spikes.indices inds_already = spike_indices is not None if neighbs is None or ns is None: if inds_already: @@ -740,7 +761,9 @@ def unit_log_likelihoods( ) for neighb_id, (neighb_chans, neighb_member_ixs) in jobs: - if inds_already: + if spikes is not None: + sp = spikes[neighb_member_ixs] + elif inds_already: sp = self.data.spike_data( spike_indices[neighb_member_ixs], with_channels=False, @@ -1047,60 +1070,94 @@ def unit_pair_bimodality( def unit_group_criterion( self, unit_ids, - likelihoods, + likelihoods=None, spikes_per_subunit=2048, + fit_type="refit_all", debug=False, ): """See if a single unit explains a group as far as AIC/BIC/MDL go.""" + assert fit_type in ("avg_preexisting", "refit_all") + unit_ids = torch.tensor(unit_ids) + # pick spikes for likelihood computation in_subunits = [ self.random_indices(u, max_size=spikes_per_subunit) for u in unit_ids ] in_any = torch.cat(in_subunits) in_any, in_order = torch.sort(in_any) + spikes_extract = self.data.spike_data(in_any) + spikes_core = self.data.spike_data(in_any, neighborhood="core") + n = in_any.numel() - # pick some of those spikes to fit for the new unit - sp = self.random_spike_data(indices_full=in_any) - - # get something sorta like fitting weights for the new unit - weights = self.get_fit_weights(unit_ids[0], sp.indices, likelihoods) - for u in unit_ids[1:]: - uweights = self.get_fit_weights(u, sp.indices, likelihoods) - weights = torch.logaddexp(weights, uweights, out=weights) - - # fit new unit - unit = self.fit_unit(features=sp, weights=weights) + if fit_type == "refit_all": + units = [] + subunit_logliks = spikes_core.features.new_full((len(unit_ids), len(in_any)), -torch.inf) + full_loglik = 0.0 + for i, k in enumerate(unit_ids): + u = self.fit_unit(unit_id=k, indices=in_any, likelihoods=likelihoods, features=spikes_extract) + units.append(u) + _, subunit_logliks[i] = self.unit_log_likelihoods(unit=u, spikes=spikes_core) + subunit_log_props = F.softmax(subunit_logliks, dim=0).mean(1).log_() + # loglik per spik + full_loglik = torch.logsumexp(subunit_logliks.T + subunit_log_props, dim=1).mean() + unit = self.fit_unit(indices=in_any, features=spikes_extract) + likelihoods = None + elif fit_type == "avg_preexisting": + unit = self.units[unit_ids[0]].avg_with(*[self.units[u] for u in unit_ids[1:]]) + if debug: + subunit_logliks = likelihoods[:, in_any][unit_ids] + full_loglik = marginal_loglik( + indices=in_any.numpy(force=True), + log_proportions=self.log_proportions, + log_likelihoods=likelihoods, + unit_ids=unit_ids, + ) + else: + assert False # extract likelihoods... no proportions! - full_loglik = marginal_loglik( - in_any.numpy(force=True), self.log_proportions, likelihoods, unit_ids - ) - unit_logliks = self.unit_log_likelihoods(unit=unit, spike_indices=in_any) - unit_loglik = torch.logsumexp(unit_logliks) + _, unit_logliks = self.unit_log_likelihoods(unit=unit, spikes=spikes_core) + unit_loglik = unit_logliks.mean() + + # parameter counting... since we use marginal likelihoods, I'm restricting + # the parameter counts to just the marginal set considered for each spike. + # then, aic and bic formulas are changed slightly below to match. + nids = self.data.core_neighborhoods.neighborhood_ids[in_any] + unique_nids, inverse = torch.unique(nids, return_inverse=True) + unique_chans = self.data.core_neighborhoods.neighborhoods[unique_nids] + unique_k_merged = unit.n_params(unique_chans) + unique_k_full = [self.units[u].n_params(unique_chans) for u in unit_ids] + unique_k_full = torch.stack(unique_k_full, dim=1) + k_merged = unique_k_merged[inverse] + k_full = unique_k_full[inverse] + + # for aic: k is avg + k_merged_avg = k_merged.sum() / n + k_full_avg = k_full.sum() / n + if self.use_proportions: + k_full_avg += len(unit_ids) - 1 # compute some criteria - n = in_any.numel() - params_unit = unit.n_params() - params_full = sum(self.units[u].n_params() for u in unit_ids) - if self.use_proportions: - params_full += len(unit_ids) - 1 - aic_full = 2 * params_full - 2 * full_loglik - aic_unit = 2 * params_unit - 2 * unit_loglik - bic_full = params_full * np.log(n) - 2 * full_loglik - bic_unit = params_unit * np.log(n) - 2 * unit_loglik - # mdl is equivalent to bic here + # actually computing AIC/BIC per example (divide by N) + # logliks here are already mean log liks. + aic_full = (2 * k_full_avg) / n - 2 * full_loglik + aic_merged = (2 * k_merged_avg) / n - 2 * unit_loglik + bic_full = (k_full_avg * np.log(n)) / n - 2 * full_loglik + bic_merged = (k_merged_avg * np.log(n)) / n - 2 * unit_loglik res = dict( - aic_full=aic_full, aic_unit=aic_unit, bic_full=bic_full, bic_unit=bic_unit + aic_full=aic_full, + aic_merged=aic_merged, + bic_full=bic_full, + bic_merged=bic_merged, ) if debug: debug_info = dict( full_loglik=full_loglik, unit_loglik=unit_loglik, unit_logliks=unit_logliks, - subunit_logliks=likelihoods[:, in_any][unit_ids], + subunit_logliks=subunit_logliks, indices=in_any, unit=unit, - sp=sp, ) res.update(debug_info) return res @@ -1261,6 +1318,8 @@ def __init__( channels_strategy_snr_min=50.0, prior_pseudocount=10, scale_mean: float = 0.1, + mean=None, + **annotations, ): super().__init__() self.rank = rank @@ -1275,6 +1334,9 @@ def __init__( self.scale_mean = scale_mean self.scale_alpha = float(prior_pseudocount) self.scale_beta = float(prior_pseudocount) / scale_mean + self.annotations = annotations + if mean is not None: + self.register_buffer("mean", mean) @classmethod def from_features( @@ -1307,17 +1369,25 @@ def from_features( self = self.to(features.features.device) return self - def n_params(self, on_channels=True): - p = 0 - nc = self.channels.numel() if on_channels else self.n_channels + def avg_with(self, *others): + new = self.__class__( + mean_kind=self.mean_kind, + cov_kind=self.cov_kind, + rank=self.rank, + n_channels=self.n_channels, + noise=self.noise, + ) + new.register_buffer("mean", (self.mean + sum(o.mean for o in others)) / (1 + len(others))) + new.register_buffer("channels", torch.cat([self.channels, *[o.channels for o in others]]).unique()) + assert self.cov_kind == "zero" + return new - # my mean + def n_params(self, channels=None, on_channels=True): + p = channels.new_zeros(len(channels)) if self.mean_kind == "full": - p += nc * self.rank - + p += self.rank * torch.isin(channels, self.channels).sum(1) # my cov assert self.cov_kind == "zero" - return p def fit(self, features: SpikeFeatures, weights: torch.Tensor, neighborhoods=None): @@ -1656,7 +1726,7 @@ def coo_to_scipy(coo_tensor): return coo_array((data, coords), shape=coo_tensor.shape) -def marginal_loglik(indices, log_proportions, log_likelihoods, unit_ids=None): +def marginal_loglik(indices, log_proportions, log_likelihoods, unit_ids=None, reduce="mean"): if unit_ids is not None: # renormalize log props log_proportions = log_proportions[unit_ids] @@ -1667,10 +1737,19 @@ def marginal_loglik(indices, log_proportions, log_likelihoods, unit_ids=None): log_likelihoods = log_likelihoods[unit_ids] # .indices == row inds for a csc_array + # since log_proportions and log_likelihoods both were sliced by the same + # unit_ids, the row indices match. props = log_proportions[log_likelihoods.indices] log_liks = props + log_likelihoods.data - return logsumexp(log_liks) + if reduce == "mean": + ll = log_liks.mean() + elif reduce == "sum": + ll = log_liks.sum() + else: + assert False + + return ll def loglik_reassign( @@ -1689,7 +1768,7 @@ def loglik_reassign( def logmeanexp(x_csr): - """Log of mean of exp of x_csr's rows (mean over columns) + """Log of mean of exp in x_csr's rows (mean over columns) Sparse zeros are treated as negative infinities. """ @@ -1698,6 +1777,7 @@ def logmeanexp(x_csr): for j in range(x_csr.shape[0]): row = x_csr[[j]] # missing vals in the row are -inf, exps are 0s, so ignore in sum + # dividing by N is subtracting log N log_mean_exp[j] = logsumexp(row.data) - log_N return log_mean_exp diff --git a/src/dartsort/cluster/merge.py b/src/dartsort/cluster/merge.py index 11f32704..cafe71c4 100644 --- a/src/dartsort/cluster/merge.py +++ b/src/dartsort/cluster/merge.py @@ -633,9 +633,11 @@ def combine_templates(template_data_a, template_data_b): spike_counts = np.concatenate( (template_data_a.spike_counts, template_data_b.spike_counts) ) - spike_counts_by_channel = np.concatenate( - (template_data_a.spike_counts_by_channel, template_data_b.spike_counts_by_channel) - ) + spike_counts_by_channel = None + if template_data_a.spike_counts_by_channel is not None: + spike_counts_by_channel = np.concatenate( + (template_data_a.spike_counts_by_channel, template_data_b.spike_counts_by_channel) + ) template_data = TemplateData( templates=templates, unit_ids=unit_ids, diff --git a/src/dartsort/cluster/stable_features.py b/src/dartsort/cluster/stable_features.py index 0c2c7b30..9bbf7a4a 100644 --- a/src/dartsort/cluster/stable_features.py +++ b/src/dartsort/cluster/stable_features.py @@ -97,6 +97,7 @@ def from_sorting( core_radius=35.0, subsampling_rg=0, max_n_spikes=np.inf, + discard_triaged=False, interpolation_sigma=20.0, interpolation_method="kriging", motion_depth_mode="channel", @@ -111,18 +112,22 @@ def from_sorting( device = torch.device(device) # which spikes to keep? - if len(sorting) > max_n_spikes: - rg = np.random.default_rng(subsampling_rg) - kept_indices = rg.choice(len(sorting), size=max_n_spikes, replace=False) - kept_indices.sort() - keep = np.zeros(len(sorting), dtype=bool) - keep[kept_indices] = 1 - keep_select = kept_indices + if discard_triaged: + keep = sorting.labels >= 0 + keep_select = kept_indices = np.flatnonzero(keep) else: keep = np.ones(len(sorting), dtype=bool) kept_indices = np.arange(len(sorting)) keep_select = slice(None) + if kept_indices.size > max_n_spikes: + rg = np.random.default_rng(subsampling_rg) + kept_kept = rg.choice(kept_indices.size, size=max_n_spikes, replace=False) + kept_kept.sort() + keep[kept_indices] = 0 + keep[kept_indices[kept_kept]] = 1 + keep_select = kept_indices = kept_indices[kept_kept] + # load information not stored directly on the sorting with h5py.File(sorting.parent_h5_path, "r", locking=False) as h5: geom = h5["geom"][:] extract_channel_index = h5["channel_index"][:]