Skip to content

Commit

Permalink
Merges everthing in from #117
Browse files Browse the repository at this point in the history
  • Loading branch information
neukym committed Jan 13, 2024
1 parent 484d07c commit 594a631
Show file tree
Hide file tree
Showing 8 changed files with 3,504 additions and 68 deletions.
98 changes: 59 additions & 39 deletions demos/demo_ippm.ipynb

Large diffs are not rendered by default.

76 changes: 73 additions & 3 deletions kymata/ippm/data_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@
import matplotlib.colors
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.preprocessing import normalize
from sklearn.metrics.pairwise import euclidean_distances
from copy import deepcopy
import requests
import seaborn as sns
from matplotlib.lines import Line2D

import math
from kymata.entities.expression import HexelExpressionSet, DIM_FUNCTION, DIM_LATENCY


class IPPMHexel(object):
"""
Container to hold data about a hexel spike.
Expand Down Expand Up @@ -230,4 +233,71 @@ def get_latency(func_hexels: IPPMHexel, mini: bool):
causality_violations += 1
total_arrows += 1

return causality_violations / total_arrows if total_arrows != 0 else 0
return causality_violations / total_arrows if total_arrows != 0 else 0

def convert_to_power10(hexels: Dict[str, IPPMHexel]) -> Dict[str, IPPMHexel]:
for func, hexel in hexels.items():
hexels[func].right_best_pairings = list(map(lambda x: (x[0], math.pow(10, x[1])), hexels[func].right_best_pairings))
hexels[func].left_best_pairings = list(map(lambda x: (x[0], math.pow(10, x[1])), hexels[func].left_best_pairings))
return hexels


def remove_excess_funcs(to_retain: List[str], hexels: Dict[str, IPPMHexel]) -> Dict[str, IPPMHexel]:
funcs = list(hexels.keys())
for func in funcs:
if not func in to_retain:
# delete
hexels.pop(func)
return hexels

def plot_k_dist_1D(pairings: List[Tuple[float, float]], k: int=4, normalise: bool=False):
alpha = 3.55e-15
X = pd.DataFrame(columns=['Latency'])
for latency, spike in pairings:
if spike <= alpha:
X.loc[len(X)] = [latency]

if normalise:
X = normalize(X)

distance_M = euclidean_distances(X) # rows are points, columns are other points same order with values as distances
k_dists = []
for r in range(len(distance_M)):
sorted_dists = sorted(distance_M[r], reverse=True) # descending order
k_dists.append(sorted_dists[k]) # store k-dist
sorted_k_dists = sorted(k_dists, reverse=True)
plt.plot(list(range(0, len(sorted_k_dists))), sorted_k_dists)
plt.show()

def copy_hemisphere(
hexels_to: Dict[str, IPPMHexel],
hexels_from: Dict[str, IPPMHexel],
hemi_to: str,
hemi_from: str,
func: str = None):
if func:
# copy only one function
if hemi_to == 'rightHemisphere' and hemi_from == 'rightHemisphere':
hexels_to[func].right_best_pairings = hexels_from[func].right_best_pairings
elif hemi_to == 'rightHemisphere' and hemi_from == 'leftHemisphere':
hexels_to[func].right_best_pairings = hexels_from[func].left_best_pairings
elif hemi_to == 'leftHemisphere' and hemi_from == 'rightHemisphere':
hexels_to[func].left_best_pairings = hexels_from[func].right_best_pairings
else:
hexels_to[func].left_best_pairings = hexels_from[func].left_best_pairings
return

for func, hexel in hexels_from.items():
if hemi_to == 'rightHemisphere' and hemi_from == 'rightHemisphere':
hexels_to[func].right_best_pairings = hexels_from[func].right_best_pairings
elif hemi_to == 'rightHemisphere' and hemi_from == 'leftHemisphere':
hexels_to[func].right_best_pairings = hexels_from[func].left_best_pairings
elif hemi_to == 'leftHemisphere' and hemi_from == 'rightHemisphere':
hexels_to[func].left_best_pairings = hexels_from[func].right_best_pairings
else:
hexels_to[func].left_best_pairings = hexels_from[func].left_best_pairings

def plot_denoised_vs_noisy(hexels: Dict[str, IPPMHexel], clusterer, title: str):
denoised_hexels = clusterer.cluster(hexels, 'rightHemisphere')
copy_hemisphere(denoised_hexels, hexels, 'leftHemisphere', 'rightHemisphere')
stem_plot(denoised_hexels, title)
136 changes: 117 additions & 19 deletions kymata/ippm/denoiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pandas as pd
from sklearn.cluster import DBSCAN as DBSCAN_, MeanShift as MeanShift_
from sklearn.mixture import GaussianMixture
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import StandardScaler, normalize

from .data_tools import IPPMHexel

Expand All @@ -22,7 +22,10 @@ def __init__(self, **kwargs):
"""
self._clusterer = None

