Skip to content

Commit

Permalink
Merge
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Nov 21, 2024
2 parents 3a8cd1e + 2f5358d commit 64721b4
Show file tree
Hide file tree
Showing 8 changed files with 555 additions and 211 deletions.
474 changes: 343 additions & 131 deletions src/dartsort/cluster/gaussian_mixture.py

Large diffs are not rendered by default.

4 changes: 1 addition & 3 deletions src/dartsort/cluster/stable_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,9 +442,7 @@ def subset_neighborhoods(self, channels, min_coverage=1.0, add_to_overlaps=None)
for j in covered_ids
}
n_spikes = self.popcounts[covered_ids].sum()
if add_to_overlaps is not None:
add_to_overlaps[covered_ids] += 1
return neighborhood_info, n_spikes
return covered_ids, neighborhood_info, n_spikes

def spike_neighborhoods(self, channels, spike_indices, min_coverage=1.0):
"""Like subset_neighborhoods, but for an already chosen collection of spikes
Expand Down
2 changes: 1 addition & 1 deletion src/dartsort/templates/pairwise_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def conv_to_resid(
svs_a = low_rank_templates_a.singular_values[template_indices_a]
svs_b = low_rank_templates_b.singular_values[template_indices_b]
active_a = torch.any(spatial_a > 0, dim=1).to(svs_a)
if ignore_empty_channels:
if ignore_empty_channels and low_rank_templates_b.spike_counts_by_channel is not None:
active_b = low_rank_templates_b.spike_counts_by_channel[template_indices_b]
active_b = active_b > 0
active_b = torch.from_numpy(active_b).to(svs_a)
Expand Down
101 changes: 84 additions & 17 deletions src/dartsort/util/hybrid_util.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
import dataclasses
import numpy as np
from tqdm.auto import tqdm
import warnings
from spikeinterface.core import BaseRecording, BaseRecordingSegment, Templates
from spikeinterface.core.core_tools import define_function_from_class
from spikeinterface.extractors import NumpySorting
from spikeinterface.generation.drift_tools import InjectDriftingTemplatesRecording, DriftingTemplates, move_dense_templates
from spikeinterface.preprocessing.basepreprocessor import (
BasePreprocessor, BasePreprocessorSegment)
from probeinterface import Probe
from scipy.spatial import KDTree
from scipy.sparse import csgraph, coo_array

from ..templates import TemplateData
from .analysis import DARTsortAnalysis
from .data_util import DARTsortSorting
from ..config import unshifted_raw_template_config
from ..templates import TemplateData



def get_drifty_hybrid_recording(
Expand Down Expand Up @@ -82,7 +79,7 @@ def get_drifty_hybrid_recording(
displacement_vectors=[disp],
displacement_sampling_frequency=displacement_sampling_frequency,
displacement_unit_factor=displacement_unit_factor,
amplitude_factor=amplitude_factor
amplitude_factor=amplitude_factor,
)
rec.annotate(peak_channel=peak_channels.tolist())
return rec
Expand Down Expand Up @@ -205,6 +202,7 @@ def refractory_poisson_spike_train(

return spike_samples


def piecewise_refractory_poisson_spike_train(rates, bins, binsize_samples, **kwargs):
"""
Returns a spike train with variable firing rate using refractory_poisson_spike_train().
Expand All @@ -214,13 +212,14 @@ def piecewise_refractory_poisson_spike_train(rates, bins, binsize_samples, **kwa
:param binsize_samples: number of samples per bin
:param **kwargs: kwargs to feed to refractory_poisson_spike_train()
"""
sp_tr = np.concatenate(
[
refractory_poisson_spike_train(r, binsize_samples, **kwargs) + bins[i] if r > 0.1 else []
for i, r in enumerate(rates)
]
)
return sp_tr
st = []
for rate, bin in zip(rates, bins):
if rate < 0.1:
continue
binst = refractory_poisson_spike_train(rate, binsize_samples, **kwargs)
st.append(bin + binst)
st = np.concatenate(st)
return st


def precompute_displaced_registered_templates(
Expand Down Expand Up @@ -255,6 +254,66 @@ def precompute_displaced_registered_templates(
return ret


def closest_clustering(gt_st, peel_st, geom=None, match_dt_ms=0.1, match_radius_um=0.0, p=2.0, M=50.0):
frames_per_ms = gt_st.sampling_frequency / 1000
delta_frames = match_dt_ms * frames_per_ms
rescale = [delta_frames]
gt_pos = gt_st.times_samples[:, None]
peel_pos = peel_st.times_samples[:, None]
if match_radius_um:
rescale = rescale + (geom.shape[1] * [match_radius_um])
gt_pos = np.c_[gt_pos, geom[gt_st.channels]]
peel_pos = np.c_[peel_pos, geom[peel_st.channels]]
else:
gt_pos = gt_pos.astype(float)
peel_pos = peel_pos.astype(float)
gt_pos /= rescale
peel_pos /= rescale
labels = greedy_match(gt_pos, peel_pos, dx=1.0 / frames_per_ms)
labels[labels >= 0] = gt_st.labels[labels[labels >= 0]]

return dataclasses.replace(peel_st, labels=labels)


def greedy_match(gt_coords, test_coords, max_val=1.0, dx=1./30, workers=-1, p=2.0):
assignments = np.full(len(test_coords), -1)
gt_unmatched = np.ones(len(gt_coords), dtype=bool)

for j, thresh in enumerate(
tqdm(np.arange(0.0, max_val + dx + 2e-5, dx), desc="match")
):
test_unmatched = np.flatnonzero(assignments < 0)
if not test_unmatched.size:
break
test_kdtree = KDTree(test_coords[test_unmatched])
gt_ix = np.flatnonzero(gt_unmatched)
d, i = test_kdtree.query(
gt_coords[gt_ix],
k=1,
distance_upper_bound=min(thresh, max_val),
workers=workers,
p=p,
)
# handle multiple gt spikes getting matched to the same peel ix
thresh_matched = i < test_kdtree.n
_, ii = np.unique(i, return_index=True)
i = i[ii]
thresh_matched = thresh_matched[ii]

gt_ix = gt_ix[ii]
gt_ix = gt_ix[thresh_matched]
i = i[thresh_matched]
assignments[test_unmatched[i]] = gt_ix
gt_unmatched[gt_ix] = False

if not gt_unmatched.any():
break
if thresh > max_val:
break

return assignments


def sorting_from_times_labels(times, labels, recording=None, sampling_frequency=None, determine_channels=True, template_config=unshifted_raw_template_config, n_jobs=0):
channels = np.zeros_like(labels)
if sampling_frequency is None:
Expand All @@ -273,6 +332,14 @@ def sorting_from_times_labels(times, labels, recording=None, sampling_frequency=
return sorting, td


def sorting_from_spikeinterface(sorting, recording=None, determine_channels=True, template_config=unshifted_raw_template_config, n_jobs=0):
def sorting_from_spikeinterface(
sorting,
recording=None,
determine_channels=True,
template_config=unshifted_raw_template_config,
n_jobs=0,
):
sv = sorting.to_spike_vector()
return sorting_from_times_labels(sv['sample_index'], sv['unit_index'], sampling_frequency=sorting.sampling_frequency, recording=recording, determine_channels=determine_channels, template_config=template_config, n_jobs=n_jobs)
return sorting_from_times_labels(
sv['sample_index'], sv['unit_index'], sampling_frequency=sorting.sampling_frequency, recording=recording, determine_channels=determine_channels, template_config=template_config, n_jobs=n_jobs
)
10 changes: 7 additions & 3 deletions src/dartsort/util/spiketorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch.nn.functional as F
from scipy.fftpack import next_fast_len
from torch.fft import irfft, rfft
import warnings


def fast_nanmedian(x, axis=-1):
Expand Down Expand Up @@ -315,9 +316,12 @@ def nancov(x, weights=None, correction=1, nan_free=False, return_nobs=False, for
cov = xtx / denom

if force_posdef:
vals, vecs = torch.linalg.eigh(cov)
good = vals > 0
cov = (vecs[:, good] * vals[good]) @ vecs[:, good].T
try:
vals, vecs = torch.linalg.eigh(cov)
good = vals > 0
cov = (vecs[:, good] * vals[good]) @ vecs[:, good].T
except Exception as e:
warnings.warn(f"Error in nancov eigh: {e}")

if return_nobs:
return cov, nobs
Expand Down
Loading

0 comments on commit 64721b4

Please sign in to comment.