Skip to content

Commit

Permalink
Debug false positives idea
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Sep 30, 2024
1 parent fed4d21 commit cbf53c0
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 8 deletions.
2 changes: 1 addition & 1 deletion src/dartsort/peel/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ def match_chunk(

# initialize convolution
compressed_template_data.convolve(
residual, padding=self.obj_pad_len, out=padded_conv
residual.T, padding=self.obj_pad_len, out=padded_conv
)

# main loop
Expand Down
10 changes: 8 additions & 2 deletions src/dartsort/templates/template_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,13 @@ def svd_compress_templates(
singular_values: n_units, rank
spatial_components: n_units, rank, n_channels
"""
templates = template_data.templates
if hasattr(template_data, "templates"):
templates = template_data.templates
counts = template_data.spike_counts_by_channel
else:
templates = template_data
counts = None

vis_mask = templates.ptp(axis=1, keepdims=True) > min_channel_amplitude
vis_templates = templates * vis_mask
dtype = templates.dtype
Expand Down Expand Up @@ -262,7 +268,7 @@ def svd_compress_templates(
singular_values[i, :k] = s[:rank]
spatial_components[i, :k, mask] = Vh[:rank].T

return LowRankTemplates(temporal_components, singular_values, spatial_components, template_data.spike_counts_by_channel)
return LowRankTemplates(temporal_components, singular_values, spatial_components, counts)


def temporally_upsample_templates(
Expand Down
69 changes: 68 additions & 1 deletion src/dartsort/util/noise_util.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from scipy.fftpack import next_fast_len
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F

from dartsort.util import spiketorch
from dartsort.detect import detect_and_deduplicate
from scipy.fftpack import next_fast_len
from tqdm.auto import trange


class FactorizedNoise(torch.nn.Module):
Expand Down Expand Up @@ -150,3 +154,66 @@ def estimate(cls, snippets):
kernel_fft = (xt_fft * xt_fft.conj()).mean(0).sqrt_()

return cls(spatial_std, vt_spatial, kernel_fft, block_size, t)

def unit_false_positives(
self,
low_rank_templates,
min_threshold=5.0,
radius=10,
generator=None,
size=100,
t=4096,
unit_batch_size=32,
):
singular = torch.asarray(
low_rank_templates.singular_values,
device=self.spatial_std.device,
dtype=self.spatial_std.dtype,
)
spatial = torch.asarray(
low_rank_templates.spatial_components,
device=self.spatial_std.device,
dtype=self.spatial_std.dtype,
)
spatial_singular = singular.unsqueeze(-1) * spatial
temporal = torch.asarray(
low_rank_templates.temporal_components,
device=self.spatial_std.device,
dtype=self.spatial_std.dtype,
)
negnormsq = spatial_singular.square().sum((1, 2)).neg_().unsqueeze(1)
nu, nt = temporal.shape[:2]
obj = None # var for reusing buffers
units = []
scores = []
for j in trange(size, desc="False positives"):
sample = self.simulate(t=t + nt - 1, generator=generator)[0].T
obj = spiketorch.convolve_lowrank(
sample,
spatial_singular,
temporal,
out=obj,
)
assert obj.shape == (nu, t)
obj = torch.add(negnormsq, obj, alpha=2.0, out=obj)

# find peaks...
peak_times, peak_units, peak_energies = detect_and_deduplicate(
obj.T,
min_threshold,
peak_sign="pos",
dedup_temporal_radius=radius,
return_energies=True,
)
units.append(peak_units.numpy(force=True))
scores.append(peak_energies.numpy(force=True))

total_samples = size * t
df = pd.DataFrame(
dict(
units=np.concatenate(units),
scores=np.concatenate(scores),
)
)

return total_samples, df
6 changes: 2 additions & 4 deletions src/dartsort/util/spiketorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,7 @@ def convolve_lowrank(
):
"""Depthwise convolution of traces with templates"""
n_templates, spike_length_samples, rank = temporal_components.shape

out_len = traces.shape[0] + 2 * padding - spike_length_samples + 1
out_len = traces.shape[1] + 2 * padding - spike_length_samples + 1
if out is None:
out = torch.empty(
(n_templates, out_len),
Expand All @@ -262,7 +261,6 @@ def convolve_lowrank(
# convolve with temporal components -- units x time
temporal = temporal_components[:, :, q]

# temporalf = temporalf[:, :, q]
# conv1d with groups! only convolve each unit with its own temporal filter
conv = F.conv1d(
rec_spatial[None],
Expand All @@ -273,7 +271,7 @@ def convolve_lowrank(

# o-a turns out not to be helpful, sadly
# conv = depthwise_oaconv1d(
# rec_spatial, temporal, padding=padding, f2=temporalf
# rec_spatial, temporal, padding=padding, f2=temporalf[:, :, q]
# )

if q:
Expand Down

0 comments on commit cbf53c0

Please sign in to comment.