Skip to content

Commit

Permalink
Multi chan vae?
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Sep 19, 2024
1 parent 6602361 commit b565241
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 46 deletions.
16 changes: 0 additions & 16 deletions src/dartsort/transform/transform_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,6 @@ def precompute(self):
pass


class RegularChannelsTransformerMixin:

def needs_precompute(self):
return hasattr(self, "regular_channel_index")

def precompute(self):
# create regular channel index...
# assume that we have a
self.radius
# parameter
...



class BaseWaveformDenoiser(BaseWaveformModule):
is_denoiser = True

Expand Down Expand Up @@ -91,8 +77,6 @@ def spike_datasets(self):
return (dataset,)




class BaseWaveformAutoencoder(BaseWaveformDenoiser, BaseWaveformFeaturizer):
pass

Expand Down
83 changes: 56 additions & 27 deletions src/dartsort/transform/vae_localize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
import torch.nn as nn
import torch.nn.functional as F
from dartsort.localize.localize_torch import localize_amplitude_vectors
from dartsort.util.spiketorch import ptp
from dartsort.util.spiketorch import get_relative_index, ptp, reindex
from dartsort.util.waveform_util import make_regular_channel_index
from torch.utils.data import DataLoader, TensorDataset
from tqdm.auto import trange


from .transform_base import BaseWaveformFeaturizer


Expand All @@ -28,6 +28,9 @@ def __init__(
hidden_dims=(256, 128),
name=None,
name_prefix="",
epochs=10,
learning_rate=1e-3,
batch_size=32,
):
assert amplitude_kind in ("peak", "ptp")
super().__init__(
Expand All @@ -40,34 +43,48 @@ def __init__(
self.amplitude_kind = amplitude_kind
self.radius = radius
self.localization_model = localization_model
self.latent_dim = 3 #x,y,z
self.latent_dim = 3 # x,y,z
self.epochs = epochs
self.learning_rate = learning_rate
self.batch_size = batch_size
self.encoder = nn.Sequential(
nn.Linear(self.input_dim, hidden_dims[0]),
nn.BatchNorm1d(hidden_dims[0]),
nn.ReLU(),
nn.Linear(hidden_dims[0], hidden_dims[1]),
nn.BatchNorm1d(hidden_dims[1]),
nn.ReLU(),
nn.Linear(hidden_dims[1], self.latent_dim * 2) # Output mu and log_var
nn.Linear(hidden_dims[1], self.latent_dim * 2), # Output mu and log_var
)
self.register_buffer("padded_geom", F.pad(self.geom, (0, 0, 0, 1)))

self.register_buffer(
"model_channel_index",
make_regular_channel_index(self.geom, radius, to_torch=True),
)
self.register_buffer(
"relative_index",
get_relative_index(self.channel_index, self.model_channel_index),
)

def reparameterize(self, mu, var):
std = var.sqrt()
eps = torch.randn_like(std)
return mu + eps * std

def local_distances(self, z, channels):
"""Return distances from each z to its local geom centered at channels."""
local_geom = self.padded_geom[self.channel_index[channels]] - self.geom[channels].unsqueeze(1)
local_geom = self.padded_geom[self.channel_index[channels]] - self.geom[
channels
].unsqueeze(1)
dx = z[:, 0, None] - local_geom[:, :, 0]
dz = z[:, 2, None] + local_geom[:, :, 1]
y = F.softplus(z[:, 1]).unsqueeze(1)
dists = torch.sqrt(dx**2 + dz**2 + y**2)
return dists

def get_alphas(self, obs_amps, dists, return_pred=False):
pred_amps_alpha1 = 1. / dists
pred_amps_alpha1 = 1.0 / dists
# least squares with no intercept
alphas = (obs_amps * pred_amps_alpha1).sum(1) / pred_amps_alpha1.square().sum(1)
if return_pred:
Expand All @@ -87,36 +104,51 @@ def forward(self, x, mask, obs_amps, channels):
z = self.reparameterize(mu, var)
alphas, pred_amps = self.decode(z, channels, obs_amps)
return pred_amps, mu, var

def loss_function(self, recon_x, x, mask, mu, var):
mask = mask.to(x)
recon_x_masked = recon_x * mask
x_masked = x * mask
BCE = F.mse_loss(recon_x, x, reduction='sum')
BCE = F.mse_loss(recon_x_masked, x_masked, reduction="sum")
KLD = -0.5 * (1 + torch.log(var) - mu.pow(2) - var).sum()
return BCE + KLD

# @torch.enable_grad()
def fit(self, waveforms, amps, channels, epochs=10, learning_rate=1e-3, batch_size=32):
# Example waveform data and IDs
with torch.enable_grad():
# Create a dataset including the IDs
dataset = TensorDataset(waveforms, amps, channels)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
self.train() # Set the model to training mode
for epoch in trange(epochs, desc="Epochs"):
def _fit(self, waveforms, channels):
# apply channel reindexing before any fitting...
waveforms = reindex(channels, waveforms, self.relative_index, pad_value=0.0)
# should be no nans there thanks to padding with 0s

if self.amplitude_kind == "ptp":
amps = ptp(waveforms)
elif self.amplitude_kind == "peak":
amps = waveforms.abs().max(dim=1).values

dataset = TensorDataset(waveforms, amps, channels)
dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True)
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)

