Skip to content

Commit

Permalink
Lates NN stuff + clustering vis debugging + GT metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Sep 25, 2024
1 parent 2aaaace commit f22e1ba
Show file tree
Hide file tree
Showing 8 changed files with 249 additions and 64 deletions.
10 changes: 9 additions & 1 deletion src/dartsort/peel/grab.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from dartsort.util import spiketorch

from .peel_base import BasePeeler
from .peel_base import BasePeeler, SpikeDataset


class GrabAndFeaturize(BasePeeler):
Expand Down Expand Up @@ -38,6 +38,13 @@ def __init__(
self.register_buffer("times_samples", times_samples)
self.register_buffer("channels", channels)

def out_datasets(self):
datasets = super().out_datasets()
datasets.append(
SpikeDataset(name="indices", shape_per_spike=(), dtype=int)
)
return datasets

def process_chunk(self, chunk_start_samples, return_residual=False):
"""Override process_chunk to skip empties."""
chunk_end_samples = min(
Expand Down Expand Up @@ -88,6 +95,7 @@ def peel_chunk(

return dict(
n_spikes=in_chunk.numel(),
indices=in_chunk,
times_samples=self.times_samples[in_chunk],
channels=channels,
collisioncleaned_waveforms=waveforms,
Expand Down
2 changes: 2 additions & 0 deletions src/dartsort/peel/threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(
relative_peak_radius_samples=5,
dedup_temporal_radius_samples=7,
n_chunks_fit=40,
n_waveforms_fit=20_000,
max_waveforms_fit=50_000,
fit_subsampling_random_state=0,
dtype=torch.float,
Expand All @@ -34,6 +35,7 @@ def __init__(
chunk_length_samples=chunk_length_samples,
chunk_margin_samples=spike_length_samples,
n_chunks_fit=n_chunks_fit,
n_waveforms_fit=n_waveforms_fit,
max_waveforms_fit=max_waveforms_fit,
fit_subsampling_random_state=fit_subsampling_random_state,
dtype=dtype,
Expand Down
7 changes: 5 additions & 2 deletions src/dartsort/transform/all_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from .vae_localize import VAELocalization
from .single_channel_denoiser import SingleChannelWaveformDenoiser
from .temporal_pca import TemporalPCADenoiser, TemporalPCAFeaturizer, TemporalPCA
from .transform_base import Waveform
from .transform_base import Waveform, Passthrough
from .decollider import Decollider

all_transformers = [
Waveform,
Expand All @@ -16,10 +17,12 @@
TemporalPCAFeaturizer,
Localization,
PointSourceLocalization,
VAELocalization,
VAELocalization,
AmplitudeFeatures,
TemporalPCA,
Voltage,
Decollider,
Passthrough,
]

transformers_by_class_name = {cls.__name__: cls for cls in all_transformers}
54 changes: 36 additions & 18 deletions src/dartsort/transform/decollider.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ def __init__(
noisier3noise=False,
inference_kind="raw",
seed=0,
batch_size=32,
learning_rate=1e-3,
epochs=25,
):
assert inference_kind in ("raw", "amortized")

Expand All @@ -34,7 +37,10 @@ def __init__(
self.hidden_dims = hidden_dims
self.n_channels = len(geom)
self.recording = recording
self.rg = np.random.get_default_rng(seed)
self.batch_size = batch_size
self.learning_rate = learning_rate
self.epochs = epochs
self.rg = np.random.default_rng(seed)

super().__init__(
geom=geom, channel_index=channel_index, name=name, name_prefix=name_prefix
Expand All @@ -50,6 +56,15 @@ def __init__(
"relative_index",
get_relative_index(self.channel_index, self.model_channel_index),
)
# suburban lawns -- janitor
self.register_buffer(
"irrelative_index",
get_relative_index(self.model_channel_index, self.channel_index),
)
self._needs_fit = True

def needs_fit(self):
return self._needs_fit

def initialize_nets(self, spike_length_samples):
self.spike_length_samples = spike_length_samples
Expand Down Expand Up @@ -83,27 +98,30 @@ def fit(self, waveforms, max_channels):
waveforms = reindex(max_channels, waveforms, self.relative_index, pad_value=0.0)
with torch.enable_grad():
self._fit(waveforms, max_channels)
self._needs_fit = False

def transform(self, waveforms, max_channels):
def forward(self, waveforms, max_channels):
"""Called only at inference time."""
n = len(waveforms)
waveforms = reindex(max_channels, waveforms, self.relative_index, pad_value=0.0)
masks = self.get_masks(max_channels).to(waveforms)

net_input = torch.cat((waveforms.view(n, self.wf_dim), masks), dim=1)
if self.inference_kind == "amortized":
pred = self.inf_net(net_input)
pred = self.inf_net(net_input).view(waveforms.shape)
elif self.inference_kind == "raw":
pred = self.eyz(net_input)
pred = self.eyz(net_input).view(waveforms.shape)
else:
assert False

pred = reindex(max_channels, pred, self.irrelative_index)

return pred

def get_masks(self, max_channels):
return self.model_channel_index[max_channels] < self.n_channels

def forward(self, y, m, mask):
def train_forward(self, y, m, mask):
n = len(y)
z = y + m
z_flat = z.view(n, self.wf_dim)
Expand All @@ -117,18 +135,18 @@ def forward(self, y, m, mask):

# predictions given z
if self.noisier3noise:
eyz = self.eyz(z_masked)
emz = self.emz(z_masked)
eyz = self.eyz(z_masked).view(y.shape)
emz = self.emz(z_masked).view(y.shape)
exz = eyz - emz
else:
eyz = self.eyz(z_masked)
eyz = self.eyz(z_masked).view(y.shape)
exz = 2 * eyz - z

# predictions given y, if relevant
if self.inference_kind == "amortized":
y_flat = y.view(n, self.wf_dim)
y_masked = torch.cat((y_flat, mask), dim=1)
exy = self.inf_net(y_masked)
exy = self.inf_net(y_masked).view(y.shape)

return exz, eyz, emz, exy

Expand Down Expand Up @@ -157,8 +175,9 @@ def get_noise(self, channels):

return torch.from_numpy(noise_waveforms)

def loss(mask, waveforms, m, exz, eyz, emz=None, exy=None):
def loss(self, mask, waveforms, m, exz, eyz, emz=None, exy=None):
loss_dict = {}
mask = mask.unsqueeze(1)
loss_dict["eyz"] = F.mse_loss(mask * eyz, mask * waveforms)
if emz is not None:
loss_dict["emz"] = F.mse_loss(mask * emz, mask * m)
Expand All @@ -167,31 +186,30 @@ def loss(mask, waveforms, m, exz, eyz, emz=None, exy=None):
return loss_dict

def _fit(self, waveforms, channels):
self.initialize_net(waveforms.shape[1])
self.initialize_nets(waveforms.shape[1])
dataset = TensorDataset(waveforms, channels)
dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True)
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
self.to(waveforms.device)

with trange(self.epochs, desc="Epochs", unit="epoch") as pbar:
epoch_losses = {}
for epoch in pbar:
epoch_losses = {}
for waveform_batch, channels_batch in dataloader:
optimizer.zero_grad()

# get a batch of noise samples
m = self.get_noise(channels_batch).to(waveform_batch)
mask = self.get_masks(channels_batch).to(waveform_batch)
exz, eyz, emz, exy = self.forward(
waveforms,
m,
)
exz, eyz, emz, exy = self.train_forward(waveform_batch, m, mask)
loss_dict = self.loss(mask, waveform_batch, m, exz, eyz, emz, exy)
loss = sum(loss_dict.values())
loss.backward()
optimizer.step()

for k, v in loss_dict:
for k, v in loss_dict.items():
epoch_losses[k] = v.item() + epoch_losses.get(k, 0.0)

loss_str = ", ".join(f"{k}: {v:0.2f}" for k, v in epoch_losses.items())
epoch_losses = {k: v / len(dataloader) for k, v in epoch_losses.items()}
loss_str = ", ".join(f"{k}: {v:.3f}" for k, v in epoch_losses.items())
pbar.set_description(f"Epochs [{loss_str}]")
39 changes: 39 additions & 0 deletions src/dartsort/transform/transform_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,45 @@ class BaseWaveformAutoencoder(BaseWaveformDenoiser, BaseWaveformFeaturizer):
pass


class Passthrough(BaseWaveformDenoiser, BaseWaveformFeaturizer):

def __init__(self, pipeline):
feat = [t for t in pipeline if t.is_featurizer]
if not len(feat):
raise ValueError("Passthrough with no featurizer?")
name = f"passthrough_{feat[0].name}"
super().__init__(name=name)
self.pipeline = pipeline

def needs_precompute(self):
return self.pipeline.needs_precompute()

def precompute(self):
return self.pipeline.precompute()

def needs_fit(self):
return self.pipeline.needs_fit()

def fit(self, waveforms, max_channels):
self.pipeline.fit(waveforms, max_channels)

def forward(self, waveforms, max_channels=None):
pipeline_waveforms, pipeline_features = self.pipeline(waveforms, max_channels)
return waveforms, pipeline_features

@property
def spike_datasets(self):
datasets = []
for t in self.pipeline.transformers:
if t.is_featurizer:
datasets.extend(t.spike_datasets)
return datasets

def transform(self, waveforms, max_channels=None):
pipeline_waveforms, pipeline_features = self.pipeline(waveforms, max_channels)
return pipeline_features


class IdentityWaveformDenoiser(BaseWaveformDenoiser):
def forward(self, waveforms, max_channels=None):
return waveforms
Expand Down
Loading

0 comments on commit f22e1ba

Please sign in to comment.