Skip to content

Commit

Permalink
Decollider residual stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Oct 10, 2024
1 parent 3effa7f commit eb22f0b
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 21 deletions.
131 changes: 111 additions & 20 deletions src/dartsort/transform/decollider.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,16 @@ def __init__(
learning_rate=1e-3,
epochs=25,
channelwise_dropout_p=0.2,
inference_z_samples=10,
n_data_workers=0,
with_conv_fullheight=False,
sample_weighting=None,
detach_amortizer=True,
eyz_net_residual=True,
e_exz_y_net_residual=True,
):
assert inference_kind in ("raw", "amortized")
assert exz_estimator in ("n2n", "2n2", "n3n", "3n3")
assert inference_kind in ("raw", "exz", "amortized")
assert inference_kind in ("raw", "exz", "exz_fromz", "amortized", "exy_fake")
assert sample_weighting in (None, "kmeans")
super().__init__(
geom=geom, channel_index=channel_index, name=name, name_prefix=name_prefix
Expand All @@ -55,6 +58,10 @@ def __init__(
self.n_data_workers = n_data_workers
self.with_conv_fullheight = with_conv_fullheight
self.sample_weighting = sample_weighting
self.inference_z_samples = inference_z_samples
self.detach_amortizer = detach_amortizer
self.eyz_net_residual = eyz_net_residual
self.e_exz_y_net_residual = e_exz_y_net_residual

self.model_channel_index_np = regularize_channel_index(
geom=self.geom, channel_index=channel_index
Expand All @@ -76,7 +83,7 @@ def __init__(
def needs_fit(self):
return self._needs_fit

def get_mlp(self):
def get_mlp(self, residual=False):
return nn_util.get_waveform_mlp(
self.spike_length_samples,
self.model_channel_index.shape[1],
Expand All @@ -88,6 +95,7 @@ def get_mlp(self):
return_initial_shape=True,
initial_conv_fullheight=self.with_conv_fullheight,
final_conv_fullheight=self.with_conv_fullheight,
residual=residual,
)

def initialize_nets(self, spike_length_samples):
Expand All @@ -97,11 +105,11 @@ def initialize_nets(self, spike_length_samples):
)

if self.exz_estimator in ("n2n", "n3n"):
self.eyz = self.get_mlp()
self.eyz = self.get_mlp(residual=self.eyz_net_residual)
if self.exz_estimator in ("n3n", "2n2", "3n3"):
self.emz = self.get_mlp()
if self.inference_kind == "amortized":
self.inf_net = self.get_mlp()
self.inf_net = self.get_mlp(residual=self.e_exz_y_net_residual)
self.to(self.relative_index.device)

def fit(self, waveforms, max_channels):
Expand All @@ -112,14 +120,50 @@ def fit(self, waveforms, max_channels):

def forward(self, waveforms, max_channels):
"""Called only at inference time."""
# TODO: batch all of this.
waveforms = reindex(max_channels, waveforms, self.relative_index, pad_value=0.0)
masks = self.get_masks(max_channels).to(waveforms)
net_input = waveforms, masks.unsqueeze(1)

if self.inference_kind == "amortized":
pred = self.inf_net(net_input)
elif self.inference_kind == "raw":
pred = self.eyz(net_input)
if hasattr(self, "emz"):
emz = self.emz(net_input)
pred = waveforms - emz
elif hasattr(self, "eyz"):
pred = self.eyz(net_input)
else:
assert False
elif self.inference_kind == "exz_fromz":
pred = torch.zeros_like(waveforms)
for j in range(self.inference_z_samples):
m = get_noise(
self.recording,
max_channels.numpy(force=True),
self.model_channel_index_np,
spike_length_samples=self.spike_length_samples,
rg=None,
)
m = m.to(waveforms)
z = waveforms + m
net_input = z, masks.unsqueeze(1)
if self.exz_estimator == "n2n":
eyz = self.eyz(net_input)
pred += 2 * eyz - z
elif self.exz_estimator == "2n2":
emz = self.emz(net_input)
pred += z - 2 * emz
elif self.exz_estimator == "n3n":
eyz = self.eyz(net_input)
emz = self.emz(net_input)
pred += eyz - emz
elif self.exz_estimator == "3n3":
emz = self.emz(net_input)
pred += z - emz
else:
assert False
pred /= self.inference_z_samples
elif self.inference_kind == "exz":
if self.exz_estimator == "n2n":
eyz = self.eyz(net_input)
Expand All @@ -136,12 +180,6 @@ def forward(self, waveforms, max_channels):
pred = waveforms - emz
else:
assert False
elif self.inference_kind == "exy_fake":
if self.exz_estimator in ("2n2", "3n3"):
emz = self.emz(net_input)
pred = waveforms - emz
else:
assert False
else:
assert False

Expand All @@ -156,6 +194,7 @@ def train_forward(self, y, m, mask):
z = y + m

# predictions given z
# TODO: variance given z and put it in the loss
exz = eyz = emz = e_exz_y = None
net_input = z, mask.unsqueeze(1)
if self.exz_estimator == "n2n":
Expand Down Expand Up @@ -188,7 +227,11 @@ def loss(self, mask, waveforms, m, exz, eyz=None, emz=None, e_exz_y=None):
if emz is not None:
loss_dict["emz"] = F.mse_loss(mask * emz, mask * m)
if e_exz_y is not None:
loss_dict["e_exz_y"] = F.mse_loss(mask * exz, mask * e_exz_y)
to_amortize = exz
if self.detach_amortizer:
# should amortize-ability affect the learning of eyz, emz?
to_amortize = to_amortize.detach()
loss_dict["e_exz_y"] = F.mse_loss(mask * to_amortize, mask * e_exz_y)
return loss_dict

def _fit(self, waveforms, channels):
Expand Down Expand Up @@ -322,11 +365,11 @@ def __getitem__(self, index):
return noise


def kmeanspp_density_estimate(x, n_components=1024, n_iter=10, sigma=10.0, learn_sigma=True, rg=0):
def kmeanspp_density_estimate(x, n_components=256, n_iter=10, sigma=10.0, learn_sigma=True, sigma_per_comp=True, rg=0, eps=1e-6, drop_prop=1e-3, with_proportions=False, scale_by_dim=False):
rg = np.random.default_rng(0)

# kmeanspp
n = len(x)
n, dim = x.shape
centroid_ixs = []
dists = torch.full(
(n,), torch.inf, dtype=x.dtype, device=x.device
Expand All @@ -351,9 +394,24 @@ def kmeanspp_density_estimate(x, n_components=1024, n_iter=10, sigma=10.0, learn
if n_iter:
centroids = x[centroid_ixs]
dists = torch.cdist(x, centroids).square_()
if with_proportions:
proportions = torch.ones(len(centroids)).to(dists) / len(centroids)
for i in range(n_iter):
# update responsibilities, n x k
e = F.softmax(-0.5 * dists / (sigma ** 2), dim=1)
if with_proportions:
e = F.softmax(-0.5 * dists + proportions.log(), dim=1)
else:
e = F.softmax(-0.5 * dists, dim=1)

# delete too-small centroids
if drop_prop is not None:
props = e.mean(0)
keep = props >= drop_prop
e = e[:, keep]

# update proportions while still spike-normalized (resps)
if with_proportions:
proportions = props[keep]

# normalize per centroid
e = e.div_(e.sum(0))
Expand All @@ -363,14 +421,47 @@ def kmeanspp_density_estimate(x, n_components=1024, n_iter=10, sigma=10.0, learn
dists = torch.cdist(x, centroids).square_()
assignments = torch.argmin(dists, 1)
if learn_sigma:
sigma = torch.take_along_dim(dists, assignments[:, None], dim=1).mean().sqrt()
if sigma_per_comp:
w = (e + eps) / (1 + dim * eps)
sigma = ((w * dists).sum(0) / dim).sqrt()
print(f"{i=} {sigma.min()=} {sigma.max()=}")
else:
sigma = torch.take_along_dim(dists, assignments[:, None], dim=1).mean().div(dim).sqrt()
if e.shape[1] == 1:
break

# estimate densities
dists = torch.cdist(x, centroids).square_()
e = F.softmax(-0.5 * dists / (sigma ** 2), dim=1)
component_proportion = e.mean(0)
density = e @ component_proportion
if with_proportions:
e = F.softmax(-0.5 * dists + proportions.log(), dim=1)
else:
e = F.softmax(-0.5 * dists, dim=1)
proportions = e.mean(0)
keep = proportions > drop_prop
proportions = proportions[keep]
if sigma_per_comp:
sigma = sigma[keep]
dists = dists[:, keep]
if with_proportions:
e = F.softmax(-0.5 * dists + proportions.log(), dim=1)
else:
e = F.softmax(-0.5 * dists, dim=1)
proportions = e.mean(0)
w = e + eps
w = w / w.sum(0)
component_sigmasq = (w * dists).sum(0)
if scale_by_dim:
pass
else:
component_sigmasq /= dim
component_dens = proportions / component_sigmasq
component_dens = component_dens / component_dens.sum()
density = e @ component_dens
print(f"{proportions.shape=} {component_sigmasq.shape=} {e.shape=} {component_dens.shape=} {dists.shape=}")
print(f"{sigma.min()=} {sigma.max()=}")
print(f"{component_dens.min()=} {component_dens.max()=}")
print(f"{proportions.min()=} {proportions.max()=}")
print(f"{component_sigmasq.min()=} {component_sigmasq.max()=}")
print(f"{density.min()=} {density.max()=}")

return density
30 changes: 29 additions & 1 deletion src/dartsort/util/nn_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def get_waveform_mlp(
initial_conv_fullheight=False,
final_conv_fullheight=False,
return_initial_shape=False,
residual=False,
):
input_dim = n_input_channels * (spike_length_samples + input_includes_mask)

Expand Down Expand Up @@ -59,7 +60,34 @@ def get_waveform_mlp(
))
layers.append(nn.ReLU())
layers.append(nn.Conv1d(spike_length_samples, spike_length_samples, kernel_size=1))
return nn.Sequential(*layers)

net = nn.Sequential(*layers)
if residual:
net = WaveformOnlyResidualForm(net)

return net


class ResidualForm(nn.Module):

def __init__(self, module):
super().__init__()
self.module = module

def forward(self, input):
output = self.module(input)
return input + output


class WaveformOnlyResidualForm(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module

def forward(self, inputs):
waveforms, masks = inputs
output = self.module(inputs)
return waveforms + output


class ChannelwiseDropout(nn.Module):
Expand Down
28 changes: 28 additions & 0 deletions src/dartsort/util/residual_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from spikeinterface.preprocessing.basepreprocessor import BasePreprocessor, BasePreprocessorSegment


class PeeledResidualRecording(BasePreprocessor):

def __init__(self, peeler):
# this does not handle segments at all!
super().__init__(peeler.recording, dtype=peeler.dtype)
assert peeler.recording.get_num_segments() == 1
self.add_recording_segment(PeeledResidualSegment(peeler))
self._kwargs = dict(peeler=peeler, recording=recording)


class PeeledResidualSegment(BasePreprocessorSegment):
def __init__(self, peeler):
self.peeler = peeler

def get_traces(self, start_frame, end_frame, channel_indices):
# right now the peeler dives into segment 0 in process_chunk
# i guess that's something to think about...
stuff = self.peeler.process_chunk(
chunk_start_samples=start_frame,
chunk_end_samples=end_frame,
return_residual=True,
skip_features=True,
)
return stuff['residual'].numpy(force=True)[:, channel_indices]

0 comments on commit eb22f0b

Please sign in to comment.