diff --git a/src/dartsort/peel/matching.py b/src/dartsort/peel/matching.py index 6dba2593..289ae47b 100644 --- a/src/dartsort/peel/matching.py +++ b/src/dartsort/peel/matching.py @@ -685,7 +685,7 @@ def find_peaks( objective[torch.logical_not(unit_mask)] = -torch.inf # formerly used detect_and_deduplicate, but that was slow. objective_max, max_obj_template = objective.max(dim=0) - times = argrelmax(objective_max, self.spike_length_samples, self.threshold) + times = spiketorch.argrelmax(objective_max, self.spike_length_samples, self.threshold) obj_template_indices = max_obj_template[times] # remove peaks inside the padding if not times.numel(): @@ -816,39 +816,13 @@ def __post_init__(self): def convolve(self, traces, padding=0, out=None): """Convolve the objective templates with traces.""" - out_len = traces.shape[0] + 2 * padding - self.spike_length_samples + 1 - if out is None: - out = torch.empty( - (self.objective_n_templates, out_len), - dtype=traces.dtype, - device=traces.device, - ) - else: - assert out.shape == (self.objective_n_templates, out_len) - - for q in range(self.rank): - # units x time - rec_spatial = self.objective_spatial_singular[:, q, :] @ traces.T - # convolve with temporal components -- units x time - temporal = self.objective_temporal_components[:, :, q] - # temporalf = self.objective_temporalf[:, :, q] - # conv1d with groups! only convolve each unit with its own temporal filter. - conv = F.conv1d( - rec_spatial[None], - temporal[:, None, :], - groups=self.objective_n_templates, - padding=padding, - )[0] - # conv = spiketorch.depthwise_oaconv1d( - # rec_spatial, temporal, padding=padding, f2=temporalf - # ) - if q: - out += conv - else: - out.copy_(conv) - - # back to units x time (remove extra dim used for conv1d) - return out + return spiketorch.convolve_lowrank( + traces, + self.objective_spatial_singular, + self.objective_temporal_components, + padding=padding, + out=out, + ) def subtract_conv( self, @@ -1238,18 +1212,3 @@ def _grow_buffer(x, old_length, new_size): new = torch.empty(new_size, dtype=x.dtype, device=x.device) new[:old_length] = x[:old_length] return new - - -def argrelmax(x, radius, threshold, exclude_edge=True): - x1 = F.max_pool1d( - x[None, None], - kernel_size=2 * radius + 1, - padding=radius, - stride=1, - )[0, 0] - x1[x < x1] = 0 - F.threshold_(x1, threshold, 0.0) - ix = torch.nonzero(x1)[:, 0] - if exclude_edge: - return ix[(ix > 0) & (ix < x.numel() - 1)] - return ix diff --git a/src/dartsort/util/spiketorch.py b/src/dartsort/util/spiketorch.py index ddad49b3..26cf0c5b 100644 --- a/src/dartsort/util/spiketorch.py +++ b/src/dartsort/util/spiketorch.py @@ -217,9 +217,74 @@ def reduce_at_(dest, ix, src, reduce, include_self=True): ) +def argrelmax(x, radius, threshold, exclude_edge=True): + x1 = F.max_pool1d( + x[None, None], + kernel_size=2 * radius + 1, + padding=radius, + stride=1, + )[0, 0] + x1[x < x1] = 0 + F.threshold_(x1, threshold, 0.0) + ix = torch.nonzero(x1)[:, 0] + if exclude_edge: + return ix[(ix > 0) & (ix < x.numel() - 1)] + return ix + + _cdtypes = {torch.float32: torch.complex64, torch.float64: torch.complex128} +def convolve_lowrank( + traces, + spatial_singular, + temporal_components, + padding=0, + out=None, +): + """Depthwise convolution of traces with templates""" + n_templates, spike_length_samples, rank = temporal_components.shape + + out_len = traces.shape[0] + 2 * padding - spike_length_samples + 1 + if out is None: + out = torch.empty( + (n_templates, out_len), + dtype=traces.dtype, + device=traces.device, + ) + else: + assert out.shape == (n_templates, out_len) + + for q in range(rank): + # units x time + rec_spatial = spatial_singular[:, q, :] @ traces + + # convolve with temporal components -- units x time + temporal = temporal_components[:, :, q] + + # temporalf = temporalf[:, :, q] + # conv1d with groups! only convolve each unit with its own temporal filter + conv = F.conv1d( + rec_spatial[None], + temporal[:, None, :], + groups=n_templates, + padding=padding, + )[0] + + # o-a turns out not to be helpful, sadly + # conv = depthwise_oaconv1d( + # rec_spatial, temporal, padding=padding, f2=temporalf + # ) + + if q: + out += conv + else: + out.copy_(conv) + + # back to units x time (remove extra dim used for conv1d) + return out + + def real_resample(x, num, dim=0): """torch version of a special case of scipy.signal.resample