def cluster(self, hexels: Dict[str, IPPMHexel], hemi: str, scaling: bool = False) -> Dict[str, IPPMHexel]:
def cluster(
self, hexels: Dict[str, IPPMHexel], hemi: str, normalise: bool = False, cluster_latency: bool = False,
posterior_pooling: bool = False
) -> Dict[str, IPPMHexel]:
"""
For each function in hemi, it will attempt to construct a dataframe that holds significant spikes (i.e., abova alpha).
Next, it clusters using self._clusterer. Finally, it locates the minimum (most significant) point for each cluster and saves
Expand Down Expand Up @@ -52,11 +55,30 @@ def cluster(self, hexels: Dict[str, IPPMHexel], hemi: str, scaling: bool = False
continue

# if we are renormalising each feature, then scale otherwise no
fitted = self._clusterer.fit(StandardScaler().fit_transform(df)) if scaling else self._clusterer.fit(df)
if cluster_latency:
# cluster only the latency dimension.
latency_only = self._get_latency_dim(df)
fitted = (self._clusterer.fit(latency_only)
if not normalise else
self._clusterer.fit(normalize(latency_only)))
else:
fitted = self._clusterer.fit(normalize(df)) if normalise else self._clusterer.fit(df)
df['Label'] = fitted.labels_
cluster_mins = self._get_cluster_mins(df)
hexels = self._update_pairings(hexels, func, cluster_mins, hemi)
return hexels if not posterior_pooling else self._posterior_pooling(hexels, hemi)

def _get_latency_dim(self, df: pd.DataFrame) -> np.ndarray:
return np.reshape(df['Latency'], (-1, 1))

def _posterior_pooling(self, hexels: Dict[str, IPPMHexel], hemi: str) -> Dict[str, IPPMHexel]:
for func in hexels.keys():
if len(hexels[func].left_best_pairings) != 0 and hemi == 'leftHemisphere':
hexels[func].left_best_pairings = [min(hexels[func].left_best_pairings, key=lambda x: x[1])]
elif len(hexels[func].right_best_pairings) != 0 and hemi == 'rightHemisphere':
hexels[func].right_best_pairings = [min(hexels[func].right_best_pairings, key=lambda x: x[1])]
return hexels


def _hexels_to_df(self, hexels: Dict[str, IPPMHexel], hemi: str) -> pd.DataFrame:
"""
Expand Down Expand Up @@ -224,7 +246,10 @@ def __init__(self, **kwargs):
print('Bin size needs to be an integer.')
raise ValueError