self.train()
with trange(self.epochs, desc="Epochs") as pbar:
for epoch in pbar:
total_loss = 0
for waveform_batch, amps_batch, channels_batch in dataloader:
optimizer.zero_grad()
channels_mask = self.channel_index[channels_batch] < len(self.geom)
reconstructed_amps, mu, var = self.forward(waveform_batch, channels_mask, amps_batch, channels_batch)
loss = self.loss_function(reconstructed_amps, amps_batch, channels_mask, mu, var)
channels_mask = self.model_channel_index[channels_batch] < len(self.geom)
reconstructed_amps, mu, var = self.forward(
waveform_batch, channels_mask, amps_batch, channels_batch
)
loss = self.loss_function(
reconstructed_amps, amps_batch, channels_mask, mu, var
)
loss.backward()
optimizer.step()
total_loss += loss.item()

print(f"Epoch {epoch}, Loss: {total_loss / len(dataloader)}")
pbar.set_description(f"epoch={epoch}: loss={total_loss / len(dataloader)}")

def fit(self, waveforms, max_channels):
with torch.enable_grad():
self._fit(waveforms, max_channels)

def transform(self, waveforms, max_channels):
"""
Expand All @@ -125,15 +157,12 @@ def transform(self, waveforms, max_channels):
waveform[n] lives on channels self.channel_index[max_channels[n]]
"""
waveforms = torch.nan_to_num(waveforms)
mask = self.channel_index[max_channels] < len(self.geom)
waveforms = reindex(max_channels, waveforms, self.relative_index, pad_value=0.0)
mask = self.model_channel_index[max_channels] < len(self.geom)
x_flat = waveforms.view(len(waveforms), -1)
x_flat_mask = torch.cat((x_flat, mask), dim=1)
mu, log_var = self.encoder(x_flat_mask).chunk(2, dim=-1)
x, y, z = mu.T
y = F.softplus(y)
mx, mz = self.geom[max_channels].T
return x + mx, y, z + mz

# %%

64 changes: 61 additions & 3 deletions src/dartsort/util/spiketorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,12 +301,17 @@ def depthwise_oaconv1d(input, weight, f2=None, padding=0):
if overlap is None:
f1 = torch.fft.rfft(input, n=s1)
f2 = torch.fft.rfft(torch.flip(weight, (-1,)), n=s1)
f1.mul_(f2[:, None:, ])
f1.mul_(
f2[
:,
None:,
]
)
res = torch.fft.irfft(f1, n=s1)
valid_len = s1 - s2 + 1
valid_start = s2 - 1
assert valid_start >= padding
res = res[:, valid_start-padding: valid_start+valid_len + padding]
res = res[:, valid_start - padding : valid_start + valid_len + padding]
return res

nstep1, pad1, nstep2, pad2 = steps_and_pad(
Expand Down Expand Up @@ -344,6 +349,59 @@ def depthwise_oaconv1d(input, weight, f2=None, padding=0):
valid_len = s1 - s2 + 1
valid_start = s2 - 1
assert valid_start >= padding
oa = oa[:, valid_start - padding:valid_start + valid_len + padding]
oa = oa[:, valid_start - padding : valid_start + valid_len + padding]

return oa


# -- channel reindexing


def get_relative_index(source_channel_index, target_channel_index):
"""Pre-compute a channel reindexing helper structure.
Inputs have shapes:
source_channel_index.shape == (n_chans, n_source_chans)
target_channel_index.shape == (n_chans, n_target_chans)
This returns an array (relative_index) of shape (n_chans, n_target_chans)
which knows how to translate between the source and target indices:
relative_index[c, j] = index of target_channel_index[c, j] in source_channel_index[c]
if present, else n_source_chans (i.e., an invalid index)
(or, n_source chans if target_channel_index[c, j] is n_chans)
See below:
reindex(max_channels, source_waveforms, relative_index)
"""
n_chans, n_source_chans = source_channel_index.shape
n_chans_, n_target_chans = target_channel_index.shape
assert n_chans == n_chans_
relative_index = torch.full_like(target_channel_index, n_source_chans)
for c in range(n_chans):
row = source_channel_index[c]
for j in range(n_target_chans):
targ = target_channel_index[c, j]
if targ == n_chans:
continue
mask = row == targ
if not mask.any():
continue
(ixs,) = mask.nonzero(as_tuple=True)
assert ixs.numel() == 1
relative_index[c, j] = ixs[0]
return relative_index


def reindex(
max_channels,
source_waveforms,
relative_index,
already_padded=False,
pad_value=torch.nan,
):
""""""
rel_ix = relative_index[max_channels]
if not already_padded:
source_waveforms = F.pad(source_waveforms, (0, 1), value=pad_value)
return torch.take_along_dim(source_waveforms, rel_ix, dim=2)

0 comments on commit b565241

Please sign in to comment.