Skip to content

Commit

Permalink
feed amps vector directly to hybrid recording gen
Browse files Browse the repository at this point in the history
  • Loading branch information
chris-langfield committed Aug 21, 2024
1 parent 11faf67 commit c9cd322
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions src/dartsort/util/hybrid_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,17 @@ def get_drifty_hybrid_recording(
firing_rates=None,
peak_channels=None,
amplitude_scale_std=0.1,
amplitude_factor=None
):
"""
:param: recording
:param: templates object
:param: motion estimate object
:param: firing_rates
:param: peak_channels
:param: amplitude_factor
:param: amplitude_scale_std -- std of gamma distributed amplitude variation if
amplitude_factor is None
:param: amplitude_factor array of length n_spikes with amplitude factors
"""
num_units = templates.num_units
rg = np.random.default_rng(seed=seed)
Expand All @@ -50,11 +53,12 @@ def get_drifty_hybrid_recording(
n_spikes = sorting.count_total_num_spikes()

# Default amplitude scalings for spikes drawn from gamma
if amplitude_scale_std:
shape = 1. / (amplitude_scale_std ** 1.5)
amplitude_factor = rg.gamma(shape, scale=1./(shape-1), size=n_spikes)
else:
amplitude_factor = np.ones(n_spikes)
if not amplitude_factor:
if amplitude_scale_std:
shape = 1. / (amplitude_scale_std ** 1.5)
amplitude_factor = rg.gamma(shape, scale=1./(shape-1), size=n_spikes)
else:
amplitude_factor = np.ones(n_spikes)

depths = recording.get_probe().contact_positions[:, 1][peak_channels]
t_start = recording.sample_index_to_time(0)
Expand Down

0 comments on commit c9cd322

Please sign in to comment.