Skip to content

Commit

Permalink
Initial mixure of PPCAs impl
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Nov 21, 2024
1 parent a7cbe82 commit 3a8cd1e
Show file tree
Hide file tree
Showing 4 changed files with 414 additions and 146 deletions.
202 changes: 74 additions & 128 deletions src/dartsort/cluster/gaussian_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .cluster_util import agglomerate
from .kmeans import kmeans
from .modes import smoothed_dipscore_at
from .ppcalib import ppca_em
from .stable_features import (
SpikeFeatures,
StableSpikeDataset,
Expand Down Expand Up @@ -45,11 +46,11 @@ def __init__(
cov_kind="zero",
use_proportions: bool = True,
proportions_sample_size: int = 2**16,
prior_type="niw",
channels_strategy="snr",
channels_strategy_snr_min=5.0,
scale_mean: float = 0.1,
prior_pseudocount: float = 10.0,
ppca_rank: int = 0,
ppca_inner_em_iter: int = 1,
random_seed: int = 0,
n_threads: int = 4,
min_count: int = 50,
Expand Down Expand Up @@ -124,11 +125,11 @@ def __init__(
noise=noise,
mean_kind=mean_kind,
cov_kind=cov_kind,
prior_type=prior_type,
channels_strategy=channels_strategy,
channels_strategy_snr_min=channels_strategy_snr_min,
prior_pseudocount=prior_pseudocount,
scale_mean=scale_mean,
ppca_rank=ppca_rank,
ppca_inner_em_iter=ppca_inner_em_iter,
)

# clustering with noise unit to hopefully grab false positives
Expand Down Expand Up @@ -811,7 +812,7 @@ def noise_log_likelihoods(self, show_progress=False):
def kmeans_split_unit(self, unit_id, debug=False):
# get spike data and use interpolation to fill it out to the
# unit's channel set
result = dict(parent_id=parent_id)
result = dict(parent_id=unit_id)
unit = self.units[unit_id]
if not unit.channels.numel():
return result
Expand Down Expand Up @@ -1351,8 +1352,9 @@ def __init__(
channels_strategy="snr",
channels_strategy_snr_min=50.0,
prior_pseudocount=10,
ppca_inner_em_iter=1,
ppca_rank=0,
scale_mean: float = 0.1,
mean=None,
**annotations,
):
super().__init__()
Expand All @@ -1369,8 +1371,8 @@ def __init__(
self.scale_alpha = float(prior_pseudocount)
self.scale_beta = float(prior_pseudocount) / scale_mean
self.annotations = annotations
if mean is not None:
self.register_buffer("mean", mean)
self.ppca_rank = ppca_rank
self.ppca_inner_em_iter = ppca_inner_em_iter

@classmethod
def from_features(
Expand All @@ -1383,8 +1385,10 @@ def from_features(
cov_kind="zero",
prior_type="niw",
channels_strategy="snr",
ppca_rank=0,
channels_strategy_snr_min=50.0,
prior_pseudocount=10,
ppca_inner_em_iter=1,
scale_mean: float = 0.1,
):
self = cls(
Expand All @@ -1398,29 +1402,13 @@ def from_features(
channels_strategy=channels_strategy,
channels_strategy_snr_min=channels_strategy_snr_min,
scale_mean=scale_mean,
ppca_rank=ppca_rank,
ppca_inner_em_iter=ppca_inner_em_iter,
)
self.fit(features, weights, neighborhoods=neighborhoods)
self = self.to(features.features.device)
return self

def avg_with(self, *others):
new = self.__class__(
mean_kind=self.mean_kind,
cov_kind=self.cov_kind,
rank=self.rank,
n_channels=self.n_channels,
noise=self.noise,
)
new.register_buffer(
"mean", (self.mean + sum(o.mean for o in others)) / (1 + len(others))
)
new.register_buffer(
"channels",
torch.cat([self.channels, *[o.channels for o in others]]).unique(),
)
assert self.cov_kind == "zero"
return new

def n_params(self, channels=None, on_channels=True):
p = channels.new_zeros(len(channels))
if self.mean_kind == "full":
Expand All @@ -1429,109 +1417,63 @@ def n_params(self, channels=None, on_channels=True):
assert self.cov_kind == "zero"
return p

def fit(self, features: SpikeFeatures, weights: torch.Tensor, neighborhoods=None):
def fit(
self,
features: SpikeFeatures,
weights: torch.Tensor,
neighborhoods=None,
show_progress=False,
):
if features is None:
self.pick_channels(None, None)
return
n = len(features)
r = self.noise.rank
new_zeros = features.features.new_zeros

achans = occupied_chans(features, self.n_channels, neighborhoods=neighborhoods)
target_padded = features.features.new_zeros(n, r, achans.numel() + 1)
if weights is None:
weights = features.features.new_ones(n)
afeats, aweights = zero_pad_to_chans(
features,
achans,
self.n_channels,
weights=weights,
target_padded=target_padded,
)
assert torch.isfinite(afeats).all()
assert torch.isfinite(aweights).all()
aweights_sum = aweights.sum(0)

# assigns self.mean
self.fit_mean(achans, afeats, aweights, aweights_sum)

# assigns self.cov, self.logdet
self.fit_cov(
features,
achans,
aweights,
aweights_sum,
neighborhoods=neighborhoods,
target_padded=target_padded,
)
del features # overwritten

self.pick_channels(achans, aweights_sum)

def fit_mean(self, achans, afeats, aweights, aweights_sum) -> SpikeFeatures:
if self.mean_kind == "zero":
return
je_suis = achans.numel()

active_mean = active_W = None
if hasattr(self, "mean"):
active_mean = self.mean[:, achans]
if hasattr(self, "W"):
active_W = self.W[:, achans]

if je_suis:
res = ppca_em(
sp=features,
noise=self.noise,
neighborhoods=neighborhoods,
active_channels=achans,
active_mean=active_mean,
active_W=active_W,
weights=weights,
cache_prefix="extract",
M=self.ppca_rank,
n_iter=self.ppca_inner_em_iter,
mean_prior_pseudocount=self.prior_pseudocount,
show_progress=show_progress,
W_initialization="zeros",
)

assert self.mean_kind == "full"
aweights_norm = aweights / aweights_sum
am = torch.linalg.vecdot(aweights_norm.unsqueeze(1), afeats, dim=0)

if self.prior_type == "niw":
assert self.noise.mean_kind == "zero"
count_full = self.prior_pseudocount + aweights_sum
# w0 = self.prior_pseudocount / count_full
w1 = aweights_sum / count_full
am = am * w1 # + self.noise.mean_full * w0
elif self.prior_type == "none":
pass
if hasattr(self, "mean"):
mean_full = self.mean
mean_full.fill_(0.0)
else:
assert False
mean_full = new_zeros((self.noise.rank, self.noise.n_channels))

mean_full = am.new_zeros((self.noise.rank, self.noise.n_channels))
mean_full[:, achans] = am
self.register_buffer("mean", mean_full)

def fit_cov(
self,
features,
achans,
aweights,
aweights_sum,
neighborhoods=None,
target_padded=None,
):
if self.cov_kind == "zero":
self.logdet = self.noise.logdet
return
if hasattr(self, "W"):
W_full = self.mean
W_full.fill_(0.0)
elif "W" in res:
W_full = new_zeros((self.noise.rank, self.noise.n_channels, self.ppca_rank))

if self.cov_kind == "scaled_template":
# todo: is there some issue with centering and weights
# zeros get filled in, don't want to subtract mean and leave nonzero
spw, nu = noise_whiten(
features,
self.noise,
neighborhoods,
mean_full=self.mean,
with_whitened_means=True,
in_place=True,
)
del features # overwritten
wfeats, _ = zero_pad_to_chans(
spw, achans, self.noise.n_channels, target_padded=target_padded
)
spw = replace(spw, features=nu)
wnu, _ = zero_pad_to_chans(spw, achans, self.noise.n_channels)
del spw # overwritten
self.template_std = template_scale_map(
wfeats,
wnu,
aweights,
alpha=self.scale_alpha,
beta=self.scale_beta,
allow_destroy=True,
)
return

assert False
if je_suis:
mean_full[:, achans] = res["mu"]
self.register_buffer("mean", mean_full)
if "W" in res:
W_full[:, achans] = res["W"]
self.register_buffer("mean", W_full)
self.pick_channels(achans, res["nobs"])

def pick_channels(self, active_chans, aweights_sum):
if self.channels_strategy == "all":
Expand All @@ -1557,16 +1499,20 @@ def marginal_covariance(self, channels, cache_key=None, device=None):
ncov = self.noise.marginal_covariance(
channels, cache_key=cache_key, device=device
)
if self.cov_kind == "zero":
zero_signal = (
self.cov_kind == "zero" or self.cov_kind == "ppca" and not self.ppca_rank
)
if zero_signal:
return ncov
if self.cov_kind == "scaled_template":
root = self.template_std * self.mean[:, channels].reshape(-1, 1)
if self.cov_kind == "ppca" and self.ppca_rank:
root = self.W[:, channels].reshape(-1, self.ppca_rank)
root = operators.LowRankRootLinearOperator(root)
cov = more_operators.LowRankRootSumLinearOperator(
root,
ncov,
)
return cov
# cov = more_operators.LowRankRootSumLinearOperator(
# root,
# ncov,
# )
# i believe this calls .add_low_rank()
return ncov + root

assert False

Expand Down
Loading

0 comments on commit 3a8cd1e

Please sign in to comment.