Skip to content

Commit

Permalink
Merge branch 'main' of github.com:cwindolf/spike-psvae
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Sep 19, 2024
2 parents f0883b7 + 295d51d commit 6602361
Show file tree
Hide file tree
Showing 16 changed files with 856 additions and 213 deletions.
2 changes: 1 addition & 1 deletion src/dartsort/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@
from .util.analysis import DARTsortAnalysis
from .util.data_util import DARTsortSorting
from .util.waveform_util import make_channel_index
from .cluster import merge
from .cluster import merge, postprocess

__version__ = importlib.metadata.version("dartsort")
2 changes: 1 addition & 1 deletion src/dartsort/cluster/initial.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def cluster_chunk(
amps = getattr(sorting, amplitudes_dataset_name)

if recording is None:
with h5py.File(sorting.parent_h5_path, "r") as h5:
with h5py.File(sorting.parent_h5_path, "r", locking=False) as h5:
geom = h5["geom"][:]
else:
geom = recording.get_channel_locations()
Expand Down
80 changes: 76 additions & 4 deletions src/dartsort/cluster/modes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import numpy as np
from scipy.signal import find_peaks
from scipy.stats import norm
from . import density
# todo: replace all isosplit stuff with things based on scipy's isotonic regression.

Expand Down Expand Up @@ -35,6 +37,44 @@ def fit_unimodal_right(x, f, weights=None, cut=0, hard=False):
return out


def fit_truncnorm_right(x, f, weights=None, cut=0, hard=False, n_iter=10):
"""Like above, but fits truncated normal to xs > cut with MoM."""
if weights is None:
weights = np.ones_like(f)

# figure out where cut lands and what f(cut) should be
cuti_left = np.searchsorted(x, cut)
if cuti_left >= len(x) - 2:
# everything is to the left... return uniform
return np.full_like(f, 1/len(f))
cuti_right = np.searchsorted(x, cut, side="right") - 1
if cuti_right <= 1:
# everything is to the right... fit normal!
mean = np.average(x, weights=weights)
var = np.average((x - mean) ** 2, weights=weights)
return norm.pdf(x, loc=mean, scale=np.sqrt(var))
# else, in the middle somewhere...
assert cuti_left in (cuti_right, cuti_right + 1)

xcut = x[cuti_right:]
mean = np.average(xcut, weights=weights[cuti_right:])
var = np.average((xcut - mean) ** 2, weights=weights[cuti_right:])
mu = mean
sigma = std = np.sqrt(var)
for i in range(n_iter):
alpha = (cut - mu) / sigma
phi_alpha = norm.cdf(alpha)
Z = 1.0 - phi_alpha
# we have: mean ~ mu + sigma*phi(alpha)/Z
# and var ~ sigma^2(1+alpha phi(alpha) /Z - (phi(alpha)/Z)^2)
# implies mu = mean - sigma*phi(alpha)/Z
# sigma^2 = var/(1+alpha phi(alpha) /Z - (phi(alpha)/Z)^2)
sigma = np.sqrt(var / (1 + alpha * phi_alpha / Z - (phi_alpha / Z) ** 2))
mu = mean - sigma * phi_alpha / Z

return norm.pdf(x, loc=mu, scale=sigma)


def fit_bimodal_at(x, f, weights=None, cut=0):
from isosplit import up_down_isotonic_regression
if weights is None:
Expand All @@ -49,7 +89,16 @@ def fit_bimodal_at(x, f, weights=None, cut=0):
return out


def smoothed_dipscore_at(cut, samples, sample_weights, alternative="smoothed", dipscore_only=False, score_kind="tv"):
def smoothed_dipscore_at(
cut,
samples,
sample_weights,
alternative="smoothed",
dipscore_only=False,
score_kind="tv",
cut_relmax_order=3,
kind="isotonic",
):
if sample_weights is None:
sample_weights = np.ones_like(samples)
densities = density.get_smoothed_densities(
Expand All @@ -63,9 +112,27 @@ def smoothed_dipscore_at(cut, samples, sample_weights, alternative="smoothed", d
)
spacings = np.diff(samples)
spacings = np.concatenate((spacings[:1], 0.5 * (spacings[1:] + spacings[:-1]), spacings[-1:]))
densities /= (densities * spacings).sum()

if cut is None:
# closest maxes left + right of 0
maxers, _ = find_peaks(densities, distance=cut_relmax_order)
coords = samples[maxers]
left_cut = right_cut = 0
if (coords < 0).any():
left_cut = coords[coords < 0].max()
if (coords > 0).any():
right_cut = coords[coords > 0].min()
candidates = np.logical_and(samples >= left_cut, samples <= right_cut)
cut = 0
if candidates.any():
cut = samples[candidates][np.argmin(densities[candidates])]

if alternative == "bimodal":
densities = fit_bimodal_at(samples, densities, weights=sample_weights, cut=cut)
densities /= (densities * spacings).sum()
densities /= (densities * spacings).sum()
else:
assert alternative == "smoothed"

score = 0
best_dens_err = np.inf
Expand All @@ -83,7 +150,12 @@ def smoothed_dipscore_at(cut, samples, sample_weights, alternative="smoothed", d
s = np.ascontiguousarray(samples[order])
d = np.ascontiguousarray(densities[order])
w = np.ascontiguousarray(sample_weights[order])
dens = fit_unimodal_right(sign * s, d, weights=w, hard=hard)
if kind == "isotonic":
dens = fit_unimodal_right(sign * s, d, weights=w, cut=sign * cut, hard=hard)
elif kind == "truncnorm":
dens = fit_truncnorm_right(sign * s, d, weights=w, cut=sign * cut, hard=hard)
else:
assert False
dens = dens[order]
dens /= (dens * spacings).sum()
dens_err = (np.abs(dens - densities) * spacings).sum()
Expand All @@ -100,4 +172,4 @@ def smoothed_dipscore_at(cut, samples, sample_weights, alternative="smoothed", d
if dipscore_only:
return score

return score, samples, densities, best_uni
return score, samples, densities, best_uni, cut
89 changes: 89 additions & 0 deletions src/dartsort/cluster/postprocess.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from dataclasses import replace

import numpy as np
import torch
import torch.nn.functional as F

from .. import config
from ..templates import TemplateData
Expand Down Expand Up @@ -69,3 +71,90 @@ def realign_and_chuck_noisy_template_units(
)

return new_sorting, new_template_data


def template_collision_scores(
recording,
template_data,
svd_compression_rank=20,
temporal_upsampling_factor=8,
min_channel_amplitude=0.0,
amplitude_scaling_variance=0.0,
amplitude_scaling_boundary=0.5,
trough_offset_samples=42,
device=None,
max_n_colliding=5,
threshold=None,
save_folder=None,
):
from ..peel.matching import ObjectiveUpdateTemplateMatchingPeeler

if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device)
if threshold is None or threshold < 1:
factor = 0.9
if threshold < 1:
factor = threshold
threshold = factor * np.square(template_data.templates).sum((1, 2)).min()
print(f"Using {threshold=:0.3f} for decollision")

matcher = ObjectiveUpdateTemplateMatchingPeeler(
recording,
template_data,
channel_index=None,
featurization_pipeline=None,
motion_est=None,
svd_compression_rank=svd_compression_rank,
temporal_upsampling_factor=temporal_upsampling_factor,
min_channel_amplitude=min_channel_amplitude,
refractory_radius_frames=10,
amplitude_scaling_variance=amplitude_scaling_variance,
amplitude_scaling_boundary=amplitude_scaling_boundary,
conv_ignore_threshold=0.0,
coarse_approx_error_threshold=0.0,
trough_offset_samples=template_data.trough_offset_samples,
threshold=threshold,
# chunk_length_samples=30_000,
# n_chunks_fit=40,
# max_waveforms_fit=50_000,
# n_waveforms_fit=20_000,
# fit_subsampling_random_state=0,
# fit_sampling="random",
max_iter=max_n_colliding,
dtype=torch.float,
)
matcher.to(device)
save_folder.mkdir(exist_ok=True)
matcher.precompute_peeling_data(save_folder)
matcher.to(device)

n = len(template_data.templates)
scores = np.zeros(n)
matches = []
unit_mask = torch.arange(n, device=device)
for j in range(n):
mask = template_data.spike_counts_by_channel[j] > 0
template = template_data.templates[j][:, mask]

mask = torch.from_numpy(mask)
compressed_template_data = matcher.templates_at_time(0.0, spatial_mask=mask)
traces = F.pad(
torch.from_numpy(template).to(device),
(0, 0, *(2 * [template_data.spike_length_samples])),
)
res = matcher.match_chunk(
traces,
compressed_template_data,
trough_offset_samples=42,
left_margin=0,
right_margin=0,
threshold=30,
return_collisioncleaned_waveforms=False,
return_residual=True,
unit_mask=unit_mask != j,
)
resid = res["residual"][template_data.spike_length_samples:2*template_data.spike_length_samples]
scores[j] = resid.square().sum() / traces.square().sum()
matches.append(res["labels"].numpy(force=True))
return scores, matches
24 changes: 17 additions & 7 deletions src/dartsort/detect/detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def detect_and_deduplicate(
spatial_dedup_batch_size=512,
exclude_edges=True,
return_energies=False,
detection_mask=None,
):
"""Detect and deduplicate peaks
Expand All @@ -34,6 +35,9 @@ def detect_and_deduplicate(
dedup_temporal_radius : int
Only the largest peak within this sliding radius
will be kept
detection_mask : tensor
If supplied, this floating tensor of 1s or 0s will act
as a gate to suppress detections in the 0s
Returns
-------
Expand Down Expand Up @@ -80,6 +84,10 @@ def detect_and_deduplicate(
# remove peaks smaller than our threshold
F.threshold_(energies, threshold, 0.0)

# optionally remove censored peaks
if detection_mask is not None:
energies.mul_(detection_mask)

# -- temporal deduplication
if dedup_temporal_radius > 0:
max_energies = F.max_pool2d(
Expand All @@ -97,23 +105,25 @@ def detect_and_deduplicate(
max_energies = max_energies[0, 0]

# -- spatial deduplication
# we would like to max pool again on the other axis,
# but that doesn't support any old radial neighborhood
# this is max pooling within the channel index's neighborhood's
if all_dedup:
max_energies[:] = max_energies.max(dim=1, keepdim=True).values
max_energies = max_energies.max(dim=1, keepdim=True).values
elif dedup_channel_index is not None:
# pad channel axis with extra chan of 0s
max_energies = F.pad(max_energies, (0, 1))
for batch_start in range(0, nsamples, spatial_dedup_batch_size):
batch_end = batch_start + spatial_dedup_batch_size
max_energies[batch_start:batch_end, :nchans] = torch.max(
max_energies[batch_start:batch_end, dedup_channel_index], dim=2
).values
torch.amax(
max_energies[batch_start:batch_end, dedup_channel_index],
dim=2,
out=max_energies[batch_start:batch_end, :nchans],
)
max_energies = max_energies[:, :nchans]

# if temporal/spatial max made you grow, you were not a peak!
if (dedup_temporal_radius > 0) or (dedup_channel_index is not None):
max_energies[max_energies > energies] = 0.0
# max_energies[max_energies > energies] = 0.0
max_energies.masked_fill_(max_energies > energies, 0.0)

# sparsify and return
if exclude_edges:
Expand Down
26 changes: 19 additions & 7 deletions src/dartsort/peel/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ def peel_chunk(

return match_results

def templates_at_time(self, t_s):
def templates_at_time(self, t_s, spatial_mask=None):
"""Handle drift -- grab the right spatial neighborhoods."""
pconvdb = self.pairwise_conv_db
pitch_shifts_a = pitch_shifts_b = None
Expand All @@ -457,6 +457,7 @@ def templates_at_time(self, t_s):
):
pconvdb.to(self.objective_spatial_components.device)
if self.is_drifting:
assert spatial_mask is None
pitch_shifts_b, cur_spatial = template_util.templates_at_time(
t_s,
self.spatial_components,
Expand Down Expand Up @@ -505,6 +506,9 @@ def templates_at_time(self, t_s):
else:
cur_spatial = self.spatial_components
cur_obj_spatial = self.objective_spatial_components
if spatial_mask is not None:
cur_spatial = cur_spatial[:, :, spatial_mask]
cur_obj_spatial = cur_obj_spatial[:, :, spatial_mask]
max_channels = self.registered_template_ampvecs.argmax(1)

# if not pconvdb._is_torch:
Expand Down Expand Up @@ -542,8 +546,10 @@ def match_chunk(
left_margin=0,
right_margin=0,
threshold=30,
return_collisioncleaned_waveforms=True,
return_residual=False,
return_conv=False,
unit_mask=None,
):
"""Core peeling routine for subtraction"""
# initialize residual, it needs to be padded to support our channel
Expand Down Expand Up @@ -588,6 +594,7 @@ def match_chunk(
padded_objective,
refrac_mask,
compressed_template_data,
unit_mask=unit_mask,
)
if new_peaks is None:
break
Expand Down Expand Up @@ -627,12 +634,14 @@ def match_chunk(
peaks.subset(*torch.nonzero(valid, as_tuple=True), sort=True)

# extract collision-cleaned waveforms on small neighborhoods
channels, waveforms = compressed_template_data.get_collisioncleaned_waveforms(
residual_padded,
peaks,
self.channel_index,
spike_length_samples=self.spike_length_samples,
)
channels = waveforms = None
if return_collisioncleaned_waveforms:
channels, waveforms = compressed_template_data.get_collisioncleaned_waveforms(
residual_padded,
peaks,
self.channel_index,
spike_length_samples=self.spike_length_samples,
)

res = dict(
n_spikes=peaks.n_spikes,
Expand All @@ -658,6 +667,7 @@ def find_peaks(
padded_objective,
refrac_mask,
compressed_template_data,
unit_mask=None,
):
# update the coarse objective
torch.add(
Expand All @@ -671,6 +681,8 @@ def find_peaks(
objective = (padded_objective + refrac_mask)[
:-1, self.obj_pad_len : -self.obj_pad_len
]
if unit_mask is not None:
objective[torch.logical_not(unit_mask)] = -torch.inf
# formerly used detect_and_deduplicate, but that was slow.
objective_max, max_obj_template = objective.max(dim=0)
times = argrelmax(objective_max, self.spike_length_samples, self.threshold)
Expand Down
Loading

0 comments on commit 6602361

Please sign in to comment.