From 3effa7f7c9790941eeb5414cbc6f0369268f08f3 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Thu, 10 Oct 2024 12:52:48 -0700 Subject: [PATCH] Refactor ccg vis --- src/dartsort/vis/analysis_plots.py | 50 +++++ src/dartsort/vis/gmm.py | 344 ++++++++++++++++++++++------- src/dartsort/vis/unit.py | 47 +--- 3 files changed, 323 insertions(+), 118 deletions(-) diff --git a/src/dartsort/vis/analysis_plots.py b/src/dartsort/vis/analysis_plots.py index 9a1292b7..da6b5919 100644 --- a/src/dartsort/vis/analysis_plots.py +++ b/src/dartsort/vis/analysis_plots.py @@ -166,3 +166,53 @@ def density_peaks_study(X, density_result, dims=[0, 1], fig=None, axes=None, idx *X[missed][:, dims].T, c="gray", **scatter_kw ) return fig, axes + + +def isi_hist(times_s, axis, max_ms=5, bin_ms=0.1, color="k", label=None, histtype="bar", alpha=1.0): + dt_ms = np.diff(times_s) * 1000 + bin_edges = np.arange( + 0, + max_ms + bin_ms, + bin_ms, + ) + # counts, _ = np.histogram(dt_ms, bin_edges) + # bin_centers = 0.5 * (bin_edges[1:] + bin_edges[:-1]) + # axis.bar(bin_centers, counts) + axis.hist(dt_ms, bin_edges, color=color, label=label, histtype=histtype, alpha=alpha) + axis.set_xlabel("isi (ms)") + axis.set_ylabel(f"count (out of {dt_ms.size} total isis)") + + +def correlogram(times_a, times_b=None, max_lag=50): + lags = np.arange(-max_lag, max_lag + 1) + ccg = np.zeros(len(lags), dtype=int) + + times_a = np.sort(times_a) + auto = times_b is None + if auto: + times_b = times_a + else: + times_b = np.sort(times_b) + + for i, lag in enumerate(lags): + lagged_b = times_b + lag + insertion_inds = np.searchsorted(times_a, lagged_b) + found = insertion_inds < len(times_a) + ccg[i] = np.sum(times_a[insertion_inds[found]] == lagged_b[found]) + + if auto: + ccg[lags == 0] = 0 + + return lags, ccg + + +def bar(ax, x, y, **kwargs): + dx = np.diff(x).min() + x0 = np.concatenate((x - dx, x[-1:] + dx)) + return ax.stairs(y, x0, **kwargs) + + +def plot_correlogram(axis, times_a, times_b=None, max_lag=50, color="k", fill=True, **stairs_kwargs): + lags, ccg = correlogram(times_a, times_b=times_b, max_lag=max_lag) + axis.set_xlabel("lag (samples)") + return bar(axis, lags, ccg, fill=fill, color=color, **stairs_kwargs) diff --git a/src/dartsort/vis/gmm.py b/src/dartsort/vis/gmm.py index 04ba0e68..6231f8c3 100644 --- a/src/dartsort/vis/gmm.py +++ b/src/dartsort/vis/gmm.py @@ -14,6 +14,7 @@ from .colors import glasbey1024 from . import analysis_plots, layout, unit from ..util.multiprocessing_util import CloudpicklePoolExecutor, get_pool, ThreadPoolExecutor +from ..util import spikeio from .waveforms import geomplot try: @@ -124,7 +125,7 @@ def draw(self, panel, gmm, unit_id): if self.feature == "spread_amp": axes[0].set_xlabel("spread") axes[0].set_ylabel("amp") - + axes = panel_bottom.subplots(ncols=3) axes = {'d': axes[0], 'f': axes[1], 'e': axes[2]} labels = dens['labels'][inv] @@ -173,7 +174,7 @@ def draw(self, panel, gmm, unit_id): else: times = torch.tensor([sum(gmm.t_bounds) / 2]).to(gmm.device) times = torch.atleast_1d(times) - + chans = torch.full((times.numel(),), unit.max_channel if pick == "unit" else gmc, device=times.device) means = unit.get_means(times).to(gmm.device) if j > 0: @@ -408,16 +409,74 @@ def draw(self, panel, gmm, unit_id): axis.set_title(gmm.merge_metric) +class MStep(GMMPlot): + kind = "single" + width = 4 + height = 5 + alpha = 0.05 + + def draw(self, panel, gmm, unit_id, unit=None, in_unit=None): + if unit is None: + unit = gmm[unit_id] + in_unit, utd = gmm.get_training_data(unit_id, in_unit=in_unit, waveform_kind="original") + times = utd['times'] + waveforms = utd['waveforms'] + waveform_channels = utd['waveform_channels'] + padded_geom = gmm.data.padded_registered_geom + + if unit.impute_before_center: + waveforms_rel = unit.impute( + times, + waveforms, + waveform_channels, + padded_registered_geom=padded_geom, + centered=False, + ) + else: + waveforms_rel = unit.to_unit_channels( + waveforms, + times, + waveform_channels=waveform_channels, + fill_mode="constant", + constant_value=torch.nan, + ) + n, r, c = waveforms_rel.shape + x = waveforms_rel.permute(0, 2, 1).reshape(n, -1).numpy(force=True) + + mean = unit.mean.reshape(r, c).T.numpy(force=True).ravel() + std = unit.std.reshape(r, c).T.numpy(force=True).ravel() + + color = glasbey1024[unit_id % len(glasbey1024)] + + ax, ay = panel.subplots(nrows=2, sharex=True) + ax.plot(x.T, color="k", alpha=self.alpha) + ax.fill_between( + np.arange(r * c), + mean - std, + mean + std, + color=color, + alpha=0.5, + lw=0, + zorder=11, + ) + ax.plot(mean, lw=1, color=color) + ay.plot(np.abs(mean), color=color, lw=1, label='fitted |mean|') + ay.plot(std, color=color, ls="--", lw=1, label='fitted std') + ay.plot(np.abs(np.nanmean(x, axis=0)), color='k', ls="--", lw=1, label='emp |mean|') + ay.plot(np.nanstd(x, axis=0), color='k', ls=":", lw=1, label='emp std') + ay.legend(loc='upper left', frameon=False, fancybox=False) + ay.set_xlabel("channel-major feature index") + + class KMeansPPSPlitPlot(GMMPlot): kind = "triplet" - width = 5 - height = 5 + width = 6 + height = 7 def __init__( self, cmap=plt.cm.rainbow, fitted_only=True, - scaled=True, amplitude_scaling_std=np.sqrt(0.001), amplitude_scaling_limit=1.2, merge_on_waveform_radius=True, @@ -426,17 +485,16 @@ def __init__( n_clust=None, n_iter=None, common_chans=False, - inherit_chans=True, - impute_before_center=True, + inherit_chans=False, + impute_before_center=False, with_proportions=False, min_overlap=0.0, zip_metric=None, + by_quantile=False, verbose=False, - equilibrium_kmeans_alpha=0.0, ): self.cmap = cmap self.fitted_only = fitted_only - self.scaled = scaled self.inv_lambda = 1.0 / (amplitude_scaling_std**2) self.scale_clip_low = 1.0 / amplitude_scaling_limit self.scale_clip_high = amplitude_scaling_limit @@ -453,11 +511,9 @@ def __init__( self.with_proportions = with_proportions self.zip_metric = zip_metric self.verbose = verbose - self.equilibrium_kmeans_alpha = equilibrium_kmeans_alpha - if self.scaled: - self.title = f"scaled {self.title}" + self.by_quantile = by_quantile - def draw(self, panel, gmm, unit_id): + def draw(self, panel, gmm, unit_id, impose_labels=None): n_clust = self.n_clust or gmm.kmeans_nclust n_iter = self.n_iter or gmm.kmeans_niter in_unit, labels, weights = gmm.kmeanspp( @@ -466,8 +522,10 @@ def draw(self, panel, gmm, unit_id): n_iter=n_iter, with_proportions=self.with_proportions, verbose=self.verbose, - equilibrium_kmeans_alpha=self.equilibrium_kmeans_alpha, ) + if impose_labels is not None: + labels = impose_labels + weights = torch.ones_like(weights) times = gmm.data.times_seconds[in_unit].numpy(force=True) amps = np.nanmax(gmm.data.amp_vecs[in_unit], 1) @@ -478,19 +536,21 @@ def draw(self, panel, gmm, unit_id): zip_metric = self.zip_metric or gmm.zip_metric zip_threshold = gmm.zip_threshold + assert in_unit.shape == labels.shape + in_unit, utd = gmm.get_training_data(unit_id, in_unit=in_unit, waveform_kind="reassign") + assert in_unit.shape == labels.shape + spike_ix, overlaps, badnesses = gmm[unit_id].spike_badnesses( + utd["times"], + utd["waveforms"], + utd["waveform_channels"], + kinds=(zip_metric, gmm.reassign_metric) + ) if gmm.zip_by_quantile: - in_unit, utd = gmm.get_training_data(unit_id, in_unit=in_unit, waveform_kind="reassign") - spike_ix, overlaps, badnesses = gmm[unit_id].spike_badnesses( - utd["times"], - utd["waveforms"], - utd["waveform_channels"], - kinds=(zip_metric,) - ) zip_threshold = torch.quantile(badnesses[zip_metric], zip_threshold) if ids.size > 1: top, bottom, below = panel.subfigures(nrows=3) - ax_top = top.subplots() + ax_top, ax_top2 = top.subplots(ncols=2, width_ratios=[3, 1]) axes = bottom.subplot_mosaic("dfe") else: ax_top = panel.subplots() @@ -505,6 +565,18 @@ def draw(self, panel, gmm, unit_id): if ids.size <= 1: return + bads = badnesses[gmm.reassign_metric].numpy(force=True) + spike_ix = spike_ix.numpy(force=True) + assert in_unit[spike_ix].shape == bads.shape == labels[spike_ix].shape + ax_top2.hist( + [bads[labels[spike_ix] == l] for l in ids], + color=glasbey1024[ids], + histtype="step", + ) + if self.verbose: + print(f"assigning self.bads {bads.min()=} {bads.mean()=} {bads.max()=}") + self.bads = bads + new_units = [] chans_kw = {} if self.inherit_chans: @@ -513,6 +585,7 @@ def draw(self, panel, gmm, unit_id): max_channel=gmm[unit_id].max_channel, ) if self.verbose: print(f"{self.impute_before_center=} {self.common_chans=}") + chans_subunit = [] for j, label in enumerate(ids): u = spike_interp.InterpUnit( do_interp=False, @@ -526,16 +599,27 @@ def draw(self, panel, gmm, unit_id): in_unit=inu, sampling_method=gmm.sampling_method, ) - u.fit_center( - **train_data, - padded_geom=gmm.data.padded_registered_geom, - show_progress=False, - weights=w, - **chans_kw, - ) - new_units.append(u) - + chans_subunit.append(train_data['waveform_channels']) + try: + u.fit_center( + **train_data, + padded_geom=gmm.data.padded_registered_geom, + show_progress=False, + weights=w, + **chans_kw, + ) + if gmm.cov_kind == "global": + if self.verbose: print(f"overwrite {u.var.min()=} {u.var.max()=} {gmm.var=}") + u.var.fill_(gmm.var) + new_units.append(u) + except ValueError: + continue ju = [(j, u) for j, u in enumerate(new_units) if u.n_chans_unit] + if self.verbose: + print("assigning self.ju,in_unit,split_labels") + self.ju = ju + self.in_unit = in_unit + self.split_labels = labels # plot new unit maxchan wfs and old one in black gmc = gmm[unit_id].max_channel @@ -550,7 +634,7 @@ def draw(self, panel, gmm, unit_id): else: times = torch.tensor([sum(gmm.t_bounds) / 2]).to(gmm.device) times = torch.atleast_1d(times) - + chans = torch.full((times.numel(),), unit.max_channel if pick == "unit" else gmc, device=times.device) means = unit.get_means(times).to(gmm.device) if j > 0: @@ -560,7 +644,7 @@ def draw(self, panel, gmm, unit_id): means = gmm.data.tpca._inverse_transform_in_probe(means) means = means.numpy(force=True) color = glasbey1024[j] - + lines = np.stack( (np.broadcast_to(np.arange(means.shape[1])[None], means.shape), means), axis=-1, @@ -628,15 +712,14 @@ def draw(self, panel, gmm, unit_id): ty.set_color(glasbey1024[i]) axis.set_title(f"{kind} {zip_threshold}") - dbads, bimodalities = gmm.bimodalities_prefit( + dbads, bimodalities, reas_labels = gmm.bimodalities_prefit( new_units, which_spikes=in_unit, - return_dbad=True, common_chans=self.common_chans, impute_missing=self.impute_before_center, ) - f0 = below - # f0, f1 = below.subfigures(ncols=2, width_ratios=(2, 1)) + # f0 = below + f0, f1 = below.subfigures(ncols=2, width_ratios=(2, 1)) # ax1 = f1.subplots() # im = ax1.imshow( # bimodalities, @@ -666,26 +749,39 @@ def draw(self, panel, gmm, unit_id): # tx.set_color(glasbey1024[i]) # ty.set_color(glasbey1024[i]) - axes0 = f0.subplots(nrows=len(dbads), ncols=len(dbads), sharex=True, sharey=True) - for i in range(len(dbads)): - for j in range(len(dbads)): - if j >= i: - axes0[i, j].set_visible(False) - continue - - axes0[i, j].scatter( - dbads[i], - dbads[j], - c=np.stack([glasbey1024[i], glasbey1024[j]], axis=0)[(dbads[i] < dbads[j]).astype(int)], - s=3, - lw=0, - ) - axes0[i, j].set_xticks([]) - axes0[i, j].set_yticks([]) - axes0[i, j].set_title(f"{bimodalities[i, j]:0.2f}", fontsize=5) - - # ax0.hist(dbads.todense(), bins=32, density=True, histtype="step") - + if len(dbads) > 1: + axes0 = f0.subplots(nrows=len(dbads) - 1, ncols=len(dbads) - 1, sharex=True, sharey=True, squeeze=False) + for i in range(1, len(dbads)): + for j in range(len(dbads) - 1): + ax = axes0[i - 1, j] + if j >= i: + ax.set_visible(False) + continue + + ax.scatter( + dbads[i], + dbads[j], + c=np.stack([glasbey1024[i], glasbey1024[j]], axis=0)[(dbads[i] < dbads[j]).astype(int)], + s=3, + lw=0, + ) + ax.set_xticks([]) + ax.set_yticks([]) + ax.set_title(f"{bimodalities[i, j]:0.2f}", fontsize=5) + + # ax0.hist(dbads.todense(), bins=32, density=True, histtype="step") + ax1 = f1.subplots() + # chans histogram... + mn = np.inf + mx = -np.inf + cs = [] + for j, ucc in enumerate(chans_subunit): + ucc = ucc[ucc < gmm.data.n_chans_full].numpy(force=True) + chans_subunit[j] = ucc + mn = min(ucc.min(), mn) + mx = max(ucc.max(), mx) + cs.append(glasbey1024[j]) + ax1.hist(chans_subunit, histtype="bar", bins=np.arange(mn, mx + 1), color=cs, stacked=True) class HDBScanSplitPlot(GMMPlot): @@ -740,7 +836,7 @@ def draw(self, panel, gmm, unit_id): class EmbedsOverTimePlot(GMMPlot): kind = "embeds" - width = 5 + width = 4 height = 3 def __init__(self, colors="gbr"): @@ -787,7 +883,7 @@ def draw(self, panel, gmm, unit_id): class AmplitudesOverTimePlot(GMMPlot): kind = "embeds" - width = 5 + width = 4 height = 3 def __init__(self, kinds=("recon", "maxchan_energy", 'model'), colors="bkrg"): @@ -863,8 +959,8 @@ def draw(self, panel, gmm, unit_id): class BadnessesOverTimePlot(GMMPlot): kind = "embeds" - width = 5 - height = 3 + width = 4 + height = 2 def __init__(self, colors="rgb", kinds=None): self.colors = colors @@ -910,8 +1006,8 @@ def draw(self, panel, gmm, unit_id): class FeaturesVsBadnessesPlot(GMMPlot): - kind = "triplet" - width = 5 + kind = "embeds" + width = 4 height = 2 def __init__(self, colors="rgb", kinds=("1-r^2", "1-scaledr^2")): @@ -1667,7 +1763,7 @@ class NeighborBimodality(GMMMergePlot): width = 3 height = 10 - def __init__(self, n_neighbors=5, badness_kind=None, do_reg=False, masked=False, mask_radius_s=5.0, impute_missing=False, max_spikes=2048, cut=None, kind="isotonic"): + def __init__(self, n_neighbors=5, badness_kind=None, do_reg=False, masked=False, mask_radius_s=5.0, impute_missing=False, max_spikes=2048, cut=None, kind="isotonic", verbose=False): self.n_neighbors = n_neighbors self.badness_kind = badness_kind self.do_reg = do_reg @@ -1677,12 +1773,14 @@ def __init__(self, n_neighbors=5, badness_kind=None, do_reg=False, masked=False, self.max_spikes = max_spikes self.cut = None self.kind = kind + self.verbose = verbose def draw(self, panel, gmm, unit_id): - from isosplit import isocut, dipscore_at kind = self.badness_kind if kind is None: kind = gmm.reassign_metric + if self.verbose: + print(f"{kind=} {self.masked=} {self.impute_missing=}") (in_self,) = torch.nonzero(gmm.labels == unit_id, as_tuple=True) neighbors = self.get_neighbors(gmm, unit_id) # remove self @@ -1818,7 +1916,7 @@ def __init__( merge_on_waveform_radius=True, n_neighbors=5, show_values=True, - badness_kind="1-r^2", + badness_kind=None, ): self.cmap = cmap self.dist_vmax = dist_vmax @@ -1839,7 +1937,6 @@ def draw(self, panel, gmm, unit_id): units_b=torch.from_numpy(neighbors), kind=kind, allow_chan_subset=self.merge_on_waveform_radius, - ) dists = divergences.numpy(force=True) @@ -2006,20 +2103,21 @@ def draw(self, panel, gmm, unit_id): default_gmm_plots = ( ISIHistogram(), ChansHeatmap(), + MStep(), # HDBScanSplitPlot(spike_kind="residual_full"), # HDBScanSplitPlot(), # ZipperSplitPlot(), KMeansPPSPlitPlot(inherit_chans=True, impute_before_center=True), - GridMeansSingleChanPlot(), + # GridMeansSingleChanPlot(), InputWaveformsSingleChanPlot(), # InputWaveformsSingleChanOverTimePlot(channel="unit"), - InputWaveformsSingleChanOverTimePlot(channel="natural"), - ResidualsSingleChanPlot(), + # InputWaveformsSingleChanOverTimePlot(channel="natural"), + # ResidualsSingleChanPlot(), AmplitudesOverTimePlot(), BadnessesOverTimePlot(), - EmbedsOverTimePlot(), + # EmbedsOverTimePlot(), # DPCSplitPlot(spike_kind="residual_full"), - DPCSplitPlot(spike_kind="split"), + # DPCSplitPlot(spike_kind="split"), # DPCSplitPlot(spike_kind="global"), # DPCSplitPlot(spike_kind="global", feature="spread_amp"), FeaturesVsBadnessesPlot(), @@ -2035,9 +2133,10 @@ def draw(self, panel, gmm, unit_id): ViolatorTimesVAmps(), NearbyTimesVAmps(), NearbyDivergencesMatrix(merge_on_waveform_radius=True, badness_kind="diagz"), - NearbyDivergencesMatrix(merge_on_waveform_radius=False, badness_kind="diagz"), - NearbyDivergencesMatrix(merge_on_waveform_radius=True, badness_kind="1-r^2"), + # NearbyDivergencesMatrix(merge_on_waveform_radius=False, badness_kind="diagz"), + NearbyDivergencesMatrix(merge_on_waveform_radius=True, badness_kind="max1-r^2"), NearbyDivergencesMatrix(merge_on_waveform_radius=True, badness_kind="cos"), + NearbyDivergencesMatrix(merge_on_waveform_radius=True, badness_kind="l2normeucsq"), # NearbyDivergencesMatrix(merge_on_waveform_radius=True, badness_kind="1-scaledr^2"), # NearbyDivergencesMatrix(merge_on_waveform_radius=False, badness_kind="diagz"), # NearbyDivergencesMatrix(merge_on_waveform_radius=False, badness_kind="1-scaledr^2"), @@ -2052,6 +2151,15 @@ def draw(self, panel, gmm, unit_id): # NeighborBimodality(badness_kind="1-scaledr^2", masked=True), ) +gmm_selected_plots = ( + NearbyMeansMultiChan(), + NearbyDivergencesMatrix(merge_on_waveform_radius=True), + MStep(), + ChansHeatmap(), + KMeansPPSPlitPlot(inherit_chans=True, impute_before_center=True), + AmplitudesOverTimePlot(), +) + def make_unit_gmm_summary( gmm, @@ -2262,3 +2370,91 @@ def mad(x, axis=None): x = x - np.median(x, axis=axis, keepdims=True) np.abs(x, out=x) return np.median(x, axis=axis) + + +def plot_input_waveforms( + gmm, + which, + waveform_kind="original", + max_abs_amp=None, + lw=1, + color=None, + colors=None, + show_zero=False, + subar=True, + msbar=False, + zlim="tight", + ax=None, + **more_geomplot_kwargs, +): + data = gmm.spike_data(which, waveform_kind=waveform_kind) + geom = gmm.data.registered_geom + + waveforms = data['waveforms'] + waveform_channels = data['waveform_channels'] + # times = data['times'] + + n, r, c = waveforms.shape + waveforms = waveforms.mT.reshape(n * c, r) + waveforms = gmm.data.tpca._inverse_transform_in_probe(waveforms) + waveforms = waveforms.view(n, c, -1).mT + + geomplot( + waveforms.numpy(force=True), + channels=waveform_channels.numpy(force=True), + geom=geom.numpy(force=True), + max_abs_amp=max_abs_amp, + lw=lw, + color=color, + colors=colors, + show_zero=show_zero, + subar=subar, + msbar=msbar, + zlim=zlim, + ax=ax, + **more_geomplot_kwargs, + ) + + +def plot_raw_waveforms( + gmm, + original_sorting, + rec, + which, + waveform_kind="original", + max_abs_amp=None, + lw=1, + color=None, + colors=None, + show_zero=False, + subar=True, + msbar=False, + zlim="tight", + ax=None, + **more_geomplot_kwargs, +): + main_channels = original_sorting.channels[gmm.data.keepers[which]] + + waveforms = spikeio.read_waveforms_channel_index( + rec, + original_sorting.times_samples[gmm.data.keepers[which]], + main_channels=main_channels, + channel_index=gmm.data.original_channel_index, + ) + + geomplot( + waveforms, + max_channels=main_channels, + channel_index=gmm.data.original_channel_index, + geom=rec.get_channel_locations(), + max_abs_amp=max_abs_amp, + lw=lw, + color=color, + colors=colors, + show_zero=show_zero, + subar=subar, + msbar=msbar, + zlim=zlim, + ax=ax, + **more_geomplot_kwargs, + ) diff --git a/src/dartsort/vis/unit.py b/src/dartsort/vis/unit.py index d2648b66..727b21d4 100644 --- a/src/dartsort/vis/unit.py +++ b/src/dartsort/vis/unit.py @@ -22,6 +22,7 @@ from ..util.analysis import DARTsortAnalysis from ..util.multiprocessing_util import CloudpicklePoolExecutor, get_pool from . import layout +from .analysis_plots import isi_hist, correlogram, plot_correlogram, bar; from .colors import glasbey1024 from .waveforms import geomplot @@ -87,9 +88,7 @@ def draw(self, panel, sorting_analysis, unit_id): times_samples = sorting_analysis.times_samples( which=sorting_analysis.in_unit(unit_id) ) - lags, acg = correlogram(times_samples, max_lag=self.max_lag) - bar(axis, lags, acg, fill=True, color="k") - axis.set_xlabel("lag (samples)") + plot_correlogram(axis, times_samples, max_lag=self.max_lag) axis.set_ylabel("acg") @@ -108,18 +107,7 @@ def draw(self, panel, sorting_analysis, unit_id, axis=None, color="k", label=Non times_s = sorting_analysis.times_seconds( which=sorting_analysis.in_unit(unit_id) ) - dt_ms = np.diff(times_s) * 1000 - bin_edges = np.arange( - 0, - self.max_ms + self.bin_ms, - self.bin_ms, - ) - # counts, _ = np.histogram(dt_ms, bin_edges) - # bin_centers = 0.5 * (bin_edges[1:] + bin_edges[:-1]) - # axis.bar(bin_centers, counts) - plt.hist(dt_ms, bin_edges, color=color, label=label) - axis.set_xlabel("isi (ms)") - axis.set_ylabel(f"count (out of {dt_ms.size} total isis)") + isi_hist(times_s, axis, bin_ms=self.bin_ms, max_ms=self.max_ms, color=color, label=label) class XZScatter(UnitPlot): @@ -1087,29 +1075,6 @@ def make_all_summaries( # -- utilities -def correlogram(times_a, times_b=None, max_lag=50): - lags = np.arange(-max_lag, max_lag + 1) - ccg = np.zeros(len(lags), dtype=int) - - times_a = np.sort(times_a) - auto = times_b is None - if auto: - times_b = times_a - else: - times_b = np.sort(times_b) - - for i, lag in enumerate(lags): - lagged_b = times_b + lag - insertion_inds = np.searchsorted(times_a, lagged_b) - found = insertion_inds < len(times_a) - ccg[i] = np.sum(times_a[insertion_inds[found]] == lagged_b[found]) - - if auto: - ccg[lags == 0] = 0 - - return lags, ccg - - def trim_waveforms(waveforms, old_offset=42, new_offset=42, new_length=121): if waveforms.shape[1] == new_length and old_offset == new_offset: return waveforms @@ -1202,9 +1167,3 @@ def _summary_job(unit_id): fig.savefig(tmp_out, dpi=_summary_job_context.dpi) tmp_out.rename(final_out) plt.close(fig) - - -def bar(ax, x, y, **kwargs): - dx = np.diff(x).min() - x0 = np.concatenate((x - dx, x[-1:] + dx)) - ax.stairs(y, x0, **kwargs)