From 11faf6782da84fe756aec45147cb5258b862b20d Mon Sep 17 00:00:00 2001 From: Christopher Langfield Date: Tue, 20 Aug 2024 11:07:34 -0700 Subject: [PATCH 1/3] piecewise spiketrain --- src/dartsort/util/hybrid_util.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/dartsort/util/hybrid_util.py b/src/dartsort/util/hybrid_util.py index 7fae8356..c3b21a0b 100644 --- a/src/dartsort/util/hybrid_util.py +++ b/src/dartsort/util/hybrid_util.py @@ -201,6 +201,23 @@ def refractory_poisson_spike_train( return spike_samples +def piecewise_refractory_poisson_spike_train(rates, bins, binsize_samples, **kwargs): + """ + Returns a spike train with variable firing rate using refractory_poisson_spike_train(). + + :param rates: list of firing rates in Hz + :param bins: bin starting samples (same shape as rates) + :param binsize_samples: number of samples per bin + :param **kwargs: kwargs to feed to refractory_poisson_spike_train() + """ + sp_tr = np.concatenate( + [ + refractory_poisson_spike_train(r, binsize_samples, **kwargs) + bins[i] if r > 0.1 else [] + for i, r in enumerate(rates) + ] + ) + return sp_tr + def precompute_displaced_registered_templates( template_data: TemplateData, From c9cd3226fed3ad700a67e5353132eb66b31b5e08 Mon Sep 17 00:00:00 2001 From: Christopher Langfield Date: Tue, 20 Aug 2024 19:56:28 -0700 Subject: [PATCH 2/3] feed amps vector directly to hybrid recording gen --- src/dartsort/util/hybrid_util.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/dartsort/util/hybrid_util.py b/src/dartsort/util/hybrid_util.py index c3b21a0b..8974cea6 100644 --- a/src/dartsort/util/hybrid_util.py +++ b/src/dartsort/util/hybrid_util.py @@ -26,6 +26,7 @@ def get_drifty_hybrid_recording( firing_rates=None, peak_channels=None, amplitude_scale_std=0.1, + amplitude_factor=None ): """ :param: recording @@ -33,7 +34,9 @@ def get_drifty_hybrid_recording( :param: motion estimate object :param: firing_rates :param: peak_channels - :param: amplitude_factor + :param: amplitude_scale_std -- std of gamma distributed amplitude variation if + amplitude_factor is None + :param: amplitude_factor array of length n_spikes with amplitude factors """ num_units = templates.num_units rg = np.random.default_rng(seed=seed) @@ -50,11 +53,12 @@ def get_drifty_hybrid_recording( n_spikes = sorting.count_total_num_spikes() # Default amplitude scalings for spikes drawn from gamma - if amplitude_scale_std: - shape = 1. / (amplitude_scale_std ** 1.5) - amplitude_factor = rg.gamma(shape, scale=1./(shape-1), size=n_spikes) - else: - amplitude_factor = np.ones(n_spikes) + if not amplitude_factor: + if amplitude_scale_std: + shape = 1. / (amplitude_scale_std ** 1.5) + amplitude_factor = rg.gamma(shape, scale=1./(shape-1), size=n_spikes) + else: + amplitude_factor = np.ones(n_spikes) depths = recording.get_probe().contact_positions[:, 1][peak_channels] t_start = recording.sample_index_to_time(0) From ef0d5ce0b77d668b11495e332c2bb6dbfed14da3 Mon Sep 17 00:00:00 2001 From: Christopher Langfield Date: Tue, 20 Aug 2024 20:22:21 -0700 Subject: [PATCH 3/3] none check --- src/dartsort/util/hybrid_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dartsort/util/hybrid_util.py b/src/dartsort/util/hybrid_util.py index 8974cea6..55cd2980 100644 --- a/src/dartsort/util/hybrid_util.py +++ b/src/dartsort/util/hybrid_util.py @@ -53,7 +53,7 @@ def get_drifty_hybrid_recording( n_spikes = sorting.count_total_num_spikes() # Default amplitude scalings for spikes drawn from gamma - if not amplitude_factor: + if amplitude_factor is None: if amplitude_scale_std: shape = 1. / (amplitude_scale_std ** 1.5) amplitude_factor = rg.gamma(shape, scale=1./(shape-1), size=n_spikes)