Skip to content

Commit

Permalink
Merge branch 'main' of github.com:cwindolf/spike-psvae
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Aug 27, 2024
2 parents 68678e7 + 62972a2 commit 7d59c74
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 11 deletions.
6 changes: 6 additions & 0 deletions src/dartsort/cluster/initial.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def cluster_chunks(
clustering_config,
sorting=None,
motion_est=None,
amplitudes_dataset_name='denoised_ptp_amplitudes',
):
"""Divide the recording into chunks, and cluster each chunk
Expand Down Expand Up @@ -239,6 +240,7 @@ def cluster_chunks(
chunk_time_range_s=chunk_range,
motion_est=motion_est,
recording=recording,
amplitudes_dataset_name=amplitudes_dataset_name,
)
for chunk_range in chunk_time_ranges_s
]
Expand All @@ -253,6 +255,7 @@ def ensemble_chunks(
sorting=None,
computation_config=None,
motion_est=None,
**kwargs,
):
"""Initial clustering combined across chunks of time
Expand Down Expand Up @@ -283,6 +286,7 @@ def ensemble_chunks(
clustering_config,
sorting=sorting,
motion_est=motion_est,
**kwargs,
)
if len(chunk_sortings) == 1:
return chunk_sortings[0]
Expand Down Expand Up @@ -320,6 +324,7 @@ def initial_clustering(
clustering_config=None,
computation_config=None,
motion_est=None,
**kwargs,
):
if sorting is None:
sorting = DARTsortSorting.from_peeling_hdf5(peeling_hdf5_filename)
Expand All @@ -333,6 +338,7 @@ def initial_clustering(
sorting=sorting,
computation_config=computation_config,
motion_est=motion_est,
**kwargs,
)


4 changes: 4 additions & 0 deletions src/dartsort/cluster/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,12 +632,16 @@ def combine_templates(template_data_a, template_data_b):
spike_counts = np.concatenate(
(template_data_a.spike_counts, template_data_b.spike_counts)
)
spike_counts_by_channel = np.concatenate(
(template_data_a.spike_counts_by_channel, template_data_b.spike_counts_by_channel)
)
template_data = TemplateData(
templates=templates,
unit_ids=unit_ids,
spike_counts=spike_counts,
registered_geom=rgeom,
registered_template_depths_um=locs,
spike_counts_by_channel=spike_counts_by_channel,
)

cross_mask = np.logical_and(np.isin(unit_ids, ids_a)[:, None], np.isin(unit_ids, ids_b)[None])
Expand Down
4 changes: 3 additions & 1 deletion src/dartsort/templates/get_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,8 +701,10 @@ def _template_job(unit_ids):
p.reducer(waveforms[in_unit], axis=0).numpy(force=True)
)
counts.append(in_unit.size)
snrs_by_chan = [ptp(rt, 0) * c for rt, c in zip(raw_templates, counts)]
snrs_by_chan = np.array([ptp(rt, 0) * c for rt, c in zip(raw_templates, counts)])
counts_by_chan = np.array(counts)
if counts_by_chan.ndim == 1:
counts_by_chan = np.broadcast_to(counts_by_chan[:, None], snrs_by_chan.shape)
raw_templates = np.array(raw_templates)

if p.denoising_tsvd is None:
Expand Down
4 changes: 3 additions & 1 deletion src/dartsort/util/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class DARTsortGroundTruthComparison:
match_mode: str = "hungarian"
compute_labels: bool = True
verbose: bool = True
device: Optional[str] = None

compute_distances: bool = True
compute_unsorted_recall: bool = True
Expand Down Expand Up @@ -120,7 +121,8 @@ def _calculate_template_distances(self):
gt_td,
tested_td,
sym_function=np.maximum,
n_jobs=max(self.gt_analysis.n_jobs, self.tested_analysis.n_jobs),
n_jobs=self.n_jobs,
device=self.device,
)
self._template_distances = dists

Expand Down
7 changes: 4 additions & 3 deletions src/dartsort/util/drift_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,16 +131,17 @@ def registered_geometry(
return registered_geom


def registered_channels(channels, geom, n_pitches_shift, registered_geom):
def registered_channels(channels, geom, n_pitches_shift, registered_geom, distance_upper_bound=None):
"""What registered channels do `channels` land on after shifting by `n_pitches_shift`?"""
pitch = get_pitch(geom)
shifted_positions = geom.copy()[channels]
shifted_positions[:, 1] += n_pitches_shift * pitch

registered_kdtree = KDTree(registered_geom)
min_distance = pdist(registered_geom).min() / 2
if distance_upper_bound is None:
distance_upper_bound = pdist(registered_geom).min() / 2
distances, registered_channels = registered_kdtree.query(
shifted_positions, distance_upper_bound=min_distance
shifted_positions, distance_upper_bound=distance_upper_bound
)
# make sure there were no unmatched points
assert np.all(registered_channels < len(registered_geom))
Expand Down
33 changes: 27 additions & 6 deletions src/dartsort/util/hybrid_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,17 @@ def get_drifty_hybrid_recording(
firing_rates=None,
peak_channels=None,
amplitude_scale_std=0.1,
amplitude_factor=None
):
"""
:param: recording
:param: templates object
:param: motion estimate object
:param: firing_rates
:param: peak_channels
:param: amplitude_factor
:param: amplitude_scale_std -- std of gamma distributed amplitude variation if
amplitude_factor is None
:param: amplitude_factor array of length n_spikes with amplitude factors
"""
num_units = templates.num_units
rg = np.random.default_rng(seed=seed)
Expand All @@ -50,11 +53,12 @@ def get_drifty_hybrid_recording(
n_spikes = sorting.count_total_num_spikes()

# Default amplitude scalings for spikes drawn from gamma
if amplitude_scale_std:
shape = 1. / (amplitude_scale_std ** 1.5)
amplitude_factor = rg.gamma(shape, scale=1./(shape-1), size=n_spikes)
else:
amplitude_factor = np.ones(n_spikes)
if amplitude_factor is None:
if amplitude_scale_std:
shape = 1. / (amplitude_scale_std ** 1.5)
amplitude_factor = rg.gamma(shape, scale=1./(shape-1), size=n_spikes)
else:
amplitude_factor = np.ones(n_spikes)

depths = recording.get_probe().contact_positions[:, 1][peak_channels]
t_start = recording.sample_index_to_time(0)
Expand Down Expand Up @@ -201,6 +205,23 @@ def refractory_poisson_spike_train(

return spike_samples

def piecewise_refractory_poisson_spike_train(rates, bins, binsize_samples, **kwargs):
"""
Returns a spike train with variable firing rate using refractory_poisson_spike_train().
:param rates: list of firing rates in Hz
:param bins: bin starting samples (same shape as rates)
:param binsize_samples: number of samples per bin
:param **kwargs: kwargs to feed to refractory_poisson_spike_train()
"""
sp_tr = np.concatenate(
[
refractory_poisson_spike_train(r, binsize_samples, **kwargs) + bins[i] if r > 0.1 else []
for i, r in enumerate(rates)
]
)
return sp_tr


def precompute_displaced_registered_templates(
template_data: TemplateData,
Expand Down

0 comments on commit 7d59c74

Please sign in to comment.