diff --git a/src/dartsort/transform/transform_base.py b/src/dartsort/transform/transform_base.py index a5f577ef..355e36de 100644 --- a/src/dartsort/transform/transform_base.py +++ b/src/dartsort/transform/transform_base.py @@ -36,20 +36,6 @@ def precompute(self): pass -class RegularChannelsTransformerMixin: - - def needs_precompute(self): - return hasattr(self, "regular_channel_index") - - def precompute(self): - # create regular channel index... - # assume that we have a - self.radius - # parameter - ... - - - class BaseWaveformDenoiser(BaseWaveformModule): is_denoiser = True @@ -91,8 +77,6 @@ def spike_datasets(self): return (dataset,) - - class BaseWaveformAutoencoder(BaseWaveformDenoiser, BaseWaveformFeaturizer): pass diff --git a/src/dartsort/transform/vae_localize.py b/src/dartsort/transform/vae_localize.py index 7b5da16a..9fb7c777 100644 --- a/src/dartsort/transform/vae_localize.py +++ b/src/dartsort/transform/vae_localize.py @@ -1,12 +1,11 @@ import torch import torch.nn as nn import torch.nn.functional as F -from dartsort.localize.localize_torch import localize_amplitude_vectors -from dartsort.util.spiketorch import ptp +from dartsort.util.spiketorch import get_relative_index, ptp, reindex +from dartsort.util.waveform_util import make_regular_channel_index from torch.utils.data import DataLoader, TensorDataset from tqdm.auto import trange - from .transform_base import BaseWaveformFeaturizer @@ -28,6 +27,9 @@ def __init__( hidden_dims=(256, 128), name=None, name_prefix="", + epochs=10, + learning_rate=1e-3, + batch_size=32, ): assert amplitude_kind in ("peak", "ptp") super().__init__( @@ -40,7 +42,10 @@ def __init__( self.amplitude_kind = amplitude_kind self.radius = radius self.localization_model = localization_model - self.latent_dim = 3 #x,y,z + self.latent_dim = 3 # x,y,z + self.epochs = epochs + self.learning_rate = learning_rate + self.batch_size = batch_size self.encoder = nn.Sequential( nn.Linear(self.input_dim, hidden_dims[0]), nn.BatchNorm1d(hidden_dims[0]), @@ -48,10 +53,19 @@ def __init__( nn.Linear(hidden_dims[0], hidden_dims[1]), nn.BatchNorm1d(hidden_dims[1]), nn.ReLU(), - nn.Linear(hidden_dims[1], self.latent_dim * 2) # Output mu and log_var + nn.Linear(hidden_dims[1], self.latent_dim * 2), # Output mu and log_var ) self.register_buffer("padded_geom", F.pad(self.geom, (0, 0, 0, 1))) + self.register_buffer( + "model_channel_index", + make_regular_channel_index(self.geom, radius, to_torch=True), + ) + self.register_buffer( + "relative_index", + get_relative_index(self.channel_index, self.model_channel_index), + ) + def reparameterize(self, mu, var): std = var.sqrt() eps = torch.randn_like(std) @@ -59,7 +73,9 @@ def reparameterize(self, mu, var): def local_distances(self, z, channels): """Return distances from each z to its local geom centered at channels.""" - local_geom = self.padded_geom[self.channel_index[channels]] - self.geom[channels].unsqueeze(1) + local_geom = self.padded_geom[self.channel_index[channels]] - self.geom[ + channels + ].unsqueeze(1) dx = z[:, 0, None] - local_geom[:, :, 0] dz = z[:, 2, None] + local_geom[:, :, 1] y = F.softplus(z[:, 1]).unsqueeze(1) @@ -67,7 +83,7 @@ def local_distances(self, z, channels): return dists def get_alphas(self, obs_amps, dists, return_pred=False): - pred_amps_alpha1 = 1. / dists + pred_amps_alpha1 = 1.0 / dists # least squares with no intercept alphas = (obs_amps * pred_amps_alpha1).sum(1) / pred_amps_alpha1.square().sum(1) if return_pred: @@ -87,36 +103,51 @@ def forward(self, x, mask, obs_amps, channels): z = self.reparameterize(mu, var) alphas, pred_amps = self.decode(z, channels, obs_amps) return pred_amps, mu, var - + def loss_function(self, recon_x, x, mask, mu, var): mask = mask.to(x) recon_x_masked = recon_x * mask x_masked = x * mask - BCE = F.mse_loss(recon_x, x, reduction='sum') + BCE = F.mse_loss(recon_x_masked, x_masked, reduction="sum") KLD = -0.5 * (1 + torch.log(var) - mu.pow(2) - var).sum() return BCE + KLD - # @torch.enable_grad() - def fit(self, waveforms, amps, channels, epochs=10, learning_rate=1e-3, batch_size=32): - # Example waveform data and IDs - with torch.enable_grad(): - # Create a dataset including the IDs - dataset = TensorDataset(waveforms, amps, channels) - dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) - optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate) - self.train() # Set the model to training mode - for epoch in trange(epochs, desc="Epochs"): + def _fit(self, waveforms, channels): + # apply channel reindexing before any fitting... + waveforms = reindex(channels, waveforms, self.relative_index, pad_value=0.0) + # should be no nans there thanks to padding with 0s + + if self.amplitude_kind == "ptp": + amps = ptp(waveforms) + elif self.amplitude_kind == "peak": + amps = waveforms.abs().max(dim=1).values + + dataset = TensorDataset(waveforms, amps, channels) + dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True) + optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) + + self.train() + with trange(self.epochs, desc="Epochs") as pbar: + for epoch in pbar: total_loss = 0 for waveform_batch, amps_batch, channels_batch in dataloader: optimizer.zero_grad() - channels_mask = self.channel_index[channels_batch] < len(self.geom) - reconstructed_amps, mu, var = self.forward(waveform_batch, channels_mask, amps_batch, channels_batch) - loss = self.loss_function(reconstructed_amps, amps_batch, channels_mask, mu, var) + channels_mask = self.model_channel_index[channels_batch] < len(self.geom) + reconstructed_amps, mu, var = self.forward( + waveform_batch, channels_mask, amps_batch, channels_batch + ) + loss = self.loss_function( + reconstructed_amps, amps_batch, channels_mask, mu, var + ) loss.backward() optimizer.step() total_loss += loss.item() - print(f"Epoch {epoch}, Loss: {total_loss / len(dataloader)}") + pbar.set_description(f"epoch={epoch}: loss={total_loss / len(dataloader)}") + + def fit(self, waveforms, max_channels): + with torch.enable_grad(): + self._fit(waveforms, max_channels) def transform(self, waveforms, max_channels): """ @@ -125,8 +156,8 @@ def transform(self, waveforms, max_channels): waveform[n] lives on channels self.channel_index[max_channels[n]] """ - waveforms = torch.nan_to_num(waveforms) - mask = self.channel_index[max_channels] < len(self.geom) + waveforms = reindex(max_channels, waveforms, self.relative_index, pad_value=0.0) + mask = self.model_channel_index[max_channels] < len(self.geom) x_flat = waveforms.view(len(waveforms), -1) x_flat_mask = torch.cat((x_flat, mask), dim=1) mu, log_var = self.encoder(x_flat_mask).chunk(2, dim=-1) @@ -134,6 +165,3 @@ def transform(self, waveforms, max_channels): y = F.softplus(y) mx, mz = self.geom[max_channels].T return x + mx, y, z + mz - -# %% - diff --git a/src/dartsort/util/spiketorch.py b/src/dartsort/util/spiketorch.py index 1bef22e5..e50fa4e0 100644 --- a/src/dartsort/util/spiketorch.py +++ b/src/dartsort/util/spiketorch.py @@ -301,12 +301,17 @@ def depthwise_oaconv1d(input, weight, f2=None, padding=0): if overlap is None: f1 = torch.fft.rfft(input, n=s1) f2 = torch.fft.rfft(torch.flip(weight, (-1,)), n=s1) - f1.mul_(f2[:, None:, ]) + f1.mul_( + f2[ + :, + None:, + ] + ) res = torch.fft.irfft(f1, n=s1) valid_len = s1 - s2 + 1 valid_start = s2 - 1 assert valid_start >= padding - res = res[:, valid_start-padding: valid_start+valid_len + padding] + res = res[:, valid_start - padding : valid_start + valid_len + padding] return res nstep1, pad1, nstep2, pad2 = steps_and_pad( @@ -344,6 +349,59 @@ def depthwise_oaconv1d(input, weight, f2=None, padding=0): valid_len = s1 - s2 + 1 valid_start = s2 - 1 assert valid_start >= padding - oa = oa[:, valid_start - padding:valid_start + valid_len + padding] + oa = oa[:, valid_start - padding : valid_start + valid_len + padding] return oa + + +# -- channel reindexing + + +def get_relative_index(source_channel_index, target_channel_index): + """Pre-compute a channel reindexing helper structure. + + Inputs have shapes: + source_channel_index.shape == (n_chans, n_source_chans) + target_channel_index.shape == (n_chans, n_target_chans) + + This returns an array (relative_index) of shape (n_chans, n_target_chans) + which knows how to translate between the source and target indices: + + relative_index[c, j] = index of target_channel_index[c, j] in source_channel_index[c] + if present, else n_source_chans (i.e., an invalid index) + (or, n_source chans if target_channel_index[c, j] is n_chans) + + See below: + reindex(max_channels, source_waveforms, relative_index) + """ + n_chans, n_source_chans = source_channel_index.shape + n_chans_, n_target_chans = target_channel_index.shape + assert n_chans == n_chans_ + relative_index = torch.full_like(target_channel_index, n_source_chans) + for c in range(n_chans): + row = source_channel_index[c] + for j in range(n_target_chans): + targ = target_channel_index[c, j] + if targ == n_chans: + continue + mask = row == targ + if not mask.any(): + continue + (ixs,) = mask.nonzero(as_tuple=True) + assert ixs.numel() == 1 + relative_index[c, j] = ixs[0] + return relative_index + + +def reindex( + max_channels, + source_waveforms, + relative_index, + already_padded=False, + pad_value=torch.nan, +): + """""" + rel_ix = relative_index[max_channels] + if not already_padded: + source_waveforms = F.pad(source_waveforms, (0, 1), value=pad_value) + return torch.take_along_dim(source_waveforms, rel_ix, dim=2)