Skip to content

Commit

Permalink
Port over interpolation code; sketch out embedded + shifted noise
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Oct 15, 2024
1 parent 3010bda commit 19da675
Show file tree
Hide file tree
Showing 2 changed files with 538 additions and 15 deletions.
229 changes: 229 additions & 0 deletions src/dartsort/util/interpolation_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
"""Library for flavors of kernel interpolation and data interp utilities"""

import numpy as np
import torch
import torch.nn.functional as F
from dartsort.util.data_util import yield_masked_chunks
from dartsort.util.drift_util import (get_spike_pitch_shifts,
static_channel_neighborhoods)

interp_kinds = (
"nearest",
"rbf",
"normalized",
"kriging",
"kriging_normalized",
)


def interpolate_by_chunk(
mask,
dataset,
geom,
channel_index,
channels,
shifts,
registered_geom,
target_channels,
sigma=10.0,
interpolation_method="normalized",
device=None,
store_on_device=False,
show_progress=True,
):
"""Interpolate data living in an HDF5 file
If dataset is a h5py.Dataset and mask is a boolean array indicating
positions of data to load, this iterates over the HDF5 chunks to
quickly scan through the data, applying interpolation to all the
features.
Arguments
---------
mask : boolean np.ndarray
Load and interpolate these entries. Shape should be
(n_spikes_full,), and let's say it has n_spikes nonzero entries.
dataset : h5py.Dataset
Chunked dataset, shape (n_spikes_full, feature_dim, n_source_channels)
Can only be chunked on the first axis
geom : array or tensor
channel_index : int array or tensor
channels : int array or tensor
Shape (n_spikes,)
shifts : array or tensor
Shape (n_spikes,) or (n_spikes, n_source_channels)
registered_geom : array or tensor
target_channels : int array or tensor
(n_spikes, n_target_channels)
sigma : float
Kernel bandwidth
interpolation_method : str
device : torch device
store_on_device : bool
Allocate the output tensor on gpu?
show_progress : bool
Returns
-------
out : torch.Tensor
(n_spikes, feature_dim, n_target_chans)
"""
# devices, dtypes, shapes
assert interpolation_method in interp_kinds
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device)
dtype = torch.from_numpy(np.empty((), dtype=dataset.dtype)).dtype
n_spikes = mask.sum()
assert channels.shape == (n_spikes,)
n_target_chans = target_channels.shape[1]
assert target_channels.shape == (n_spikes, n_target_chans)
feature_dim = dataset.shape[1]
assert channel_index.shape[1] == dataset.shape[2]

# allocate output
storage_device = device if store_on_device else "cpu"
out_shape = n_spikes, feature_dim, n_target_chans
out = torch.empty(out_shape, dtype=dtype, device=storage_device)

# build data needed for interpolation
source_geom = pad_geom(geom, dtype=dtype, device=device)
target_geom = pad_geom(registered_geom, dtype=dtype, device=device)
shifts = torch.as_tensor(shifts, dtype=dtype).to(device)
target_channels = torch.as_tensor(target_channels, device=device)
channel_index = torch.as_tensor(channel_index, device=device)
channels = torch.as_tensor(channels, device=device)

for ixs, chunk_features in yield_masked_chunks(
mask, dataset, show_progress=show_progress, desc_prefix="Interpolating"
):
# where are the spikes?
source_channels = channel_index[channels[ixs]]
source_shifts = shifts[ixs]
if source_shifts.ndim == 1:
# allows per-channel shifts
source_shifts = source_shifts.unsqueeze(1)
source_pos = source_geom[source_channels] + source_shifts

# where are they going?
target_pos = target_geom[target_channels[ixs]]

# interpolate, store
chunk_res = kernel_interpolate(
chunk_features,
source_pos,
target_pos,
sigma=sigma,
allow_destroy=True,
interpolation_method=interpolation_method,
)
out[ixs] = chunk_res.to(out)

return out


def pad_geom(geom, dtype=torch.float, device=None):
geom = torch.as_tensor(geom, dtype=dtype, device=device)
geom = F.pad(geom, (0, 0, 0, 1), value=torch.nan)
return geom


def kernel_interpolate(
features,
source_pos,
target_pos,
source_kernel_invs=None,
sigma=10.0,
allow_destroy=False,
interpolation_method="normalized",
out=None,
):
"""Kernel interpolation of multi-channel features or waveforms
Arguments
---------
features : torch.Tensor
n_spikes, feature_dim, n_source_channels
These can be masked, indicated by nans here and in the same
places of source_pos
source_pos : torch.Tensor
n_spikes, n_source_channels, spatial_dim
target_pos : torch.Tensor
n_spikes, n_target_channels, spatial_dim
These can also be masked, indicate with nans and you will
get nans in those positions
source_kernel_invs : optional torch.Tensor
Precomputed inverses of source-to-source kernel matrices,
if you have them, for use in kriging
sigma : float
Spatial bandwidth of RBF kernels
allow_destroy : bool
We need to overwrite nans in the features with 0s. If you
allow me, I'll do that in-place.
out : torch.Tensor
Storage for target
Returns
-------
features : torch.Tensor
n_spikes, feature_dim, n_target_channels
"""
assert interpolation_method in interp_kinds

# -- build kernel
if interpolation_method == "nearest":
d = torch.cdist(source_pos, target_pos)
kernel = torch.zeros_like(d)
kernel[d.argmin(dim=(1, 2), keepdim=True)] = 1
else:
kernel = log_rbf(source_pos, target_pos, sigma)
if interpolation_method == "normalized":
kernel = F.softmax(kernel, dim=1)
kernel.nan_to_num_()
elif interpolation_method.startswith("kriging"):
kernel = kernel.exp_()
kernel = source_kernel_invs @ kernel
if interpolation_method == "kriging_normalized":
kernel = kernel / kernel.sum(1, keepdim=True)
elif interpolation_method == "rbf":
kernel = kernel.exp_()
else:
assert False

# -- apply kernel
features = torch.nan_to_num(features, out=features if allow_destroy else None)
features = torch.bmm(features, kernel, out=out)

# nan-ify nonexistent chans
needs_nan = torch.isnan(target_pos).all(2).unsqueeze(1)
needs_nan = needs_nan.broadcast_to(features.shape)
features[needs_nan] = torch.nan

return features


def log_rbf(source_pos, target_pos=None, sigma=None):
"""Log of RBF kernel
This handles missing values in source_pos or target_pos, indicated by
nans, by replacing them with -inf so that they exp to 0.
Arguments
---------
source_pos : torch.tensor
n source locations
target_pos : torch.tensor
m target locations
sigma : float
Returns
-------
kernel : torch.tensor
n by m
"""
if target_pos is None:
target_pos = source_pos
kernel = torch.cdist(source_pos, target_pos)
kernel = kernel.square_().mul_(-1.0 / (2 * sigma**2))
torch.nan_to_num(kernel, nan=-torch.inf, out=kernel)
return kernel
Loading

0 comments on commit 19da675

Please sign in to comment.