Skip to content

Commit

Permalink
Merge branch 'main' of github.com:cwindolf/dartsort
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Aug 21, 2024
2 parents 73cbed0 + ef0d5ce commit 831b0ad
Showing 1 changed file with 27 additions and 6 deletions.
33 changes: 27 additions & 6 deletions src/dartsort/util/hybrid_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,17 @@ def get_drifty_hybrid_recording(
firing_rates=None,
peak_channels=None,
amplitude_scale_std=0.1,
amplitude_factor=None
):
"""
:param: recording
:param: templates object
:param: motion estimate object
:param: firing_rates
:param: peak_channels
:param: amplitude_factor
:param: amplitude_scale_std -- std of gamma distributed amplitude variation if
amplitude_factor is None
:param: amplitude_factor array of length n_spikes with amplitude factors
"""
num_units = templates.num_units
rg = np.random.default_rng(seed=seed)
Expand All @@ -50,11 +53,12 @@ def get_drifty_hybrid_recording(
n_spikes = sorting.count_total_num_spikes()

# Default amplitude scalings for spikes drawn from gamma
if amplitude_scale_std:
shape = 1. / (amplitude_scale_std ** 1.5)
amplitude_factor = rg.gamma(shape, scale=1./(shape-1), size=n_spikes)
else:
amplitude_factor = np.ones(n_spikes)
if amplitude_factor is None:
if amplitude_scale_std:
shape = 1. / (amplitude_scale_std ** 1.5)
amplitude_factor = rg.gamma(shape, scale=1./(shape-1), size=n_spikes)
else:
amplitude_factor = np.ones(n_spikes)

depths = recording.get_probe().contact_positions[:, 1][peak_channels]
t_start = recording.sample_index_to_time(0)
Expand Down Expand Up @@ -201,6 +205,23 @@ def refractory_poisson_spike_train(

return spike_samples

def piecewise_refractory_poisson_spike_train(rates, bins, binsize_samples, **kwargs):
"""
Returns a spike train with variable firing rate using refractory_poisson_spike_train().
:param rates: list of firing rates in Hz
:param bins: bin starting samples (same shape as rates)
:param binsize_samples: number of samples per bin
:param **kwargs: kwargs to feed to refractory_poisson_spike_train()
"""
sp_tr = np.concatenate(
[
refractory_poisson_spike_train(r, binsize_samples, **kwargs) + bins[i] if r > 0.1 else []
for i, r in enumerate(rates)
]
)
return sp_tr


def precompute_displaced_registered_templates(
template_data: TemplateData,
Expand Down

0 comments on commit 831b0ad

Please sign in to comment.