From 68678e70b9bb18eac303593600bc33f5b613efb5 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Tue, 27 Aug 2024 10:58:41 -0400 Subject: [PATCH 1/8] Channel logic --- src/dartsort/cluster/postprocess.py | 89 +++++++++ src/dartsort/peel/matching.py | 26 ++- src/dartsort/peel/peel_base.py | 10 +- src/dartsort/transform/temporal_pca.py | 1 + src/dartsort/util/spiketorch.py | 4 + src/dartsort/vis/gmm.py | 247 ++++++++++++++++++++----- src/dartsort/vis/unit.py | 26 ++- 7 files changed, 336 insertions(+), 67 deletions(-) diff --git a/src/dartsort/cluster/postprocess.py b/src/dartsort/cluster/postprocess.py index 00e79a21..b6c2cfa1 100644 --- a/src/dartsort/cluster/postprocess.py +++ b/src/dartsort/cluster/postprocess.py @@ -1,6 +1,8 @@ from dataclasses import replace import numpy as np +import torch +import torch.nn.functional as F from .. import config from ..templates import TemplateData @@ -69,3 +71,90 @@ def realign_and_chuck_noisy_template_units( ) return new_sorting, new_template_data + + +def template_collision_scores( + recording, + template_data, + svd_compression_rank=20, + temporal_upsampling_factor=8, + min_channel_amplitude=0.0, + amplitude_scaling_variance=0.0, + amplitude_scaling_boundary=0.5, + trough_offset_samples=42, + device=None, + max_n_colliding=5, + threshold=None, + save_folder=None, +): + from ..peel.matching import ObjectiveUpdateTemplateMatchingPeeler + + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + device = torch.device(device) + if threshold is None or threshold < 1: + factor = 0.9 + if threshold < 1: + factor = threshold + threshold = factor * np.square(template_data.templates).sum((1, 2)).min() + print(f"Using {threshold=:0.3f} for decollision") + + matcher = ObjectiveUpdateTemplateMatchingPeeler( + recording, + template_data, + channel_index=None, + featurization_pipeline=None, + motion_est=None, + svd_compression_rank=svd_compression_rank, + temporal_upsampling_factor=temporal_upsampling_factor, + min_channel_amplitude=min_channel_amplitude, + refractory_radius_frames=10, + amplitude_scaling_variance=amplitude_scaling_variance, + amplitude_scaling_boundary=amplitude_scaling_boundary, + conv_ignore_threshold=0.0, + coarse_approx_error_threshold=0.0, + trough_offset_samples=template_data.trough_offset_samples, + threshold=threshold, + # chunk_length_samples=30_000, + # n_chunks_fit=40, + # max_waveforms_fit=50_000, + # n_waveforms_fit=20_000, + # fit_subsampling_random_state=0, + # fit_sampling="random", + max_iter=max_n_colliding, + dtype=torch.float, + ) + matcher.to(device) + save_folder.mkdir(exist_ok=True) + matcher.precompute_peeling_data(save_folder) + matcher.to(device) + + n = len(template_data.templates) + scores = np.zeros(n) + matches = [] + unit_mask = torch.arange(n, device=device) + for j in range(n): + mask = template_data.spike_counts_by_channel[j] > 0 + template = template_data.templates[j][:, mask] + + mask = torch.from_numpy(mask) + compressed_template_data = matcher.templates_at_time(0.0, spatial_mask=mask) + traces = F.pad( + torch.from_numpy(template).to(device), + (0, 0, *(2 * [template_data.spike_length_samples])), + ) + res = matcher.match_chunk( + traces, + compressed_template_data, + trough_offset_samples=42, + left_margin=0, + right_margin=0, + threshold=30, + return_collisioncleaned_waveforms=False, + return_residual=True, + unit_mask=unit_mask != j, + ) + resid = res["residual"][template_data.spike_length_samples:2*template_data.spike_length_samples] + scores[j] = resid.square().sum() / traces.square().sum() + matches.append(res["labels"].numpy(force=True)) + return scores, matches \ No newline at end of file diff --git a/src/dartsort/peel/matching.py b/src/dartsort/peel/matching.py index 5d91b3d7..6dba2593 100644 --- a/src/dartsort/peel/matching.py +++ b/src/dartsort/peel/matching.py @@ -447,7 +447,7 @@ def peel_chunk( return match_results - def templates_at_time(self, t_s): + def templates_at_time(self, t_s, spatial_mask=None): """Handle drift -- grab the right spatial neighborhoods.""" pconvdb = self.pairwise_conv_db pitch_shifts_a = pitch_shifts_b = None @@ -457,6 +457,7 @@ def templates_at_time(self, t_s): ): pconvdb.to(self.objective_spatial_components.device) if self.is_drifting: + assert spatial_mask is None pitch_shifts_b, cur_spatial = template_util.templates_at_time( t_s, self.spatial_components, @@ -505,6 +506,9 @@ def templates_at_time(self, t_s): else: cur_spatial = self.spatial_components cur_obj_spatial = self.objective_spatial_components + if spatial_mask is not None: + cur_spatial = cur_spatial[:, :, spatial_mask] + cur_obj_spatial = cur_obj_spatial[:, :, spatial_mask] max_channels = self.registered_template_ampvecs.argmax(1) # if not pconvdb._is_torch: @@ -542,8 +546,10 @@ def match_chunk( left_margin=0, right_margin=0, threshold=30, + return_collisioncleaned_waveforms=True, return_residual=False, return_conv=False, + unit_mask=None, ): """Core peeling routine for subtraction""" # initialize residual, it needs to be padded to support our channel @@ -588,6 +594,7 @@ def match_chunk( padded_objective, refrac_mask, compressed_template_data, + unit_mask=unit_mask, ) if new_peaks is None: break @@ -627,12 +634,14 @@ def match_chunk( peaks.subset(*torch.nonzero(valid, as_tuple=True), sort=True) # extract collision-cleaned waveforms on small neighborhoods - channels, waveforms = compressed_template_data.get_collisioncleaned_waveforms( - residual_padded, - peaks, - self.channel_index, - spike_length_samples=self.spike_length_samples, - ) + channels = waveforms = None + if return_collisioncleaned_waveforms: + channels, waveforms = compressed_template_data.get_collisioncleaned_waveforms( + residual_padded, + peaks, + self.channel_index, + spike_length_samples=self.spike_length_samples, + ) res = dict( n_spikes=peaks.n_spikes, @@ -658,6 +667,7 @@ def find_peaks( padded_objective, refrac_mask, compressed_template_data, + unit_mask=None, ): # update the coarse objective torch.add( @@ -671,6 +681,8 @@ def find_peaks( objective = (padded_objective + refrac_mask)[ :-1, self.obj_pad_len : -self.obj_pad_len ] + if unit_mask is not None: + objective[torch.logical_not(unit_mask)] = -torch.inf # formerly used detect_and_deduplicate, but that was slow. objective_max, max_obj_template = objective.max(dim=0) times = argrelmax(objective_max, self.spike_length_samples, self.threshold) diff --git a/src/dartsort/peel/peel_base.py b/src/dartsort/peel/peel_base.py index 9710b628..0b46be41 100644 --- a/src/dartsort/peel/peel_base.py +++ b/src/dartsort/peel/peel_base.py @@ -41,7 +41,6 @@ def __init__( fit_subsampling_random_state=0, dtype=torch.float, ): - assert recording.get_num_channels() == channel_index.shape[0] if recording.get_num_segments() > 1: raise ValueError( "Peeling does not yet support multi-segment recordings." @@ -56,7 +55,9 @@ def __init__( fit_subsampling_random_state ) self.dtype = dtype - self.register_buffer("channel_index", channel_index) + if channel_index is not None: + self.register_buffer("channel_index", channel_index) + assert recording.get_num_channels() == channel_index.shape[0] self.n_waveforms_fit = n_waveforms_fit self.fit_sampling = fit_sampling self.fit_max_reweighting = fit_max_reweighting @@ -70,8 +71,11 @@ def __init__( self.fixed_output_data = [ ("sampling_frequency", self.recording.get_sampling_frequency()), ("geom", self.recording.get_channel_locations()), - ("channel_index", self.channel_index.numpy(force=True).copy()), ] + if channel_index is not None: + self.fixed_output_data.append( + ("channel_index", self.channel_index.numpy(force=True).copy()), + ) # -- main functions for users to call # in practice users will interact with the functions `subtract(...)` in diff --git a/src/dartsort/transform/temporal_pca.py b/src/dartsort/transform/temporal_pca.py index fd171561..83203ada 100644 --- a/src/dartsort/transform/temporal_pca.py +++ b/src/dartsort/transform/temporal_pca.py @@ -146,6 +146,7 @@ def to_sklearn(self): pca.components_ = self.components.numpy() if hasattr(self, "whitener"): pca.explained_variance_ = np.square(self.whitener.numpy()) + pca.temporal_slice = self.temporal_slice return pca diff --git a/src/dartsort/util/spiketorch.py b/src/dartsort/util/spiketorch.py index 95483472..1bef22e5 100644 --- a/src/dartsort/util/spiketorch.py +++ b/src/dartsort/util/spiketorch.py @@ -38,6 +38,10 @@ def ravel_multi_index(multi_index, dims): raveled_indices : LongTensor Indices into the flattened tensor of shape `dims` """ + if len(dims) == 1: + assert multi_index.ndim == 1 + return multi_index + assert len(multi_index) == len(dims) # collect multi indices diff --git a/src/dartsort/vis/gmm.py b/src/dartsort/vis/gmm.py index 3675d0ec..9b91feda 100644 --- a/src/dartsort/vis/gmm.py +++ b/src/dartsort/vis/gmm.py @@ -36,12 +36,17 @@ def draw(self, panel, gmm, unit_id): class DPCSplitPlot(GMMPlot): kind = "triplet" width = 5 - height = 2 + height = 5 - def __init__(self, spike_kind="residual", feature="pca"): + def __init__(self, spike_kind="train", feature="pca", inherit_chans=True, common_chans=True, dist_vmax=1., cmap=plt.cm.rainbow): self.spike_kind = spike_kind assert feature in ("pca", "spread_amp") self.feature = feature + self.inherit_chans = inherit_chans + self.common_chans = common_chans + self.dist_vmax = dist_vmax + self.cmap = cmap + self.show_values = True def draw(self, panel, gmm, unit_id): if self.feature == "pca": @@ -80,7 +85,7 @@ def draw(self, panel, gmm, unit_id): spread = (channel_norms * logs).sum(1) z = np.c_[amp.numpy(force=True), spread.numpy(force=True)] z /= mad(z, 0) - + if in_unit is None: ax = panel.subplots() ax.set_title("no features") @@ -104,19 +109,141 @@ def draw(self, panel, gmm, unit_id): return ru = np.unique(dens["labels"]) - panel, axes = analysis_plots.density_peaks_study( + panel_top, panel_bottom = panel.subfigures(nrows=2, height_ratios=[.6, 1]) + panel_top, axes = analysis_plots.density_peaks_study( z, dens, s=10, idx=idx, inv=inv, - fig=panel, + fig=panel_top, ) axes[-1].set_title(f"n={(ru>=0).sum()}", fontsize=8) axes[0].set_title(self.spike_kind) if self.feature == "spread_amp": axes[0].set_xlabel("spread") axes[0].set_ylabel("amp") + + axes = panel_bottom.subplots(ncols=2) + axes = {'d': axes[0], 'e': axes[1]} + labels = dens['labels'][inv] + in_unit = torch.from_numpy(in_unit) + ids = np.unique(labels) + ids = ids[ids >= 0] + new_units = [] + chans_kw = {} + if self.inherit_chans: + chans_kw = dict( + channels=gmm[unit_id].channels, + max_channel=gmm[unit_id].max_channel, + ) + for j, label in enumerate(ids): + u = spike_interp.InterpUnit( + do_interp=False, + **gmm.unit_kw, + ) + inu = in_unit[np.flatnonzero(labels == label)] + inu, train_data = gmm.get_training_data( + unit_id, + waveform_kind="original", + in_unit=inu, + sampling_method=gmm.sampling_method, + ) + u.fit_center( + **train_data, + padded_geom=gmm.data.padded_registered_geom, + show_progress=False, + **chans_kw, + ) + new_units.append(u) + + ju = [(j, u) for j, u in enumerate(new_units) if u.n_chans_unit] + + # plot new unit maxchan wfs and old one in black + ax = axes["d"] + ax.axhline(0, c="k", lw=0.8) + all_means = [] + for j, unit in ju: + if unit.do_interp: + times = unit.interp.grid.squeeze() + if self.fitted_only: + times = times[unit.interp.grid_fitted] + else: + times = torch.tensor([sum(gmm.t_bounds) / 2]).to(gmm.device) + times = torch.atleast_1d(times) + + chans = torch.full((times.numel(),), unit.max_channel, device=times.device) + means = unit.get_means(times).to(gmm.device) + if j > 0: + all_means.append(means.mean(0)) + means = unit.to_waveform_channels(means, waveform_channels=chans[:, None]) + means = means[..., 0] + means = gmm.data.tpca._inverse_transform_in_probe(means) + means = means.numpy(force=True) + color = glasbey1024[j] + + lines = np.stack( + (np.broadcast_to(np.arange(means.shape[1])[None], means.shape), means), + axis=-1, + ) + ax.add_collection(LineCollection(lines, colors=color, lw=1)) + ax.autoscale_view() + ax.set_xticks([]) + ax.spines[["top", "right", "bottom"]].set_visible(False) + + # plot distance matrix + kind = gmm.merge_metric + min_overlap = gmm.min_overlap + subset_channel_index = None + if gmm.merge_on_waveform_radius: + subset_channel_index = gmm.data.registered_reassign_channel_index + nu = len(new_units) + divergences = torch.full((nu, nu), torch.nan) + for i, ua in ju: + # print(f"{i=} {ua.n_chans_unit=} {ua.channels.tolist()=}") + for j, ub in ju: + # print(f"{j=} {ub.n_chans_unit=} {ub.channels.tolist()=}") + if i == j: + divergences[i, j] = 0 + continue + divergences[i, j] = ua.divergence( + ub, + kind=kind, + min_overlap=min_overlap, + subset_channel_index=subset_channel_index, + common_chans=self.common_chans, + ) + dists = divergences.numpy(force=True) + + axis = axes["e"] + im = axis.imshow( + dists, + vmin=0, + vmax=self.dist_vmax, + cmap=self.cmap, + origin="lower", + interpolation="none", + ) + if self.show_values: + for (j, i), label in np.ndenumerate(dists): + axis.text( + i, + j, + f"{label:.2f}".lstrip("0"), + ha="center", + va="center", + clip_on=True, + fontsize=5, + ) + panel.colorbar(im, ax=axis, shrink=0.3) + axis.set_xticks(range(len(new_units))) + axis.set_yticks(range(len(new_units))) + for i, (tx, ty) in enumerate( + zip(axis.xaxis.get_ticklabels(), axis.yaxis.get_ticklabels()) + ): + tx.set_color(glasbey1024[i]) + ty.set_color(glasbey1024[i]) + axis.set_title(gmm.merge_metric) class ZipperSplitPlot(GMMPlot): @@ -367,6 +494,7 @@ def draw(self, panel, gmm, unit_id): ) u.fit_center( **train_data, + padded_geom=gmm.data.padded_registered_geom, show_progress=False, weights=w, **chans_kw, @@ -546,7 +674,14 @@ def draw(self, panel, gmm, unit_id): unique_ixs, counts = np.unique(ixs, return_counts=True) ax = panel.subplots() xy = gmm.data.registered_geom.numpy(force=True) - ax.scatter(*xy[unique_ixs].T, c=counts, lw=0, cmap=self.cmap) + s = ax.scatter(*xy[unique_ixs].T, c=counts, lw=0, cmap=self.cmap) + plt.colorbar(s, ax=ax, shrink=0.3) + ax.scatter( + *xy[gmm[unit_id].channels_valid.numpy(force=True)].T, + color="r", + lw=1, + fc="none", + ) ax.scatter( *xy[np.atleast_1d(gmm[unit_id].max_channel.numpy(force=True))].T, color="g", @@ -558,51 +693,47 @@ class AmplitudesOverTimePlot(GMMPlot): kind = "embeds" width = 5 height = 3 + + def __init__(self, kinds=("recon", 'model'), colors="bkrg"): + self.kinds = kinds + self.colors = dict(zip(kinds, colors)) def draw(self, panel, gmm, unit_id): in_unit, utd = gmm.get_training_data(unit_id, waveform_kind="reassign") - amps = utd["static_amp_vecs"].numpy(force=True) - amps = np.nanmax(amps, axis=1) + + show = {} + if "feat" in self.kinds: + amps = utd["static_amp_vecs"].numpy(force=True) + show['feat'] = np.nanmax(amps, axis=1) amps2 = utd["waveforms"] - n, r, c = amps2.shape - amps2 = gmm.data.tpca._inverse_transform_in_probe( - amps2.permute(0, 2, 1).reshape(n * c, r) - ) - amps2 = amps2.reshape(n, -1, c).permute(0, 2, 1) - amps2 = np.nan_to_num(amps2.numpy(force=True)).ptp(axis=(1, 2)) - recons = gmm[unit_id].get_means(utd["times"]) - recons = gmm[unit_id].to_waveform_channels( - recons, waveform_channels=utd["waveform_channels"] - ) - n, r, c = recons.shape - recons = gmm.data.tpca._inverse_transform_in_probe( - recons.permute(0, 2, 1).reshape(n * c, r) - ) - recons = recons.reshape(n, -1, c).permute(0, 2, 1) - - recon_amps = np.nan_to_num(recons.numpy(force=True)).ptp(axis=(1, 2)) + if 'recon' in self.kinds: + n, r, c = amps2.shape + amps2 = gmm.data.tpca._inverse_transform_in_probe( + amps2.permute(0, 2, 1).reshape(n * c, r) + ) + amps2 = amps2.reshape(n, -1, c).permute(0, 2, 1) + show['recon'] = np.nan_to_num(amps2.numpy(force=True)).ptp(axis=(1, 2)) + if 'model' in self.kinds: + recons = gmm[unit_id].get_means(utd["times"]) + recons = gmm[unit_id].to_waveform_channels( + recons, waveform_channels=utd["waveform_channels"] + ) + n, r, c = recons.shape + recons = gmm.data.tpca._inverse_transform_in_probe( + recons.permute(0, 2, 1).reshape(n * c, r) + ) + recons = recons.reshape(n, -1, c).permute(0, 2, 1) + show['model'] = np.nan_to_num(recons.numpy(force=True)).ptp(axis=(1, 2)) + if 'maxchan_energy' in self.kinds: + wfs = torch.nan_to_num(utd['waveforms']) + wfs = torch.linalg.norm(wfs, dim=1) + wfs = wfs.max(dim=1).values + show['maxchan_energy'] = wfs.numpy(force=True) ax = panel.subplots() - ax.scatter( - utd["times"].numpy(force=True), - amps, - s=3, - c="b", - lw=0, - label="observed (final)", - ) - ax.scatter( - utd["times"].numpy(force=True), - amps2, - s=3, - c="k", - lw=0, - label="observed (feat)", - ) - ax.scatter( - utd["times"].numpy(force=True), recon_amps, s=3, c="r", lw=0, label="model" - ) - # ax.legend(loc="upper left", ncols=3) + t = utd["times"].numpy(force=True) + for kind, a in show.items(): + ax.scatter(t, a, c=self.colors[kind], s=3, lw=0, label=kind) ax.set_ylabel("amplitude") @@ -705,8 +836,7 @@ def draw(self, panel, gmm, unit_id): badnesses = {k: v.numpy(force=True) for k, v in badnesses.items()} # amplitudes - amps = utd["static_amp_vecs"].numpy(force=True) - amps = np.nanmax(amps, axis=1) + amps = torch.nan_to_num(torch.linalg.norm(utd['waveforms'], dim=1)).max(dim=1).values.numpy(force=True) axes = panel.subplots(ncols=z.shape[1] + 1, sharey=True) for ax, feat, featname in zip( @@ -838,11 +968,12 @@ class InputWaveformsMultiChanPlot(GMMPlot): width = 5 height = 5 - def __init__(self, cmap=plt.cm.rainbow, max_plot=250, rg=0, time_range=None): + def __init__(self, cmap=plt.cm.rainbow, max_plot=250, rg=0, time_range=None, imputation_kind=None): self.cmap = cmap self.max_plot = max_plot self.rg = np.random.default_rng(0) self.time_range = time_range + self.imputation_kind = imputation_kind def draw(self, panel, gmm, unit_id): in_unit, utd = gmm.get_training_data(unit_id) @@ -857,6 +988,22 @@ def draw(self, panel, gmm, unit_id): waveform_channels = utd["waveform_channels"][wh] waveforms = utd["waveforms"][wh] n, r, c = waveforms.shape + if self.imputation_kind: + waveforms = gmm[unit_id].impute( + times, + waveforms, + waveform_channels, + waveform_channel_index=utd["waveform_channel_index"], + imputation_kind=self.imputation_kind, + padded_registered_geom=gmm.data.padded_registered_geom, + ) + waveforms = waveforms.reshape(n, r, gmm[unit_id].n_chans_unit) + waveform_channels = gmm[unit_id].channels.numpy(force=True) + waveform_channels = np.broadcast_to(waveform_channels[None], (n, *waveform_channels.shape)) + else: + waveform_channels = waveform_channels.numpy(force=True) + n, r, c = waveforms.shape + waveforms = gmm.data.tpca._inverse_transform_in_probe( waveforms.permute(0, 2, 1).reshape(n * c, r) ) @@ -870,7 +1017,7 @@ def draw(self, panel, gmm, unit_id): ax = panel.subplots() geomplot( waveforms, - channels=waveform_channels.numpy(force=True), + channels=waveform_channels, geom=gmm.data.registered_geom.numpy(force=True), max_abs_amp=maa, lw=1, @@ -1705,7 +1852,7 @@ def draw(self, panel, gmm, unit_id): colors = glasbey1024[neighbor_ids % len(glasbey1024)] axes = panel.subplots( - nrows=self.n_neighbors, + nrows=self.n_neighbors + 1, ncols=2, sharex=True, sharey=False, diff --git a/src/dartsort/vis/unit.py b/src/dartsort/vis/unit.py index ae13fb07..25349af4 100644 --- a/src/dartsort/vis/unit.py +++ b/src/dartsort/vis/unit.py @@ -416,9 +416,16 @@ def get_waveforms(self, sorting_analysis, unit_id): def draw(self, panel, sorting_analysis, unit_id, axis=None): if axis is None: axis = panel.subplots() - which, waveforms, max_chan, geom, ci = self.get_waveforms( + tslice, which, waveforms, max_chan, geom, ci = self.get_waveforms( sorting_analysis, unit_id ) + trough_offset_samples = self.trough_offset_samples + spike_length_samples = self.spike_length_samples + if tslice.start is not None: + trough_offset_samples = self.trough_offset_samples - tslice.start + spike_length_samples = self.spike_length_samples - tslice.start + if tslice.stop is not None: + spike_length_samples = tslice.stop - tslice.start max_abs_amp = None show_template = self.show_template @@ -439,8 +446,8 @@ def draw(self, panel, sorting_analysis, unit_id, axis=None): templates = trim_waveforms( templates, old_offset=sorting_analysis.coarse_template_data.trough_offset_samples, - new_offset=self.trough_offset_samples, - new_length=self.spike_length_samples, + new_offset=trough_offset_samples, + new_length=spike_length_samples, ) max_abs_amp = self.max_abs_template_scale * np.nanmax(np.abs(templates)) @@ -454,8 +461,8 @@ def draw(self, panel, sorting_analysis, unit_id, axis=None): suptemplates = trim_waveforms( suptemplates, old_offset=sorting_analysis.template_data.trough_offset_samples, - new_offset=self.trough_offset_samples, - new_length=self.spike_length_samples, + new_offset=trough_offset_samples, + new_length=spike_length_samples, ) show_superres_templates = suptemplates.shape[0] > 1 max_abs_amp = self.max_abs_template_scale * np.nanmax(np.abs(suptemplates)) @@ -544,7 +551,7 @@ class RawWaveformPlot(WaveformPlot): wfs_kind = "raw wfs" def get_waveforms(self, sorting_analysis, unit_id): - return sorting_analysis.unit_raw_waveforms( + return slice(None), *sorting_analysis.unit_raw_waveforms( unit_id, template_index=self.template_index, max_count=self.count, @@ -559,7 +566,12 @@ class TPCAWaveformPlot(WaveformPlot): wfs_kind = "coll.-cl. tpca wfs" def get_waveforms(self, sorting_analysis, unit_id): - return sorting_analysis.unit_tpca_waveforms( + temporal_slice = getattr( + sorting_analysis.sklearn_tpca, + "temporal_slice", + slice(None), + ) + return temporal_slice, *sorting_analysis.unit_tpca_waveforms( unit_id, template_index=self.template_index, max_count=self.count, From 6f8bf098bd0c7603587bb3a8a7e6f4e8d1171970 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Tue, 3 Sep 2024 08:30:00 -0700 Subject: [PATCH 2/8] Speed up PCA using torch's low rank stuff --- src/dartsort/transform/temporal_pca.py | 71 +++++++++++--------------- 1 file changed, 30 insertions(+), 41 deletions(-) diff --git a/src/dartsort/transform/temporal_pca.py b/src/dartsort/transform/temporal_pca.py index 83203ada..568003aa 100644 --- a/src/dartsort/transform/temporal_pca.py +++ b/src/dartsort/transform/temporal_pca.py @@ -22,6 +22,7 @@ def __init__( name=None, name_prefix="", temporal_slice=None, + n_oversamples=10, ): if fit_radius is not None: if geom is None or channel_index is None: @@ -42,8 +43,7 @@ def __init__( self.centered = centered self.whiten = whiten self.temporal_slice = temporal_slice - if whiten: - assert self.centered + self.n_oversamples = n_oversamples def fit(self, waveforms, max_channels): waveforms = self._temporal_slice(waveforms) @@ -53,48 +53,38 @@ def fit(self, waveforms, max_channels): waveforms, train_channel_index = channel_subset_by_radius( waveforms, max_channels, - self.channel_index.cpu().numpy(), - self.geom.cpu().numpy(), + self.channel_index, + self.geom, self.fit_radius, ) _, waveforms_fit = get_channels_in_probe( waveforms, max_channels, train_channel_index ) - waveforms_fit = waveforms_fit.cpu().numpy() if self.centered: - pca = PCA( - self.rank, - random_state=self.random_state, - whiten=self.whiten, - copy=False, # don't need to worry here - ) - pca.fit(waveforms_fit) - self.register_buffer( - "mean", torch.tensor(pca.mean_).to(waveforms.dtype) - ) - self.register_buffer( - "components", - torch.tensor(pca.components_).to(waveforms.dtype), - ) - self.register_buffer( - "whitener", - torch.sqrt( - torch.tensor(pca.explained_variance_).to(waveforms.dtype) - ), - ) + mean = waveforms_fit.mean(0) else: - tsvd = TruncatedSVD(self.rank, random_state=self.random_state) - tsvd.fit(waveforms_fit) - self.register_buffer( - "mean", - torch.zeros(waveforms_fit[0].shape, dtype=waveforms.dtype), - ) - self.register_buffer( - "components", - torch.tensor(tsvd.components_).to(waveforms.dtype), - ) - + mean = torch.zeros_like(waveforms_fit[0]) + + q = self.rank + self.n_oversamples + n_samples, n_times = waveforms_fit.shape + assert q < min(n_samples, n_times) + M = mean if self.centered else None + + # 7 is based on sklearn's auto choice + U, S, V = torch.svd_lowrank(waveforms_fit, q=q, M=M, niter=7) + U = U[..., :self.rank] + S = S[..., :self.rank] + V = V[..., :self.rank] + + # loadings = U * S[..., None, :] + components = V.T.contiguous() + explained_variance = (S**2) / (n_samples - 1) + whitener = torch.sqrt(explained_variance) + + self.register_buffer("mean", mean) + self.register_buffer("components", components) + self.register_buffer("whitener", whitener) self._needs_fit = False def needs_fit(self): @@ -142,11 +132,10 @@ def to_sklearn(self): random_state=self.random_state, whiten=self.whiten, ) - pca.mean_ = self.mean.numpy() - pca.components_ = self.components.numpy() - if hasattr(self, "whitener"): - pca.explained_variance_ = np.square(self.whitener.numpy()) - pca.temporal_slice = self.temporal_slice + pca.mean_ = self.mean.numpy(force=True) + pca.components_ = self.components.numpy(force=True) + pca.explained_variance_ = np.square(self.whitener.numpy(force=True)) + pca.temporal_slice = self.temporal_slice # this is not standard return pca From 7310853f531d7ba337ca1cb43287922659d3d113 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Tue, 3 Sep 2024 08:30:10 -0700 Subject: [PATCH 3/8] Never lock --- src/dartsort/cluster/initial.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dartsort/cluster/initial.py b/src/dartsort/cluster/initial.py index c52c5892..7fdaff2c 100644 --- a/src/dartsort/cluster/initial.py +++ b/src/dartsort/cluster/initial.py @@ -60,7 +60,7 @@ def cluster_chunk( amps = getattr(sorting, amplitudes_dataset_name) if recording is None: - with h5py.File(sorting.parent_h5_path, "r") as h5: + with h5py.File(sorting.parent_h5_path, "r", locking=False) as h5: geom = h5["geom"][:] else: geom = recording.get_channel_locations() From d7d875ff95c2bace0a9f8098ea9c4c424c26c8a5 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Tue, 3 Sep 2024 08:30:28 -0700 Subject: [PATCH 4/8] Export postprocess --- src/dartsort/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dartsort/__init__.py b/src/dartsort/__init__.py index 68bf5011..25874b2d 100644 --- a/src/dartsort/__init__.py +++ b/src/dartsort/__init__.py @@ -12,6 +12,6 @@ from .util.analysis import DARTsortAnalysis from .util.data_util import DARTsortSorting from .util.waveform_util import make_channel_index -from .cluster import merge +from .cluster import merge, postprocess __version__ = importlib.metadata.version("dartsort") From 2bfe1d23c83973ef26fa3d03eef7280de41ca8bb Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Tue, 3 Sep 2024 08:30:48 -0700 Subject: [PATCH 5/8] Template std devs --- src/dartsort/templates/get_templates.py | 1 + src/dartsort/templates/templates.py | 26 ++++++++++++++++++------- 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/src/dartsort/templates/get_templates.py b/src/dartsort/templates/get_templates.py index 47fdf041..14255ce4 100644 --- a/src/dartsort/templates/get_templates.py +++ b/src/dartsort/templates/get_templates.py @@ -42,6 +42,7 @@ def get_templates( n_jobs=0, dtype=np.float32, show_progress=True, + with_std_dev=False, device=None, ): """Raw, denoised, and shifted templates diff --git a/src/dartsort/templates/templates.py b/src/dartsort/templates/templates.py index 802f0034..72cfbb1c 100644 --- a/src/dartsort/templates/templates.py +++ b/src/dartsort/templates/templates.py @@ -27,7 +27,9 @@ class TemplateData: # (n_templates,) spike count for each template spike_counts: np.ndarray # (n_templates, n_registered_channels or n_channels) spike count for each channel - spike_counts_by_channel: np.ndarray + spike_counts_by_channel: Optional[np.ndarray] = None + # (n_templates, spike_length_samples, n_registered_channels or n_channels) + raw_std_dev: Optional[np.ndarray] = None registered_geom: Optional[np.ndarray] = None registered_template_depths_um: Optional[np.ndarray] = None @@ -52,6 +54,14 @@ def to_npz(self, npz_path): to_save["registered_template_depths_um"] = ( self.registered_template_depths_um ) + if self.spike_counts_by_channel is not None: + to_save["spike_counts_by_channel"] = ( + self.spike_counts_by_channel + ) + if self.raw_std_dev is not None: + to_save["raw_std_dev"] = ( + self.raw_std_dev + ) if not npz_path.parent.exists(): npz_path.parent.mkdir() np.savez(npz_path, **to_save) @@ -109,7 +119,7 @@ def from_config( motion_est=None, save_npz_name="template_data.npz", localizations_dataset_name="point_source_localizations", - with_locs=True, + with_locs=False, n_jobs=0, units_per_job=8, tsvd=None, @@ -212,11 +222,13 @@ def from_config( # handle registered templates if template_config.registered_templates and motion_est is not None: - registered_template_depths_um = get_template_depths( - results["templates"], - kwargs["registered_geom"], - localization_radius_um=template_config.registered_template_localization_radius_um, - ) + registered_template_depths_um = None + if with_locs: + registered_template_depths_um = get_template_depths( + results["templates"], + kwargs["registered_geom"], + localization_radius_um=template_config.registered_template_localization_radius_um, + ) obj = cls( results["templates"], unit_ids=results["unit_ids"], From 6b5127488889b9f8003533370db2e2bd6a31e70e Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Tue, 3 Sep 2024 08:30:58 -0700 Subject: [PATCH 6/8] Latest vis --- src/dartsort/vis/gmm.py | 181 +++++++++++++++++++++------------------ src/dartsort/vis/unit.py | 4 +- 2 files changed, 102 insertions(+), 83 deletions(-) diff --git a/src/dartsort/vis/gmm.py b/src/dartsort/vis/gmm.py index 9b91feda..70176803 100644 --- a/src/dartsort/vis/gmm.py +++ b/src/dartsort/vis/gmm.py @@ -38,7 +38,7 @@ class DPCSplitPlot(GMMPlot): width = 5 height = 5 - def __init__(self, spike_kind="train", feature="pca", inherit_chans=True, common_chans=True, dist_vmax=1., cmap=plt.cm.rainbow): + def __init__(self, spike_kind="split", feature="pca", inherit_chans=True, common_chans=True, dist_vmax=1., cmap=plt.cm.rainbow): self.spike_kind = spike_kind assert feature in ("pca", "spread_amp") self.feature = feature @@ -58,7 +58,7 @@ def draw(self, panel, gmm, unit_id): for sl, data in gmm.batches(in_unit): gmm[unit_id].residual_embed(**data, out=features[sl]) z = features[:, : gmm.dpc_split_kw.rank].numpy(force=True) - elif self.spike_kind == "train": + elif self.spike_kind == "split": _, in_unit, z = gmm.split_features(unit_id) elif self.spike_kind == "global": in_unit, data = gmm.get_training_data(unit_id) @@ -76,7 +76,7 @@ def draw(self, panel, gmm, unit_id): ) z = loadings.numpy(force=True) elif self.feature == "spread_amp": - assert self.spike_kind in ("train", "global") + assert self.spike_kind in ("split", "global") in_unit, data = gmm.get_training_data(unit_id) waveforms = data["waveforms"] channel_norms = torch.sqrt(torch.nan_to_num(waveforms.square().sum(1))) @@ -124,8 +124,8 @@ def draw(self, panel, gmm, unit_id): axes[0].set_xlabel("spread") axes[0].set_ylabel("amp") - axes = panel_bottom.subplots(ncols=2) - axes = {'d': axes[0], 'e': axes[1]} + axes = panel_bottom.subplots(ncols=3) + axes = {'d': axes[0], 'f': axes[1], 'e': axes[2]} labels = dens['labels'][inv] in_unit = torch.from_numpy(in_unit) ids = np.unique(labels) @@ -160,36 +160,38 @@ def draw(self, panel, gmm, unit_id): ju = [(j, u) for j, u in enumerate(new_units) if u.n_chans_unit] # plot new unit maxchan wfs and old one in black - ax = axes["d"] - ax.axhline(0, c="k", lw=0.8) - all_means = [] - for j, unit in ju: - if unit.do_interp: - times = unit.interp.grid.squeeze() - if self.fitted_only: - times = times[unit.interp.grid_fitted] - else: - times = torch.tensor([sum(gmm.t_bounds) / 2]).to(gmm.device) - times = torch.atleast_1d(times) - - chans = torch.full((times.numel(),), unit.max_channel, device=times.device) - means = unit.get_means(times).to(gmm.device) - if j > 0: - all_means.append(means.mean(0)) - means = unit.to_waveform_channels(means, waveform_channels=chans[:, None]) - means = means[..., 0] - means = gmm.data.tpca._inverse_transform_in_probe(means) - means = means.numpy(force=True) - color = glasbey1024[j] - - lines = np.stack( - (np.broadcast_to(np.arange(means.shape[1])[None], means.shape), means), - axis=-1, - ) - ax.add_collection(LineCollection(lines, colors=color, lw=1)) - ax.autoscale_view() - ax.set_xticks([]) - ax.spines[["top", "right", "bottom"]].set_visible(False) + gmc = gmm[unit_id].max_channel + for ax, pick in zip((axes["d"], axes["f"]), ("unit", "shared")): + ax.axhline(0, c="k", lw=0.8) + all_means = [] + for j, unit in ju: + if unit.do_interp: + times = unit.interp.grid.squeeze() + if self.fitted_only: + times = times[unit.interp.grid_fitted] + else: + times = torch.tensor([sum(gmm.t_bounds) / 2]).to(gmm.device) + times = torch.atleast_1d(times) + + chans = torch.full((times.numel(),), unit.max_channel if pick == "unit" else gmc, device=times.device) + means = unit.get_means(times).to(gmm.device) + if j > 0: + all_means.append(means.mean(0)) + means = unit.to_waveform_channels(means, waveform_channels=chans[:, None]) + means = means[..., 0] + means = gmm.data.tpca._inverse_transform_in_probe(means) + means = means.numpy(force=True) + color = glasbey1024[j] + + lines = np.stack( + (np.broadcast_to(np.arange(means.shape[1])[None], means.shape), means), + axis=-1, + ) + ax.add_collection(LineCollection(lines, colors=color, lw=1)) + ax.autoscale_view() + ax.set_xticks([]) + ax.spines[["top", "right", "bottom"]].set_visible(False) + ax.set_title(pick) # plot distance matrix kind = gmm.merge_metric @@ -424,6 +426,7 @@ def __init__( n_iter=50, common_chans=True, inherit_chans=True, + impute_before_center=False, min_overlap=0.0, ): self.cmap = cmap @@ -439,6 +442,7 @@ def __init__( self.n_iter = n_iter self.common_chans = common_chans self.inherit_chans = inherit_chans + self.impute_before_center = impute_before_center self.min_overlap = min_overlap self.merge_on_waveform_radius = merge_on_waveform_radius if self.scaled: @@ -458,7 +462,7 @@ def draw(self, panel, gmm, unit_id): if ids.size > 1: top, bottom = panel.subfigures(nrows=2) ax_top = top.subplots() - axes = bottom.subplot_mosaic("de") + axes = bottom.subplot_mosaic("dfe") else: ax_top = panel.subplots() @@ -482,7 +486,7 @@ def draw(self, panel, gmm, unit_id): for j, label in enumerate(ids): u = spike_interp.InterpUnit( do_interp=False, - **gmm.unit_kw, + **gmm.unit_kw | dict(impute_before_center=self.impute_before_center), ) inu = in_unit[np.flatnonzero(labels == label)] w = None if weights is None else weights[labels == label, j] @@ -504,36 +508,38 @@ def draw(self, panel, gmm, unit_id): ju = [(j, u) for j, u in enumerate(new_units) if u.n_chans_unit] # plot new unit maxchan wfs and old one in black - ax = axes["d"] - ax.axhline(0, c="k", lw=0.8) - all_means = [] - for j, unit in ju: - if unit.do_interp: - times = unit.interp.grid.squeeze() - if self.fitted_only: - times = times[unit.interp.grid_fitted] - else: - times = torch.tensor([sum(gmm.t_bounds) / 2]).to(gmm.device) - times = torch.atleast_1d(times) - - chans = torch.full((times.numel(),), unit.max_channel, device=times.device) - means = unit.get_means(times).to(gmm.device) - if j > 0: - all_means.append(means.mean(0)) - means = unit.to_waveform_channels(means, waveform_channels=chans[:, None]) - means = means[..., 0] - means = gmm.data.tpca._inverse_transform_in_probe(means) - means = means.numpy(force=True) - color = glasbey1024[j] - - lines = np.stack( - (np.broadcast_to(np.arange(means.shape[1])[None], means.shape), means), - axis=-1, - ) - ax.add_collection(LineCollection(lines, colors=color, lw=1)) - ax.autoscale_view() - ax.set_xticks([]) - ax.spines[["top", "right", "bottom"]].set_visible(False) + gmc = gmm[unit_id].max_channel + for ax, pick in zip((axes["d"], axes["f"]), ("unit", "shared")): + ax.axhline(0, c="k", lw=0.8) + all_means = [] + for j, unit in ju: + if unit.do_interp: + times = unit.interp.grid.squeeze() + if self.fitted_only: + times = times[unit.interp.grid_fitted] + else: + times = torch.tensor([sum(gmm.t_bounds) / 2]).to(gmm.device) + times = torch.atleast_1d(times) + + chans = torch.full((times.numel(),), unit.max_channel if pick == "unit" else gmc, device=times.device) + means = unit.get_means(times).to(gmm.device) + if j > 0: + all_means.append(means.mean(0)) + means = unit.to_waveform_channels(means, waveform_channels=chans[:, None]) + means = means[..., 0] + means = gmm.data.tpca._inverse_transform_in_probe(means) + means = means.numpy(force=True) + color = glasbey1024[j] + + lines = np.stack( + (np.broadcast_to(np.arange(means.shape[1])[None], means.shape), means), + axis=-1, + ) + ax.add_collection(LineCollection(lines, colors=color, lw=1)) + ax.autoscale_view() + ax.set_xticks([]) + ax.spines[["top", "right", "bottom"]].set_visible(False) + ax.set_title(pick) # plot distance matrix kind = gmm.merge_metric @@ -597,7 +603,7 @@ class HDBScanSplitPlot(GMMPlot): width = 2 height = 2 - def __init__(self, spike_kind="train"): + def __init__(self, spike_kind="split"): self.spike_kind = spike_kind def draw(self, panel, gmm, unit_id): @@ -609,7 +615,7 @@ def draw(self, panel, gmm, unit_id): for sl, data in gmm.batches(in_unit): gmm[unit_id].residual_embed(**data, out=features[sl]) z = features[:, : gmm.dpc_split_kw.rank].numpy(force=True) - elif self.spike_kind == "train": + elif self.spike_kind == "split": _, in_unit, z = gmm.split_features(unit_id) elif self.spike_kind == "global": in_unit, data = gmm.get_training_data(unit_id) @@ -694,7 +700,7 @@ class AmplitudesOverTimePlot(GMMPlot): width = 5 height = 3 - def __init__(self, kinds=("recon", 'model'), colors="bkrg"): + def __init__(self, kinds=("recon", "maxchan_energy", 'model'), colors="bkrg"): self.kinds = kinds self.colors = dict(zip(kinds, colors)) @@ -734,6 +740,7 @@ def draw(self, panel, gmm, unit_id): t = utd["times"].numpy(force=True) for kind, a in show.items(): ax.scatter(t, a, c=self.colors[kind], s=3, lw=0, label=kind) + ax.legend(loc="upper left") ax.set_ylabel("amplitude") @@ -788,19 +795,28 @@ def draw(self, panel, gmm, unit_id): overlaps = overlaps.numpy(force=True) times = utd["times"][spike_ix].numpy(force=True) - ax = panel.subplots() + ax, ay = panel.subplots(ncols=2, width_ratios=[2, 1]) + ay.grid(True) for j, (kind, b) in enumerate(badnesses.items()): + b = b.numpy(force=True) + c = self.colors[j] ax.scatter( times, - b.numpy(force=True), + b, alpha=overlaps, s=3, - c=self.colors[j], + c=c, lw=0, label=kind, ) + ay.ecdf(b, lw=1, color=c) + ay.axvline(np.mean(b), color=c, lw=1) + ay.axvline(np.median(b), color=c, ls="--", lw=1) + ay.set_xlabel(kind) + ay.set_ylabel("cdf") ax.legend(loc="upper left") ax.set_ylabel("badness") + ay.set_yticks([0, 0.25, 0.5, 0.75, 1]) class FeaturesVsBadnessesPlot(GMMPlot): @@ -1352,9 +1368,9 @@ def __init__(self, n_neighbors=5): self.n_neighbors = n_neighbors def get_neighbors(self, gmm, unit_id, reversed=False): - unit_dists = gmm.central_divergences(units_a=[unit_id])[0] + unit_dists = gmm.central_divergences(units_a=torch.tensor([unit_id]))[0] unit_ids = gmm.unit_ids() - neighbors = torch.argsort(unit_dists) + neighbors = torch.argsort((unit_ids != unit_id).to(unit_dists) + unit_dists) assert unit_ids[neighbors[0]] == unit_id neighbors = neighbors[: self.n_neighbors + 1] neighbors = neighbors[torch.isfinite(unit_dists[neighbors])] @@ -1561,12 +1577,13 @@ class NeighborBimodality(GMMMergePlot): width = 3 height = 10 - def __init__(self, n_neighbors=5, badness_kind="1-r^2", do_reg=False, masked=False, mask_radius_s=5.0): + def __init__(self, n_neighbors=5, badness_kind="1-r^2", do_reg=False, masked=False, mask_radius_s=5.0, impute_missing=False): self.n_neighbors = n_neighbors self.badness_kind = badness_kind self.do_reg = do_reg self.masked = masked self.mask_radius_s = mask_radius_s + self.impute_missing = impute_missing def draw(self, panel, gmm, unit_id): from isosplit import isocut, dipscore_at @@ -1625,6 +1642,7 @@ def draw(self, panel, gmm, unit_id): unit_ids=[unit_id, u], show_progress=False, kind=self.badness_kind, + impute_missing=self.impute_missing, ) a = np.full(badness.shape, np.inf) a[badness.coords] = badness.data @@ -1666,7 +1684,8 @@ def draw(self, panel, gmm, unit_id): ds_ud = f"{ds_ud:0.3f}".lstrip("0").rstrip("0") ds_udw = f"{ds_udw:0.3f}".lstrip("0").rstrip("0") mstr = "masked " if self.masked else "" - row[1].set_title(f"{mstr} u{ds_ud} uw{ds_udw}", fontsize=7) + istr = "imp " if self.impute_missing else "" + row[1].set_title(f"{mstr}{istr} u{ds_ud} uw{ds_udw}", fontsize=7) if self.do_reg: sns.regplot( @@ -1864,7 +1883,7 @@ def draw(self, panel, gmm, unit_id): unit.bar(axes[0, 1], alags, acg, fill=True, fc=colors[0]) axes[0, 0].axis("off") axes[0, 1].set_ylabel(f"my acg {neighbor_ids[0]}") - + j = 0 for j, ub in enumerate(neighbor_ids[1:], start=1): their_st = gmm.data.times_samples[gmm.labels == ub] @@ -1893,7 +1912,7 @@ def draw(self, panel, gmm, unit_id): # HDBScanSplitPlot(spike_kind="residual_full"), # HDBScanSplitPlot(), # ZipperSplitPlot(), - KMeansPPSPlitPlot(), + KMeansPPSPlitPlot(inherit_chans=True, n_clust=10, impute_before_center=True), GridMeansSingleChanPlot(), InputWaveformsSingleChanPlot(), # InputWaveformsSingleChanOverTimePlot(channel="unit"), @@ -1903,7 +1922,7 @@ def draw(self, panel, gmm, unit_id): BadnessesOverTimePlot(), EmbedsOverTimePlot(), # DPCSplitPlot(spike_kind="residual_full"), - DPCSplitPlot(spike_kind="train"), + DPCSplitPlot(spike_kind="split"), # DPCSplitPlot(spike_kind="global"), # DPCSplitPlot(spike_kind="global", feature="spread_amp"), FeaturesVsBadnessesPlot(), @@ -1928,8 +1947,8 @@ def draw(self, panel, gmm, unit_id): ISICorner(bin_ms=0.25), ISICorner(bin_ms=0.5, max_ms=8, tick_step=2), # NeighborBimodality(), - NeighborBimodality(badness_kind="diagz", masked=True), NeighborBimodality(badness_kind="1-r^2", masked=True), + NeighborBimodality(badness_kind="1-r^2", masked=True, impute_missing=True), CCGColumn(), # NeighborBimodality(badness_kind="1-scaledr^2", masked=True), ) diff --git a/src/dartsort/vis/unit.py b/src/dartsort/vis/unit.py index 25349af4..227d7cf4 100644 --- a/src/dartsort/vis/unit.py +++ b/src/dartsort/vis/unit.py @@ -421,10 +421,10 @@ def draw(self, panel, sorting_analysis, unit_id, axis=None): ) trough_offset_samples = self.trough_offset_samples spike_length_samples = self.spike_length_samples - if tslice.start is not None: + if tslice is not None and tslice.start is not None: trough_offset_samples = self.trough_offset_samples - tslice.start spike_length_samples = self.spike_length_samples - tslice.start - if tslice.stop is not None: + if tslice is not None and tslice.stop is not None: spike_length_samples = tslice.stop - tslice.start max_abs_amp = None From a6307cc5a32633a54ccf168ad8c58160f729c043 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Wed, 18 Sep 2024 11:17:02 -0700 Subject: [PATCH 7/8] Debug + truncnorm + pearson --- src/dartsort/cluster/modes.py | 80 ++++++++++++++- src/dartsort/detect/detect.py | 24 +++-- src/dartsort/peel/subtract.py | 26 ++++- src/dartsort/vis/gmm.py | 188 ++++++++++++++++++++++++++-------- src/dartsort/vis/unit.py | 5 + 5 files changed, 269 insertions(+), 54 deletions(-) diff --git a/src/dartsort/cluster/modes.py b/src/dartsort/cluster/modes.py index c7ac87eb..72879438 100644 --- a/src/dartsort/cluster/modes.py +++ b/src/dartsort/cluster/modes.py @@ -1,4 +1,6 @@ import numpy as np +from scipy.signal import find_peaks +from scipy.stats import norm from . import density # todo: replace all isosplit stuff with things based on scipy's isotonic regression. @@ -35,6 +37,44 @@ def fit_unimodal_right(x, f, weights=None, cut=0, hard=False): return out +def fit_truncnorm_right(x, f, weights=None, cut=0, hard=False, n_iter=10): + """Like above, but fits truncated normal to xs > cut with MoM.""" + if weights is None: + weights = np.ones_like(f) + + # figure out where cut lands and what f(cut) should be + cuti_left = np.searchsorted(x, cut) + if cuti_left >= len(x) - 2: + # everything is to the left... return uniform + return np.full_like(f, 1/len(f)) + cuti_right = np.searchsorted(x, cut, side="right") - 1 + if cuti_right <= 1: + # everything is to the right... fit normal! + mean = np.average(x, weights=weights) + var = np.average((x - mean) ** 2, weights=weights) + return norm.pdf(x, loc=mean, scale=np.sqrt(var)) + # else, in the middle somewhere... + assert cuti_left in (cuti_right, cuti_right + 1) + + xcut = x[cuti_right:] + mean = np.average(xcut, weights=weights[cuti_right:]) + var = np.average((xcut - mean) ** 2, weights=weights[cuti_right:]) + mu = mean + sigma = std = np.sqrt(var) + for i in range(n_iter): + alpha = (cut - mu) / sigma + phi_alpha = norm.cdf(alpha) + Z = 1.0 - phi_alpha + # we have: mean ~ mu + sigma*phi(alpha)/Z + # and var ~ sigma^2(1+alpha phi(alpha) /Z - (phi(alpha)/Z)^2) + # implies mu = mean - sigma*phi(alpha)/Z + # sigma^2 = var/(1+alpha phi(alpha) /Z - (phi(alpha)/Z)^2) + sigma = np.sqrt(var / (1 + alpha * phi_alpha / Z - (phi_alpha / Z) ** 2)) + mu = mean - sigma * phi_alpha / Z + + return norm.pdf(x, loc=mu, scale=sigma) + + def fit_bimodal_at(x, f, weights=None, cut=0): from isosplit import up_down_isotonic_regression if weights is None: @@ -49,7 +89,16 @@ def fit_bimodal_at(x, f, weights=None, cut=0): return out -def smoothed_dipscore_at(cut, samples, sample_weights, alternative="smoothed", dipscore_only=False, score_kind="tv"): +def smoothed_dipscore_at( + cut, + samples, + sample_weights, + alternative="smoothed", + dipscore_only=False, + score_kind="tv", + cut_relmax_order=3, + kind="isotonic", +): if sample_weights is None: sample_weights = np.ones_like(samples) densities = density.get_smoothed_densities( @@ -63,9 +112,27 @@ def smoothed_dipscore_at(cut, samples, sample_weights, alternative="smoothed", d ) spacings = np.diff(samples) spacings = np.concatenate((spacings[:1], 0.5 * (spacings[1:] + spacings[:-1]), spacings[-1:])) + densities /= (densities * spacings).sum() + + if cut is None: + # closest maxes left + right of 0 + maxers, _ = find_peaks(densities, distance=cut_relmax_order) + coords = samples[maxers] + left_cut = right_cut = 0 + if (coords < 0).any(): + left_cut = coords[coords < 0].max() + if (coords > 0).any(): + right_cut = coords[coords > 0].min() + candidates = np.logical_and(samples >= left_cut, samples <= right_cut) + cut = 0 + if candidates.any(): + cut = samples[candidates][np.argmin(densities[candidates])] + if alternative == "bimodal": densities = fit_bimodal_at(samples, densities, weights=sample_weights, cut=cut) - densities /= (densities * spacings).sum() + densities /= (densities * spacings).sum() + else: + assert alternative == "smoothed" score = 0 best_dens_err = np.inf @@ -83,7 +150,12 @@ def smoothed_dipscore_at(cut, samples, sample_weights, alternative="smoothed", d s = np.ascontiguousarray(samples[order]) d = np.ascontiguousarray(densities[order]) w = np.ascontiguousarray(sample_weights[order]) - dens = fit_unimodal_right(sign * s, d, weights=w, hard=hard) + if kind == "isotonic": + dens = fit_unimodal_right(sign * s, d, weights=w, cut=sign * cut, hard=hard) + elif kind == "truncnorm": + dens = fit_truncnorm_right(sign * s, d, weights=w, cut=sign * cut, hard=hard) + else: + assert False dens = dens[order] dens /= (dens * spacings).sum() dens_err = (np.abs(dens - densities) * spacings).sum() @@ -100,4 +172,4 @@ def smoothed_dipscore_at(cut, samples, sample_weights, alternative="smoothed", d if dipscore_only: return score - return score, samples, densities, best_uni + return score, samples, densities, best_uni, cut diff --git a/src/dartsort/detect/detect.py b/src/dartsort/detect/detect.py index 455ec0b8..688064bd 100644 --- a/src/dartsort/detect/detect.py +++ b/src/dartsort/detect/detect.py @@ -12,6 +12,7 @@ def detect_and_deduplicate( spatial_dedup_batch_size=512, exclude_edges=True, return_energies=False, + detection_mask=None, ): """Detect and deduplicate peaks @@ -34,6 +35,9 @@ def detect_and_deduplicate( dedup_temporal_radius : int Only the largest peak within this sliding radius will be kept + detection_mask : tensor + If supplied, this floating tensor of 1s or 0s will act + as a gate to suppress detections in the 0s Returns ------- @@ -80,6 +84,10 @@ def detect_and_deduplicate( # remove peaks smaller than our threshold F.threshold_(energies, threshold, 0.0) + # optionally remove censored peaks + if detection_mask is not None: + energies.mul_(detection_mask) + # -- temporal deduplication if dedup_temporal_radius > 0: max_energies = F.max_pool2d( @@ -97,23 +105,25 @@ def detect_and_deduplicate( max_energies = max_energies[0, 0] # -- spatial deduplication - # we would like to max pool again on the other axis, - # but that doesn't support any old radial neighborhood + # this is max pooling within the channel index's neighborhood's if all_dedup: - max_energies[:] = max_energies.max(dim=1, keepdim=True).values + max_energies = max_energies.max(dim=1, keepdim=True).values elif dedup_channel_index is not None: # pad channel axis with extra chan of 0s max_energies = F.pad(max_energies, (0, 1)) for batch_start in range(0, nsamples, spatial_dedup_batch_size): batch_end = batch_start + spatial_dedup_batch_size - max_energies[batch_start:batch_end, :nchans] = torch.max( - max_energies[batch_start:batch_end, dedup_channel_index], dim=2 - ).values + torch.amax( + max_energies[batch_start:batch_end, dedup_channel_index], + dim=2, + out=max_energies[batch_start:batch_end, :nchans], + ) max_energies = max_energies[:, :nchans] # if temporal/spatial max made you grow, you were not a peak! if (dedup_temporal_radius > 0) or (dedup_channel_index is not None): - max_energies[max_energies > energies] = 0.0 + # max_energies[max_energies > energies] = 0.0 + max_energies.masked_fill_(max_energies > energies, 0.0) # sparsify and return if exclude_edges: diff --git a/src/dartsort/peel/subtract.py b/src/dartsort/peel/subtract.py index 7c14092e..2a8b7298 100644 --- a/src/dartsort/peel/subtract.py +++ b/src/dartsort/peel/subtract.py @@ -38,6 +38,7 @@ def __init__( fit_subsampling_random_state=0, fit_sampling="random", residnorm_decrease_threshold=3.162, + persist_deduplication=True, dtype=torch.float, ): super().__init__( @@ -57,6 +58,7 @@ def __init__( self.trough_offset_samples = trough_offset_samples self.spike_length_samples = spike_length_samples self.peak_sign = peak_sign + self.persist_deduplication = persist_deduplication if subtract_channel_index is None: subtract_channel_index = channel_index.clone().detach() self.register_buffer("subtract_channel_index", subtract_channel_index) @@ -208,6 +210,7 @@ def peel_chunk( peak_sign=self.peak_sign, spatial_dedup_channel_index=self.spatial_dedup_channel_index, residnorm_decrease_threshold=self.residnorm_decrease_threshold, + persist_deduplication=self.persist_deduplication, ) # add in chunk_start_samples @@ -343,6 +346,9 @@ def subtract_chunk( peak_sign="both", spatial_dedup_channel_index=None, residnorm_decrease_threshold=3.162, # sqrt(10) + persist_deduplication=True, + relative_peak_radius=5, + dedup_temporal_radius=7, ): """Core peeling routine for subtraction""" # validate arguments to avoid confusing error messages later @@ -374,15 +380,26 @@ def subtract_chunk( spike_times = [] spike_channels = [] spike_features = [] + if persist_deduplication: + detection_mask = torch.ones_like(residual) + dedup_temporal_ix = torch.arange( + -dedup_temporal_radius, dedup_temporal_radius, device=residual.device + ) - for threshold in detection_thresholds: + for j, threshold in enumerate(detection_thresholds): # -- detect and extract waveforms # detection has more args which we don't expose right now + step_mask = None + if persist_deduplication and j > 0: + step_mask = detection_mask[:, :-1] times_samples, channels = detect_and_deduplicate( residual[:, :-1], threshold, dedup_channel_index=spatial_dedup_channel_index, peak_sign=peak_sign, + detection_mask=step_mask, + relative_peak_radius=relative_peak_radius, + dedup_temporal_radius=dedup_temporal_radius, ) if not times_samples.numel(): continue @@ -444,6 +461,13 @@ def subtract_chunk( already_padded=True, in_place=True, ) + if persist_deduplication: + time_ix = times_samples.unsqueeze(1) + dedup_temporal_ix.unsqueeze(0) + if spatial_dedup_channel_index is not None: + chan_ix = spatial_dedup_channel_index[channels] + else: + chan_ix = channels.unsqueeze(1) + detection_mask[time_ix[:, :, None], chan_ix[:, None, :]] = 0.0 del times_samples, channels, waveforms, features # check if we got no spikes diff --git a/src/dartsort/vis/gmm.py b/src/dartsort/vis/gmm.py index 70176803..5b237b14 100644 --- a/src/dartsort/vis/gmm.py +++ b/src/dartsort/vis/gmm.py @@ -7,6 +7,7 @@ import torch from tqdm.auto import tqdm from scipy.spatial import KDTree +from scipy.stats import pearsonr from ..cluster import density from dartsort.cluster.modes import smoothed_dipscore_at @@ -422,12 +423,16 @@ def __init__( merge_on_waveform_radius=True, dist_vmax=1.0, show_values=True, - n_clust=5, - n_iter=50, - common_chans=True, + n_clust=None, + n_iter=None, + common_chans=False, inherit_chans=True, - impute_before_center=False, + impute_before_center=True, + with_proportions=False, min_overlap=0.0, + zip_metric=None, + verbose=False, + equilibrium_kmeans_alpha=0.0, ): self.cmap = cmap self.fitted_only = fitted_only @@ -445,12 +450,23 @@ def __init__( self.impute_before_center = impute_before_center self.min_overlap = min_overlap self.merge_on_waveform_radius = merge_on_waveform_radius + self.with_proportions = with_proportions + self.zip_metric = zip_metric + self.verbose = verbose + self.equilibrium_kmeans_alpha = equilibrium_kmeans_alpha if self.scaled: self.title = f"scaled {self.title}" def draw(self, panel, gmm, unit_id): + n_clust = self.n_clust or gmm.kmeans_nclust + n_iter = self.n_iter or gmm.kmeans_niter in_unit, labels, weights = gmm.kmeanspp( - unit_id, n_clust=self.n_clust, n_iter=self.n_iter + unit_id, + n_clust=n_clust, + n_iter=n_iter, + with_proportions=self.with_proportions, + verbose=self.verbose, + equilibrium_kmeans_alpha=self.equilibrium_kmeans_alpha, ) times = gmm.data.times_seconds[in_unit].numpy(force=True) @@ -458,9 +474,22 @@ def draw(self, panel, gmm, unit_id): labels = labels.numpy(force=True) ids = np.unique(labels) ids = ids[ids >= 0] + if self.verbose: print(f"gmm.KMeansPPSPlitPlot {ids=} {np.unique(labels, return_counts=True)=}") + + zip_metric = self.zip_metric or gmm.zip_metric + zip_threshold = gmm.zip_threshold + if gmm.zip_by_quantile: + in_unit, utd = gmm.get_training_data(unit_id, in_unit=in_unit, waveform_kind="reassign") + spike_ix, overlaps, badnesses = gmm[unit_id].spike_badnesses( + utd["times"], + utd["waveforms"], + utd["waveform_channels"], + kinds=(zip_metric,) + ) + zip_threshold = torch.quantile(badnesses[zip_metric], zip_threshold) if ids.size > 1: - top, bottom = panel.subfigures(nrows=2) + top, bottom, below = panel.subfigures(nrows=3) ax_top = top.subplots() axes = bottom.subplot_mosaic("dfe") else: @@ -483,6 +512,7 @@ def draw(self, panel, gmm, unit_id): channels=gmm[unit_id].channels, max_channel=gmm[unit_id].max_channel, ) + if self.verbose: print(f"{self.impute_before_center=} {self.common_chans=}") for j, label in enumerate(ids): u = spike_interp.InterpUnit( do_interp=False, @@ -542,7 +572,7 @@ def draw(self, panel, gmm, unit_id): ax.set_title(pick) # plot distance matrix - kind = gmm.merge_metric + kind = gmm.zip_metric min_overlap = self.min_overlap if self.min_overlap is None: min_overlap = gmm.min_overlap @@ -551,6 +581,7 @@ def draw(self, panel, gmm, unit_id): subset_channel_index = gmm.data.registered_reassign_channel_index nu = len(new_units) divergences = torch.full((nu, nu), torch.nan) + if self.verbose: print(f"{self.inherit_chans=} {zip_metric=} {subset_channel_index is None=}") for i, ua in ju: # print(f"{i=} {ua.n_chans_unit=} {ua.channels.tolist()=}") for j, ub in ju: @@ -595,7 +626,66 @@ def draw(self, panel, gmm, unit_id): ): tx.set_color(glasbey1024[i]) ty.set_color(glasbey1024[i]) - axis.set_title(gmm.merge_metric) + axis.set_title(f"{kind} {zip_threshold}") + + dbads, bimodalities = gmm.bimodalities_prefit( + new_units, + which_spikes=in_unit, + return_dbad=True, + common_chans=self.common_chans, + impute_missing=self.impute_before_center, + ) + f0 = below + # f0, f1 = below.subfigures(ncols=2, width_ratios=(2, 1)) + # ax1 = f1.subplots() + # im = ax1.imshow( + # bimodalities, + # vmin=0, + # vmax=self.dist_vmax, + # cmap=self.cmap, + # origin="lower", + # interpolation="none", + # ) + # if self.show_values: + # for (j, i), label in np.ndenumerate(bimodalities): + # ax1.text( + # i, + # j, + # f"{label:.2f}".lstrip("0"), + # ha="center", + # va="center", + # clip_on=True, + # fontsize=5, + # ) + # panel.colorbar(im, ax=ax1, shrink=0.3) + # ax1.set_xticks(range(len(new_units))) + # ax1.set_yticks(range(len(new_units))) + # for i, (tx, ty) in enumerate( + # zip(ax1.xaxis.get_ticklabels(), ax1.yaxis.get_ticklabels()) + # ): + # tx.set_color(glasbey1024[i]) + # ty.set_color(glasbey1024[i]) + + axes0 = f0.subplots(nrows=len(dbads), ncols=len(dbads), sharex=True, sharey=True) + for i in range(len(dbads)): + for j in range(len(dbads)): + if j >= i: + axes0[i, j].set_visible(False) + continue + + axes0[i, j].scatter( + dbads[i], + dbads[j], + c=np.stack([glasbey1024[i], glasbey1024[j]], axis=0)[(dbads[i] < dbads[j]).astype(int)], + s=3, + lw=0, + ) + axes0[i, j].set_xticks([]) + axes0[i, j].set_yticks([]) + axes0[i, j].set_title(f"{bimodalities[i, j]:0.2f}", fontsize=5) + + # ax0.hist(dbads.todense(), bins=32, density=True, histtype="step") + class HDBScanSplitPlot(GMMPlot): @@ -1577,16 +1667,22 @@ class NeighborBimodality(GMMMergePlot): width = 3 height = 10 - def __init__(self, n_neighbors=5, badness_kind="1-r^2", do_reg=False, masked=False, mask_radius_s=5.0, impute_missing=False): + def __init__(self, n_neighbors=5, badness_kind=None, do_reg=False, masked=False, mask_radius_s=5.0, impute_missing=False, max_spikes=2048, cut=None, kind="isotonic"): self.n_neighbors = n_neighbors self.badness_kind = badness_kind self.do_reg = do_reg self.masked = masked self.mask_radius_s = mask_radius_s self.impute_missing = impute_missing + self.max_spikes = max_spikes + self.cut = None + self.kind = kind def draw(self, panel, gmm, unit_id): from isosplit import isocut, dipscore_at + kind = self.badness_kind + if kind is None: + kind = gmm.reassign_metric (in_self,) = torch.nonzero(gmm.labels == unit_id, as_tuple=True) neighbors = self.get_neighbors(gmm, unit_id) # remove self @@ -1618,6 +1714,15 @@ def draw(self, panel, gmm, unit_id): else: in_self_local = in_self.numpy(force=True) inu_local = inu.numpy(force=True) + + rg = np.random.default_rng(0) + if inu_local.size > self.max_spikes: + inu_local = rg.choice(inu_local, size=self.max_spikes, replace=False) + inu_local.sort() + if in_self_local.size > self.max_spikes: + in_self_local = rg.choice(in_self_local, size=self.max_spikes, replace=False) + in_self_local.sort() + nu = inu_local.size ns = in_self_local.size @@ -1641,13 +1746,14 @@ def draw(self, panel, gmm, unit_id): which_spikes=torch.from_numpy(which).to(gmm.labels), unit_ids=[unit_id, u], show_progress=False, - kind=self.badness_kind, + kind=kind, impute_missing=self.impute_missing, ) a = np.full(badness.shape, np.inf) a[badness.coords] = badness.data - a = np.nan_to_num(a, nan=1.0, posinf=1.0, copy=False) + a = np.nan_to_num(a, nan=gmm.match_threshold, posinf=gmm.match_threshold, copy=False) self_badness, u_badness = a + pear = pearsonr(self_badness, u_badness) row[0].scatter( u_badness, @@ -1657,30 +1763,32 @@ def draw(self, panel, gmm, unit_id): lw=0, alpha=0.5, ) - row[0].set_title(f"{ns=} {nu=}", fontsize=6) + row[0].set_title(f"{ns=} {nu=} 1-rho={1.0 - pear.statistic} p={pear.pvalue}", fontsize=6) row[0].set_xlabel(f"{u}: {self.badness_kind}") row[0].set_ylabel(f"{unit_id}: {self.badness_kind}") # closer to me = self_badness < u_badness = u_badness - self_badness > 0 = to the right dbad = u_badness - self_badness unique_dbad, inverse, counts = np.unique(dbad, return_counts=True, return_inverse=True) - weights = counts.copy().astype(float) + # weights = counts.copy().astype(float) weights = np.zeros(counts.shape) np.add.at(weights, inverse, sample_weights) - row[1].axvline(0, lw=1, color="k") - n, bins, patches = row[1].hist(dbad, bins=64, histtype="step", color="b", density=True) - row[1].hist(unique_dbad, bins=bins, histtype="step", color="b", linestyle=":", density=True) - row[1].hist(unique_dbad, weights=weights, bins=bins, histtype="step", color="r", density=True) + hist, _, _ = row[1].hist(unique_dbad, bins=bins, histtype="step", color="b", linestyle=":", density=True) + histw, _, _ = row[1].hist(unique_dbad, weights=weights, bins=bins, histtype="step", color="r", density=True) + bc = 0.5 * (bins[1:] + bins[:-1]) - ds_ud, x, m, m_ud = smoothed_dipscore_at(0, unique_dbad, sample_weights=counts.astype(float)) + ds_ud, x, m, m_ud, cut_ud = smoothed_dipscore_at(self.cut, unique_dbad, sample_weights=counts.astype(float), kind=self.kind) row[1].plot(x, m, color="g", lw=1) row[1].plot(x, m_ud, color="g", lw=1, ls=(0, (0.5, 0.5))) - ds_udw, xw, mw, m_udw = smoothed_dipscore_at(0, unique_dbad, sample_weights=weights) + row[1].axvline(cut_ud, lw=0.8, color="g") + ds_udw, xw, mw, m_udw, cut_udw = smoothed_dipscore_at(self.cut, unique_dbad, sample_weights=weights, kind=self.kind) row[1].plot(xw, mw, color="orange", lw=1) row[1].plot(xw, m_udw, color="orange", lw=1, ls=(0, (0.5, 0.5))) + row[1].axvline(cut_udw, lw=0.8, color="orange") + dbin = np.diff(bc).mean() ds_ud = f"{ds_ud:0.3f}".lstrip("0").rstrip("0") ds_udw = f"{ds_udw:0.3f}".lstrip("0").rstrip("0") mstr = "masked " if self.masked else "" @@ -1697,7 +1805,7 @@ def draw(self, panel, gmm, unit_id): ci=None, ) - + class NearbyDivergencesMatrix(GMMMergePlot): kind = "amatrix" width = 2 @@ -1723,27 +1831,16 @@ def draw(self, panel, gmm, unit_id): kind = self.badness_kind if kind is None: kind = gmm.merge_metric - min_overlap = gmm.min_overlap - subset_channel_index = None - if self.merge_on_waveform_radius: - subset_channel_index = gmm.data.registered_reassign_channel_index neighbors = self.get_neighbors(gmm, unit_id) - - nu = len(neighbors) - ju = list(enumerate(gmm[u] for u in neighbors)) - divergences = torch.full((nu, nu), torch.nan) - for i, ua in ju: - for j, ub in ju: - if i == j: - divergences[i, j] = 0 - continue - divergences[i, j] = ua.divergence( - ub, - kind=kind, - min_overlap=min_overlap, - subset_channel_index=subset_channel_index, - ) + nu = neighbors.size + divergences = gmm.central_divergences( + units_a=torch.from_numpy(neighbors), + units_b=torch.from_numpy(neighbors), + kind=kind, + allow_chan_subset=self.merge_on_waveform_radius, + + ) dists = divergences.numpy(force=True) axis = panel.subplots() @@ -1912,7 +2009,7 @@ def draw(self, panel, gmm, unit_id): # HDBScanSplitPlot(spike_kind="residual_full"), # HDBScanSplitPlot(), # ZipperSplitPlot(), - KMeansPPSPlitPlot(inherit_chans=True, n_clust=10, impute_before_center=True), + KMeansPPSPlitPlot(inherit_chans=True, impute_before_center=True), GridMeansSingleChanPlot(), InputWaveformsSingleChanPlot(), # InputWaveformsSingleChanOverTimePlot(channel="unit"), @@ -1938,7 +2035,9 @@ def draw(self, panel, gmm, unit_id): ViolatorTimesVAmps(), NearbyTimesVAmps(), NearbyDivergencesMatrix(merge_on_waveform_radius=True, badness_kind="diagz"), + NearbyDivergencesMatrix(merge_on_waveform_radius=False, badness_kind="diagz"), NearbyDivergencesMatrix(merge_on_waveform_radius=True, badness_kind="1-r^2"), + NearbyDivergencesMatrix(merge_on_waveform_radius=True, badness_kind="cos"), # NearbyDivergencesMatrix(merge_on_waveform_radius=True, badness_kind="1-scaledr^2"), # NearbyDivergencesMatrix(merge_on_waveform_radius=False, badness_kind="diagz"), # NearbyDivergencesMatrix(merge_on_waveform_radius=False, badness_kind="1-scaledr^2"), @@ -1947,8 +2046,8 @@ def draw(self, panel, gmm, unit_id): ISICorner(bin_ms=0.25), ISICorner(bin_ms=0.5, max_ms=8, tick_step=2), # NeighborBimodality(), - NeighborBimodality(badness_kind="1-r^2", masked=True), - NeighborBimodality(badness_kind="1-r^2", masked=True, impute_missing=True), + # NeighborBimodality(badness_kind=None, masked=True, impute_missing=True, kind="truncnorm"), + NeighborBimodality(badness_kind=None, masked=True, impute_missing=True), CCGColumn(), # NeighborBimodality(badness_kind="1-scaledr^2", masked=True), ) @@ -2012,11 +2111,16 @@ def make_all_gmm_summaries( overwrite=False, unit_ids=None, use_threads=False, + n_units=None, + seed=0, **other_global_params, ): save_folder = Path(save_folder) if unit_ids is None: unit_ids = gmm.unit_ids().numpy(force=True) + if n_units is not None and n_units < len(unit_ids): + rg = np.random.default_rng(seed) + unit_ids = rg.choice(unit_ids, size=n_units, replace=False) if not overwrite and all_summaries_done(unit_ids, save_folder, ext=image_ext): return diff --git a/src/dartsort/vis/unit.py b/src/dartsort/vis/unit.py index 227d7cf4..d2648b66 100644 --- a/src/dartsort/vis/unit.py +++ b/src/dartsort/vis/unit.py @@ -1026,11 +1026,16 @@ def make_all_summaries( overwrite=False, unit_ids=None, gizmo_name="sorting_analysis", + n_units=None, + seed=0, **other_global_params, ): save_folder = Path(save_folder) if unit_ids is None: unit_ids = sorting_analysis.unit_ids + if n_units is not None and n_units < len(unit_ids): + rg = np.random.default_rng(seed) + unit_ids = rg.choice(unit_ids, size=n_units, replace=False) if not overwrite and all_summaries_done( unit_ids, save_folder, ext=image_ext ): From 295d51d737721f4afd45f1b4d5c2df5cfe305400 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Thu, 19 Sep 2024 08:14:45 -0700 Subject: [PATCH 8/8] Setting stuff up for multi-channel model fitting --- src/dartsort/transform/transform_base.py | 16 +++ src/dartsort/util/waveform_util.py | 122 +++++++++++++++++++++++ 2 files changed, 138 insertions(+) diff --git a/src/dartsort/transform/transform_base.py b/src/dartsort/transform/transform_base.py index 355e36de..a5f577ef 100644 --- a/src/dartsort/transform/transform_base.py +++ b/src/dartsort/transform/transform_base.py @@ -36,6 +36,20 @@ def precompute(self): pass +class RegularChannelsTransformerMixin: + + def needs_precompute(self): + return hasattr(self, "regular_channel_index") + + def precompute(self): + # create regular channel index... + # assume that we have a + self.radius + # parameter + ... + + + class BaseWaveformDenoiser(BaseWaveformModule): is_denoiser = True @@ -77,6 +91,8 @@ def spike_datasets(self): return (dataset,) + + class BaseWaveformAutoencoder(BaseWaveformDenoiser, BaseWaveformFeaturizer): pass diff --git a/src/dartsort/util/waveform_util.py b/src/dartsort/util/waveform_util.py index d4f48f6c..1e16e481 100644 --- a/src/dartsort/util/waveform_util.py +++ b/src/dartsort/util/waveform_util.py @@ -85,6 +85,65 @@ def fill_geom_holes(geom): return filled_geom, is_original +def regularize_geom(geom, radius=0): + """Re-order, fill holes, and optionally expand geometry to make it 'regular' + + Used in make_regular_channel_index. That docstring has some info about what's + going on here. + """ + nchans = len(geom) + eps = pdist(geom).min() / 2.0 + + rgeom = geom.copy() + for j in range(geom.shape[1]): + # skip empty dims + if geom[:, j].ptp() < eps: + continue + rgeom = _regularize_1d(rgeom, radius=max(eps, radius), eps=eps, dim=j) + + # order regularized geom by depth and then x + order = np.lexsort(rgeom.T) + rgeom = rgeom[order] + + return rgeom, eps + + +def _regularize_1d(geom, mapping, radius, eps, dim=1): + total = geom[:, dim].ptp() + dim_pitch = get_pitch(geom, direction=dim) + steps = int(np.ceil(total / dim_pitch)) + + min_pos = geom[:, dim].min() - radius + max_pos = geom[:, dim].min() + radius + + all_positions = [] + offset = np.zeros(geom.shape[1]) + for step in range(-steps, steps + 1): + offset[j] = step * dim_pitch + # add positions within the radius + offset_geom = geom + offset + keepers = offset_geom[:, dim].clip(min_pos, max_pos) == offset_geom[:, dim] + all_positions.append(offset_geom[keepers]) + + all_positions = np.concatenate(all_positions) + all_positions = np.unique(all_positions, axis=0) + + # deal with fp tolerance + dists = squareform(pdist(all_positions)) + A = dists < eps + n_neighbs = A.sum(0) + if n_neighbs.max() == 1: + return all_positions + else: + assert n_neighbs.max() > 1 + + from scipy.cluster.hierarchy import linkage, fcluster + Z = linkage(A.astype(np.float32)) + labels = fcluster(Z, 1.1, criterion='distance') + labels -= 1 + return all_positions[np.unique(labels)] + + # -- channel index creation @@ -164,6 +223,69 @@ def make_filled_channel_index(geom, radius, p=2, pad_val=None, to_torch=False): return channel_index +def make_regular_channel_index(geom, radius, p=2, to_torch=False): + """Channel index for multi-channel models + + In this channel index, the layout of channels around the max channel is + always consistent -- but this is achieved by dummy channels. This makes + for a consistent layout to input into multi-channel feature learners, at + least relative to the detection channel. However, those learners have to + deal with masked out dummy channels. + + Example: + + Let's say the probe looks like: + o o + o o + o + o o + And, let's say that our channel index's radius is twice the vertical + spacing (and say this is the same as the horizontal spacing). Then + the top left channel might have (eyeballing it) 4 neighbors (excluding + itself), the second row left channel maybe 5, the second row right channel + only 4 due to the hole. + + This is padded out to a dummy probe (masked chans as xs) with holes filled in: + x x x x x x + x x x x x x + x x o o x x + x x o o x x + x x o x x x + x x o o x x + x x x x x x + x x x x x x + Now, all of the real channels have the same number of neighbors and in fact + exist in the same spatial relationship with their channel neighborhood. It's + just that some of those neighbors are fake. But, extracting a radiul channel + index here will lead to a consistent layout -- as long as the channels are + ordered right (!!). + + Note that for probes which don't have a regular layout, this just doesn't + make sense at all. I'm thinking of ones that have weird layouts where one channel + is way bigger than the others and serves as a reference, etc -- throw those + away! + """ + rgeom, eps = regularize_geom(geom, radius=radius) + + # determine original geom's position in the regularized one, and which + # channels are fake chans (they are unmatched in the query) + kdt = KDTree(geom) + dists, reg2orig = kdt.query(rgeom, k=1, distance_upper_bound=eps) + + # make regularized channel index... + rci = make_channel_index(rgeom, radius, p=p) + + # subset it to non-fake chans and replace regularized channel indices with + # the corresponding original channel index (or len(geom) if fake) + real_reg = np.flatnonzero(reg2orig < kdt.n) + real_reg_ix = reg2orig[real_reg] + ordered_real_reg = real_reg[np.argsort(real_reg_ix)] + channel_index = reg2orig[rci[ordered_real_reg]] + + return channel_index + + + def make_contiguous_channel_index(n_channels, n_neighbors=40): """Channel index with linear neighborhoods in channel id space""" channel_index = []