def cluster(self, hexels: Dict[str, IPPMHexel], hemi: str, scaling: bool = False) -> List[Tuple[float, float]]:
def cluster(
self, hexels: Dict[str, IPPMHexel], hemi: str, normalise: bool = False,
cluster_latency: bool = False, posterior_pooling: bool = False
) -> Dict[str, IPPMHexel]:
"""
Custom clustering method since it differs from other unsupervised techniques.
Expand Down Expand Up @@ -278,7 +303,7 @@ def cluster(self, hexels: Dict[str, IPPMHexel], hemi: str, scaling: bool = False

hexels = super()._update_pairings(hexels, func, ret, hemi)

return hexels
return hexels if not posterior_pooling else super()._posterior_pooling(hexels, hemi)

def _cluster_bin(self, df: pd.DataFrame, r_idx: int, latency: int) -> Tuple[float, int, int, int]:
"""
Expand Down Expand Up @@ -319,10 +344,65 @@ def _cluster_bin(self, df: pd.DataFrame, r_idx: int, latency: int) -> Tuple[floa

return bin_min, lat_min, num_seen, r_idx

class AdaptiveMaxPooler(DenoisingStrategy):
def __init__(self, base_bin_sz: int=10, threshold: int=5):
self._threshold = threshold
self._base_bin_sz = base_bin_sz

def cluster(
self, hexels: Dict[str, IPPMHexel], hemi: str, normalise: bool = False,
cluster_latency: bool = False, posterior_pooling: bool = False
) -> Dict[str, IPPMHexel]:
hexels = deepcopy(hexels)
get_default_vals = lambda _: (np.inf, None)
for func, df in super()._hexels_to_df(hexels, hemi):
if len(df) == 0:
hexels = super()._update_pairings(hexels, func, [], hemi)
continue

df = df.sort_values(by='Latency')

df_ptr = 0 # index into df
end_ptr = 1 # guarenteed to have > 1 data point. delineates end of bin
start_ptr = 0
total_bins = 1000 / self._base_bin_sz
prev_bin_min, prev_bin_lat_min = get_default_vals('a')
prev_signi = False
ret = []
while df_ptr < len(df) and start_ptr < total_bins:
end_ms = end_ptr * self._base_bin_sz
num_in_bin = 0
cur_bin_min, cur_bin_lat_min = get_default_vals('a')
while df_ptr < len(df) and df.iloc[df_ptr, 0] < end_ms:
if df.iloc[df_ptr, 1] < cur_bin_min:
cur_bin_min, cur_bin_lat_min = df.iloc[df_ptr, 1], df.iloc[df_ptr, 0]
num_in_bin += 1
df_ptr += 1
if num_in_bin >= self._threshold:
end_ptr += 1
prev_signi = True
if cur_bin_min < prev_bin_min:
prev_bin_min, prev_bin_lat_min = cur_bin_min, cur_bin_lat_min
else:
if prev_signi:
# start_ptr to end_ptr is significant
ret.append((prev_bin_lat_min, prev_bin_min))
prev_bin_min, prev_bin_lat_min = get_default_vals('a')
prev_signi = False
start_ptr = end_ptr
end_ptr += 1
if prev_signi:
# last bin was significant and we expanded
ret.append((prev_bin_lat_min, prev_bin_min))
prev_bin_min, prev_bin_lat_min = get_default_vals('a')
prev_signi = False
hexels = super()._update_pairings(hexels, func, ret, hemi)
return hexels if not posterior_pooling else super()._posterior_pooling(hexels, hemi)


class GMM(DenoisingStrategy):
"""
This strategy uses the GaussianMixtureModel algorithm. Intuitively, it attempts to fit a multimodal Gaussian distribution to the data using the EM algorithm.
This strategy uses the GaussianMixtureModel algorithm. It attempts to fit a multimodal Gaussian distribution to the data using the EM algorithm.
The primary disadvantage of this model is that the number of Gaussians have to be prespecified. This implementation does a grid search from 1 to max_gaussians
to find the optimal number of Gaussians. Moreover, it does not work well with anomalies.
"""
Expand All @@ -345,12 +425,13 @@ def __init__(self, **kwargs):
set this if you want your results to be reproducible.
"""
# we are instantiating multiple models, so save hyperparameters instead of clusterer object.
self._max_gaussians = 6 if not 'max_gaussians' in kwargs.keys() else kwargs['max_gaussians']
self._max_gaussians = 5 if not 'max_gaussians' in kwargs.keys() else kwargs['max_gaussians']
self._covariance_type = 'full' if not 'covariance_type' in kwargs.keys() else kwargs['covariance_type']
self._max_iter = 100 if not 'max_iter' in kwargs.keys() else kwargs['max_iter']
self._n_init = 3 if not 'n_init' in kwargs.keys() else kwargs['n_init']
self._init_params = 'k-means++' if not 'init_params' in kwargs.keys() else kwargs['init_params']
self._max_iter = 1000 if not 'max_iter' in kwargs.keys() else kwargs['max_iter']
self._n_init = 8 if not 'n_init' in kwargs.keys() else kwargs['n_init']
self._init_params = 'kmeans' if not 'init_params' in kwargs.keys() else kwargs['init_params']
self._random_state = None if not 'random_state' in kwargs.keys() else kwargs['random_state']
self._is_aic = False if not 'is_aic' in kwargs.keys() else kwargs['is_aic'] # default is BIC since it is better for explanatory models, since it assumes reality lies within the hypothesis space.

invalid = False
if type(self._max_gaussians) != int:
Expand All @@ -376,7 +457,10 @@ def __init__(self, **kwargs):
raise ValueError


def cluster(self, hexels: Dict[str, IPPMHexel], hemi: str, scaling: bool = False) -> Dict[str, IPPMHexel]:
def cluster(
self, hexels: Dict[str, IPPMHexel], hemi: str, normalise: bool = False,
cluster_latency: bool = False, posterior_pooling: bool = False
) -> Dict[str, IPPMHexel]:
"""
Overriding the superclass cluster function because we want to perform a grid-search over the number of clusters to locate the optimal one.
It works similarly to the superclass.cluster method but it performs it multiple times. It stops if the number of data points < number of clusters as
Expand All @@ -396,6 +480,10 @@ def cluster(self, hexels: Dict[str, IPPMHexel], hemi: str, scaling: bool = False
super()._check_hemi(hemi)
hexels = deepcopy(hexels)
for func, df in super()._hexels_to_df(hexels, hemi):
if len(df) == 0:
hexels = super()._update_pairings(hexels, func, [], hemi)
continue

if len(df) == 1:
# no point clustering, just return the single data point.
ret = []
Expand All @@ -405,7 +493,7 @@ def cluster(self, hexels: Dict[str, IPPMHexel], hemi: str, scaling: bool = False
continue

best_labels = None
best_score = np.inf # use aic, bic or silhouette score for model selection. if silhouette, switch this to -inf since we wanna maximise it.
best_score = np.inf
for n in range(1, self._max_gaussians):
if n > len(df):
# the number of gaussians has to be less than the number of datapoints.
Expand All @@ -416,19 +504,29 @@ def cluster(self, hexels: Dict[str, IPPMHexel], hemi: str, scaling: bool = False
n_init=self._n_init,
init_params=self._init_params,
random_state=self._random_state)
scaler = StandardScaler()
gmm.fit(scaler.fit_transform(df)) if scaling else gmm.fit(df)
temp = None
if normalise and cluster_latency:
temp = np.reshape(normalize(df['Latency']), (-1, 1))
if not normalise and cluster_latency:
temp = np.reshape(df['Latency'], (-1, 1))
if normalise and not cluster_latency:
temp = normalize(df)
else:
temp = df

score = gmm.bic(df) # gmm.aic(df) for AIC score.
gmm.fit(temp)
score = gmm.aic(temp) if self._is_aic else gmm.bic(temp)
labels = gmm.predict(temp)

if score < best_score:
# this condition depends on the choice of AIC/BIC/silhouette. if using silhouette, reverse the inequality.
best_labels = gmm.predict(scaler.transform(df)) if scaling else gmm.predict(df)
# this condition depends on the choice of AIC/BIC
best_labels = labels
best_score = score

df['Label'] = best_labels
cluster_mins = super()._get_cluster_mins(df)
hexels = super()._update_pairings(hexels, func, cluster_mins, hemi)
return hexels
return hexels if not posterior_pooling else super()._posterior_pooling(hexels, hemi)


class DBSCAN(DenoisingStrategy):
Expand Down Expand Up @@ -522,7 +620,7 @@ class MeanShift(DenoisingStrategy):
"""
def __init__(self, **kwargs):
cluster_all = False if not 'cluster_all' in kwargs.keys() else kwargs['cluster_all']
bandwidth = None if not 'bandwidth' in kwargs.keys() else kwargs['bandwidth']
bandwidth = 30 if not 'bandwidth' in kwargs.keys() else kwargs['bandwidth']
seeds = None if not 'seeds' in kwargs.keys() else kwargs['seeds']
min_bin_freq = 2 if not 'min_bin_freq' in kwargs.keys() else kwargs['min_bin_freq']
n_jobs = -1 if not 'n_jobs' in kwargs.keys() else kwargs['n_jobs']
Expand Down
Loading

0 comments on commit 594a631

Please sign in to comment.