From 18f85e314b2e9c60ccbed0658bb77b0685db3037 Mon Sep 17 00:00:00 2001 From: Anirudh Lakra <55798132+anirudh1666@users.noreply.github.com> Date: Sat, 2 Mar 2024 16:41:57 +0000 Subject: [PATCH] Ippm optimisations (#211) * Fixing lint errors * Fixing lint errors * Fixing lint errors * Optimisations + Comments for IPPM * Fixing lint errors * Removing changes in plain.py I think your branch might be behind main, as `main` should pass the linting checks without this change in gridsearch/plain.py. --------- Co-authored-by: Anirudh Lakra Co-authored-by: neukym --- kymata/ippm/builder.py | 152 +++++++++---- kymata/ippm/data_tools.py | 116 +++++++++- kymata/ippm/denoiser.py | 412 ++++++++++++++++++++++++++++-------- kymata/ippm/plotter.py | 20 +- tests/test_ippm_builder.py | 4 +- tests/test_ippm_denoiser.py | 46 ++-- tests/test_plotting.py | 1 + 7 files changed, 579 insertions(+), 172 deletions(-) diff --git a/kymata/ippm/builder.py b/kymata/ippm/builder.py index 9dd58ede..6f0d21e9 100644 --- a/kymata/ippm/builder.py +++ b/kymata/ippm/builder.py @@ -6,7 +6,6 @@ from .data_tools import IPPMHexel - # convenient tuple/class to hold information about nodes. Node = namedtuple('Node', 'magnitude position in_edges') @@ -14,7 +13,7 @@ class IPPMBuilder(object): """ A graphing class used to construct a dictionary that contains the nodes and all relevant - information to construct a nx.DiGraph. + information to construct a dict containing node names as keys and Node objects (see namedtuple) as values. """ def build_graph(self, @@ -25,41 +24,112 @@ def build_graph(self, """ Builds a dictionary of nodes and information about the node. The information is built out of namedtuple class Node, which contains magnitude, position, color, and - the incoming edges. + the incoming edges + + Analysis + -------- + + sorting takes nlogn where n = # of hexels = 10000. assumption: quicksort == sort + we do it for every f, so f * nlogn + next we loop through every f and touch every pairing and parent. In the worst case, + the number of pairings == # of hexels, and # of parent == # of fs - 1. Hence, this + part has O(f * (n + f-1)) = O(f * n + f^2). + + Total complexity: O(f * nlogn + f * n + f^2) + What dominates: f * nlogn ~ f * n. if logn > 1, then f * nlogn > f * n. + We have n = 10,000. So logn > 1, hence f * nlogn > f * n. + As long as n >= 10, logn >= 1. Since this is big-O, we take worst + case but in reality, the # of pairings is typically from 0-10. + So, f * n >= f * nlogn but big-O is worst-case, so we stick with nlogn. + + f * nlogn ~ f^2. if nlogn > f, f * nlogn > f^2. + assuming n = 10,000, nlogn = 40000 > f = 12. + f = 12 + Hence, f * nlogn is actually the dominant term. + Can we reduce it? It would involve updating the algorithm to avoid + sorting prior to building the graph. The primary we reason to sort + is so that we can exploit the naming of functions being ordered and + immediately add an edge from the last pairing of a parent to the first + pairing of a child. Without it being sorted, we would have to loop + through the parent pairings to locate the last node. Hence, + the complexity becomes: O(f * (n - (f - 1) * n)) = O(f * n - f^2 * n). + Now f^2 * n > nlogn, so it would actually make the algorithm slower. + Moreover, we took an unrealistic worst case assumption of n = 10,000. + In reality, it would be between 0-10, so nlogn would be approximately less + than or equal to f. So, in practice, it would not dominate. Especially + as the dataset quality improves, the number of spikes would go down and + the number of functions would increase. + + Therefore, we cannot reduce the complexity further. + Final complexity: O(f * nlogn). + + + Space Complexity: Let n be the maximum pairs of spikes out of all functions. + We copy hexels, so O(n * f) + We copy func hier, so O(f * (f-1)) = O(f^2) + Our dict will contain a key for every pairing. Hence, it will be of size O(f * n). + Total space: O(n * f + f^2 + f * n) = O(f * n). + + We could feasibly trade-off time for space complexity but I think it is good as it is. + + Algorithm + --------- + + It iteratively selects the top-level function, which is defined as a function that does not + have any children, i.e., it does not have any arrows going out of it. Hence, it starts with the final function and proceeds + in a top-down fashion towards the input node. Upon selecting a top-level function, it creates a spike for every pairing in + the best_pairings. Since the spikes are already ordered, we get a nice labelling where func_name-0 corresponds to the earliest + spike and func_name-{len(pairings)}-1 is the final spike. Next, we go through the parents (incoming edges) and add an edge + from the final function (i.e., inc_edge_func_name-{len(pairings)}-1) to the first current function (func_name-0). We repeat this + for all functions. Last thing to note is that the input node has to be defined because it is treated differently to the rest. The + input function has only 1 spike and a default size of 10 at latency == 0. + + We do it top-down to make the ordering of nodes clean. Otherwise, we can get a messy jumble of nodes with the input in the middle and + the final output randomly assigned. Params ------ - hexels : dictionary containing function names and Hexel objects with data in it. - - function_hier : dictionary of the format (function_name : [children_functions]) - - inputs : list of input functions. function_hier contains the input functions, so we need this to distinguish between inputs and functoins. + - function_hier : dictionary of the format (function_name : [parent_functions]) + - inputs : list of input functions. function_hier contains the input functions, so we need this to distinguish between inputs and functions. - hemi : leftHemisphere or rightHemisphere Returns ------- A dictionary of nodes with unique names where the keys are node objects with all - relevant information for plotting a nx.DiGraph. + relevant information for plotting a graph. """ - functions = list(function_hier.keys()) - # filter out functions that are unneccessary - filtered = {func : hexels[func] for func in functions if func in hexels.keys()} + + hexels = deepcopy(hexels) # do it in-place to avoid modifying hexels. + # sort it so that the null edges go in the right order (from left to right along time axis) - sorted = self._sort_by_latency(filtered, hemi) + sorted = self._sort_by_latency(hexels, hemi, list(function_hier.keys())) - hier = deepcopy(function_hier) # we will modify function_hier so copy - n_partitions = 1 / len(function_hier.keys()) # partition x -axis - part_idx = 0 - graph = {} # format: node_name : [magnitude, color, position, in_edges] + hier = deepcopy(function_hier) # we will modify function_hier so copy + n_partitions = 1 / len(function_hier.keys()) # partition y-axis + part_idx = 0 # pointer into the current partition out of [0, n_partitions). We go top-down cus of this. + # it ensures we order our nodes according to their level, with input at the bottom. + graph = {} # format: node_name : [magnitude, color, position, in_edges] while len(hier.keys()) > 0: - top_level = self._get_top_level_functions(hier) + top_level = self._get_top_level_functions(hier) # get function that is a child, i.e., it doesn't have any arrows going out of it for f in top_level: + # We do the following: + # if f == input_node: + # default_settings + # else: + # create_spike_for_every_pairing + # for parent in parents: + # add spike from final parent node to the first f node. + # While doing this, if we encounter empty pairings for f or any parent, we skip them. + if f in inputs: # input node default size is 10. hier.pop(f) graph[f] = Node(100, (0, 1 - n_partitions * part_idx), []) else: - children = hier[f] - hier.pop(f) + parents = hier[f] + hier.pop(f) # make sure to pop so we don't get the same top level fs every loop. best_pairings = ( sorted[f].left_best_pairings if hemi == 'leftHemisphere' else @@ -77,44 +147,46 @@ def build_graph(self, (latency, 1 - n_partitions * part_idx), [f + '-' + str(idx - 1)] if idx != 0 else []) - part_idx += 1 + part_idx += 1 # increment partition index - for child in children: - # add edges coming from children to f. - if child in inputs: - graph[f + '-0'].in_edges.append(child) + for parent in parents: + # add edges coming from parents to f. + if parent in inputs: + # if the parent is a input node, there is only one node. + graph[f + '-0'].in_edges.append(parent) else: - children_pairings = ( - sorted[child].left_best_pairings if hemi == 'leftHemisphere' else - sorted[child].right_best_pairings + parent_pairings = ( + sorted[parent].left_best_pairings if hemi == 'leftHemisphere' else + sorted[parent].right_best_pairings ) - if len(children_pairings) == 0: + if len(parent_pairings) == 0: # ignore this function continue - # add an edge from the final spike of a function. - graph[f + '-0'].in_edges.append(child + '-' + str(len(children_pairings) - 1)) + # add an edge from the final spike of parent to first spike of current function. + graph[f + '-0'].in_edges.append(parent + '-' + str(len(parent_pairings) - 1)) return graph - def _get_top_level_functions(self, edges: Dict[str, List[str]]) -> set: + def _get_top_level_functions(self, hier: Dict[str, List[str]]) -> set: """ Returns the top-level function. A top-level function is at the highest level of the function hierarchy. It can be found as the function that does not appear - in any of the lists of children. + in any of the lists of parents. I.e., it is a function that that does not feed into + any other functions; it only has functions feeding into it. Params ------ - edges : dictionary that contains the function hierarchy (including inputs) + hier : dictionary that contains the function hierarchy (including inputs) Returns ------- a set containing the top-level functions. """ - funcs_leftover = list(edges.keys()) - children_funcs = [f for children in edges.values() for f in children] - return set(funcs_leftover).difference(set(children_funcs)) + funcs_leftover = list(hier.keys()) + parent_funcs = [f for parents in hier.values() for f in parents] + return set(funcs_leftover).difference(set(parent_funcs)) - def _sort_by_latency(self, hexels: Dict[str, IPPMHexel], hemi: str): + def _sort_by_latency(self, hexels: Dict[str, IPPMHexel], hemi: str, functions: List[str]) -> Dict[str, IPPMHexel]: """ Sort pairings by latency in increasing order inplace. @@ -126,11 +198,15 @@ def _sort_by_latency(self, hexels: Dict[str, IPPMHexel], hemi: str): ------- sorted hexels. """ - for key in hexels.keys(): + for function in functions: + if function not in hexels.keys(): + # function was not detected in the hexels + continue + if hemi == 'leftHemisphere': - hexels[key].left_best_pairings.sort(key=lambda x: x[0]) + hexels[function].left_best_pairings.sort(key=lambda x: x[0]) else: - hexels[key].right_best_pairings.sort(key=lambda x: x[0]) + hexels[function].right_best_pairings.sort(key=lambda x: x[0]) return hexels diff --git a/kymata/ippm/data_tools.py b/kymata/ippm/data_tools.py index 76f30467..cf474b46 100644 --- a/kymata/ippm/data_tools.py +++ b/kymata/ippm/data_tools.py @@ -121,7 +121,7 @@ def build_hexel_dict_from_api_response(dict_: Dict) -> Dict[str, IPPMHexel]: ------- Dict of the format [function name, Hexel(func_name, id, left_pairings, right_pairings)] """ - hexels = {} + hexels = {} for hemi in ['leftHemisphere', 'rightHemisphere']: for (_, latency, pval, func) in dict_[hemi]: # we have id, latency (ms), pvalue (log_10), function name. @@ -215,7 +215,27 @@ def stem_plot( plt.show() -def causality_violation_score(hexels: Dict[str, IPPMHexel], hierarchy: Dict[str, List[str]], hemi: str): +def causality_violation_score(denoised_hexels: Dict[str, IPPMHexel], hierarchy: Dict[str, List[str]], hemi: str): + """ + Assumption: hexels are denoised. Otherwise, it doesn't really make sense to check the min/max latency of noisy hexels. + + A score calculated on denoised hexels that calculates the proportion of arrows in IPPM that are going backward in time. + It assumes that the function hierarchy is correct, which may not always be correct, so you must use it with caution. + + Algorithm + ---------- + violations = 0 + total_arrows = 0 + for each func_name, parents_list in hierarchy: + child_lat = min(hexels[func]) + for parent in parents_list: + parent_lat = max(hexels[parent]) + if child_lat < parent_lat: + violations++ + total_arrows++ + return violations / total_arrows if total_arrows > 0 else 0 + """ + assert(hemi == 'rightHemisphere' or hemi == 'leftHemisphere') def get_latency(func_hexels: IPPMHexel, mini: bool): @@ -227,17 +247,18 @@ def get_latency(func_hexels: IPPMHexel, mini: bool): causality_violations = 0 total_arrows = 0 for func, inc_edges in hierarchy.items(): + # essentially: if max(parent_spikes_latency) > min(child_spikes_latency), there will be a backwards arrow in time. # arrows go from latest inc_edge spike to the earliest func spike - child_latency = get_latency(hexels[func], mini=True) + child_latency = get_latency(denoised_hexels[func], mini=True) for inc_edge in inc_edges: - parent_latency = get_latency(hexels[inc_edge], mini=False) + parent_latency = get_latency(denoised_hexels[inc_edge], mini=False) if child_latency < parent_latency: causality_violations += 1 total_arrows += 1 return causality_violations / total_arrows if total_arrows != 0 else 0 -def function_recall(hexels: Dict[str, IPPMHexel], funcs: List[str], ippm_dict: Dict[str, Node], hemi: str) -> Tuple[float]: +def function_recall(noisy_hexels: Dict[str, IPPMHexel], funcs: List[str], ippm_dict: Dict[str, Node], hemi: str) -> Tuple[float]: """ This is the second scoring metric: function recall. It illustrates what proportion out of functions in the noisy hexels are detected as part of IPPM. E.g., 9 functions but only 8 found => 8/9 = function recall. Use this along with causality violation to evaluate IPPMs and analyse their strengths and weaknesses. @@ -256,20 +277,22 @@ def function_recall(hexels: Dict[str, IPPMHexel], funcs: List[str], ippm_dict: D """ assert(hemi == 'rightHemisphere' or hemi == 'leftHemisphere') - # Step 1: Identify how many significant functions out of funcs there are. + # Step 1: Calculate significance level alpha = 1 - NormalDist(mu=0, sigma=1).cdf(5) bonferroni_corrected_alpha = 1-(pow((1-alpha),(1/(2*201*200000)))) funcs_present_in_data = 0 detected_funcs = 0 for func in funcs: - pairings = hexels[func].right_best_pairings if hemi == 'rightHemisphere' else hexels[func].left_best_pairings + pairings = noisy_hexels[func].right_best_pairings if hemi == 'rightHemisphere' else noisy_hexels[func].left_best_pairings for latency, spike in pairings: + # Step 2: Find a pairing that is significant if spike <= bonferroni_corrected_alpha: funcs_present_in_data += 1 - # Step 2: Found a function, look in ippm_dict.keys() for the function. + # Step 3: Found a function, look in ippm_dict.keys() for the function. for node_name in ippm_dict.keys(): if func in node_name: + # Step 4: If found, then increment detected_funcs. Also increment funcs_pressent detected_funcs += 1 break break @@ -281,6 +304,17 @@ def function_recall(hexels: Dict[str, IPPMHexel], funcs: List[str], ippm_dict: D def convert_to_power10(hexels: Dict[str, IPPMHexel]) -> Dict[str, IPPMHexel]: + """ + Utility function to take data from the .nkg format and convert it to power of 10, so it can be used for IPPMs. + + Parameters + ------------ + hexels: dict function_name as key and hexel object as value. Hexels contain pairings for left/right. + + Returns + -------- + same dict but the pairings are all raised to power x. E.g., pairings = [(lat1, x), ..., (latn, xn)] -> [(lat1, 10^x), ..., (latn, 10^xn)] + """ for func, hexel in hexels.items(): hexels[func].right_best_pairings = list(map(lambda x: (x[0], math.pow(10, x[1])), hexels[func].right_best_pairings)) hexels[func].left_best_pairings = list(map(lambda x: (x[0], math.pow(10, x[1])), hexels[func].left_best_pairings)) @@ -288,7 +322,21 @@ def convert_to_power10(hexels: Dict[str, IPPMHexel]) -> Dict[str, IPPMHexel]: def remove_excess_funcs(to_retain: List[str], hexels: Dict[str, IPPMHexel]) -> Dict[str, IPPMHexel]: - funcs = list(hexels.keys()) + """ + Utility function to distill the hexels down to a subset of functions. Use this to visualise a subset of functions for time-series. + E.g., you want the time-series for one function, so just pass it wrapped in a list as to_retain + + Parameters + ---------- + to_retain: list of functions we want to retain in the hexels dict + hexels: hexels: dict function_name as key and hexel object as value. Hexels contain pairings for left/right. + + Returns + ------- + hexels but all functions that aren't in to_retain are filtered. + """ + + funcs = list(hexels.keys()) # need this because we can't remove from the dict while also iterating over it. for func in funcs: if func not in to_retain: # delete @@ -296,6 +344,26 @@ def remove_excess_funcs(to_retain: List[str], hexels: Dict[str, IPPMHexel]) -> D return hexels def plot_k_dist_1D(pairings: List[Tuple[float, float]], k: int=4, normalise: bool=False): + """ + This could be optimised further but since we aren't using it, we can leave it as it is. + + A utility function to plot the k-dist graph for a set of pairings. Essentially, the k dist graph plots the distance + to the kth neighbour for each point. By inspecting the gradient of the graph, we can gain some intuition behind the density of + points within the dataset, which can feed into selecting the optimal DBSCAN hyperparameters. + + For more details refer to section 4.2 in https://www.dbs.ifi.lmu.de/Publikationen/Papers/KDD-96.final.frame.pdf + + Parameters + ---------- + pairings: list of pairings extracted from a hexel. It contains the pairings for one function and one hemisphere + k: the k we use to find the kth neighbour. Paper above advises to use k=4. + normalise: whether to normalise before plotting the k-dist. It is important because the k-dist then equally weights both dimensions. + + Returns + ------- + Nothing but plots a graph. + """ + alpha = 3.55e-15 X = pd.DataFrame(columns=['Latency']) for latency, spike in pairings: @@ -320,6 +388,22 @@ def copy_hemisphere( hemi_to: str, hemi_from: str, func: str = None): + """ + Utility function to copy a hemisphere onto another one. The primary use-case is to plot the denoised hemisphere against the + noisy hemisphere using the same hexel object. I.e., copy right hemisphere to left; denoise on right; plot right vs left. + + Parameters + ---------- + hexels_to: Hexels we are writing into. Could be (de)noisy hexels. + hexels_from: Hexels we are copying from + hemi_to: the hemisphere we index into when we write into hexels_to. E.g., hexels_to[hemi_to] = hexels_from[hemi_from] + hemi_from: the hemisphere we index into when we copy the hexels from hexels_from. + func: if func != None, we only copy one function. Otherwise, we copy all. + + Returns + ------- + Nothing, everything is done in-place. I.e., hexels_to is now updated. + """ if func: # copy only one function if hemi_to == 'rightHemisphere' and hemi_from == 'rightHemisphere': @@ -343,6 +427,20 @@ def copy_hemisphere( hexels_to[func].left_best_pairings = hexels_from[func].left_best_pairings def plot_denoised_vs_noisy(hexels: Dict[str, IPPMHexel], clusterer, title: str): + """ + Utility function to plot the noisy and denoised versions. It runs the supplied clusterer and then copies the denoised hexels, which + are fed into a stem plot. + + Parameters + ---------- + hexels: hexels we want to denoise then plot + clusterer: A child class of DenoisingStrategy that implements .cluster + title: title of plot + + Returns + ------- + Nothing but plots a graph. + """ denoised_hexels = clusterer.cluster(hexels, 'rightHemisphere') copy_hemisphere(denoised_hexels, hexels, 'leftHemisphere', 'rightHemisphere') stem_plot(denoised_hexels, title) \ No newline at end of file diff --git a/kymata/ippm/denoiser.py b/kymata/ippm/denoiser.py index ad3e703b..3404d3a1 100644 --- a/kymata/ippm/denoiser.py +++ b/kymata/ippm/denoiser.py @@ -8,6 +8,8 @@ from sklearn.mixture import GaussianMixture from sklearn.preprocessing import normalize +#import multiprocessing + from .data_tools import IPPMHexel @@ -28,18 +30,41 @@ def cluster( ) -> Dict[str, IPPMHexel]: """ For each function in hemi, it will attempt to construct a dataframe that holds significant spikes (i.e., abova alpha). - Next, it clusters using self._clusterer. Finally, it locates the minimum (most significant) point for each cluster and saves - it. + Next, we do any optional preprocessing and cluster using self._clusterer. Finally, it locates the minimum + (most significant) point for each cluster and saves it. Optionally, the user can choose to constrain the number of spikes to 1. + Preprocessing includes scaling the data to have unit length or cluster only on latency (Density based clustering). + + TODO: could we instead average over points in a cluster? Minimum makes most sense because it is the best match and corresponds to an + actual match. This can be overridden if using a custom clustering strategy but as it is, it works well sklearn clustering techniques. As a result, additional algorithms from sklearn can be easily incorporated. + Essentially the algorithm operates by tagging each cluster as being generated by one original spike. Our algorithm attempts + to delineate the clusters and assumes the most significant spike is the original generating spike. + Params ------ hexels : Dict[str, Hexel] Contains the left hemisphere and right hemisphere pairings. We want to denoise one of them. hemi : str from ['rightHemisphere', 'leftHemisphere'] indicates the hemisphere to denoise. + normalise: bool + maps vectors to unit vectors. I.e., divides by the total length rather than standardise (map to N(0, 1)). + normalisation can aid algorithms using distance metrics, e.g., euclidean distance. Intuitively, having unnormalised + dimensions gives a heftier weight to the larger one, in this case, latency. By normalising, each dimension contributes the + same amount to the overall algorithm. Tests showed that normalising does not improve performance. + cluster_latency: bool + indicates whether to cluster only on the latency dimension. It is akin to having unnormalised features because + the latency dimension has a far higher weight than the magnitude dimension, latency is from -200 to 800 while magnitude + is 10^-x which is incredibly small. Clustering solely on latency can actually boost performance in some cases. However, + it is the same as a density-based algorithm because we map the problem to 1D and cluster on the density of points in + time. Hence, we keep it by default to False, so not all algorithms default to a density-based one. + posterior_pooling: bool + indicates whether we want to pool at the end of denoising to constrain the number of nodes per function to be 1. + it returns graphs of increased transparency at the price of enforcing a constraint, which may be false. The + assumption is that each function only appears once. To keep the results as close to the Truth as possible, it is set + by default to 1 but can be set to True to generate nice graphs. Returns ------- @@ -48,39 +73,126 @@ def cluster( """ self._check_hemi(hemi) # guardrail to check for hemisphere input. hexels = deepcopy(hexels) # dont do it in-place. - for func, df in self._hexels_to_df(hexels, hemi): + + alpha = self._estimate_alpha() + for func, df in self._hexels_to_df(hexels, hemi, alpha): + # Step 1: For each time series, we cluster on a dataframe with only significant spikes. if len(df) == 0: # there are no significant spikes. hexels = self._update_pairings(hexels, func, [], hemi) continue - # if we are renormalising each feature, then scale otherwise no + # Step 2: Perform any preprocessing and run the clusterer. + # if we are renormalising each feature, then normalise otherwise no if cluster_latency: - # cluster only the latency dimension. + # cluster only the latency dimension. normalising a 1d dataset should make no difference but kept it incase latency_only = self._get_latency_dim(df) fitted = (self._clusterer.fit(latency_only) if not normalise else self._clusterer.fit(normalize(latency_only))) else: fitted = self._clusterer.fit(normalize(df)) if normalise else self._clusterer.fit(df) + + # Step 3: Identify clusters and extract the most significant spike. df['Label'] = fitted.labels_ cluster_mins = self._get_cluster_mins(df) + + # Step 4: Overwrite the noisy pairings and optionally max pool hexels = self._update_pairings(hexels, func, cluster_mins, hemi) + if posterior_pooling: + hexels = self._posterior_pooling(hexels, hemi, func) + return hexels + + """ + This commented region contains the code for the multiprocessing version of the denoiser. + Currently, it is commented out because it does not offer an improvement in performance due to the + overhead cost of managing the processes outweighing the boost in speed. As the dataset size increases, + the value of multiprocessing code will increase, so eventually, this code might be useful. Therefore, it is + kept as a reference point. + + Also, Python is quite poor at concurrency. Python is compiled to bytecode, which is run by an interpreter. However, there is only + one shared interpreter across processes; hence, only one process runs at a time. So, Python gives off the illusion of concurrency when + it is actually just context switching rapidly between jobs. Incorporate the overhead of managing the threads/processes and concurrent + Python code can actually get slower. + + + def cluster_multiproc(self, hexels, hemi, normalise = False, cluster_latency = False, posterior_pooling = False): + self._check_hemi(hemi) + hexels = deepcopy(hexels) + alpha = self._estimate_alpha() + with multiprocessing.Pool() as pool: + args = [] + for func, df in self._hexels_to_df(hexels, hemi, alpha): + # aggregate independent dfs for each time series we want to cluster on. + if len(df) == 0: + # don't save this as no point. + hexels = self._update_pairings(hexels, func, [], hemi) + else: + args.append((df, func, normalise, cluster_latency, posterior_pooling)) + + # cluster each one in parallel + # we need to create the dataframes independently because multiprocessing copies by reference (so if you pass hexels, it copies it f times.) + result = pool.starmap(self._cluster_worker, args) + + # update hexels with new minimums. + for func, cluster_mins in result: + hexels = self._update_pairings(hexels, func, cluster_mins, hemi) return hexels if not posterior_pooling else self._posterior_pooling(hexels, hemi) + def _cluster_worker(self, df, func, normalise, cluster_latency, posterior_pooling): + if cluster_latency: + latency_only = self._get_latency_dim(df) + fitted = (self._clusterer.fit(latency_only) + if not normalise else + self._clusterer.fit(normalize(latency_only))) + else: + fitted = self._clusterer.fit(normalize(df)) if normalise else self._clusterer.fit(df) + df['Label'] = fitted.labels_ + cluster_mins = self._get_cluster_mins(df) + return (func, cluster_mins) + """ + def _get_latency_dim(self, df: pd.DataFrame) -> np.ndarray: + """ + Utility function to map latency from a 1D array to a 2D array of form: [[latency_elem_1, ..., latency_elem_n]] + Used when clustering solely on latency. + + Params + ------ + df: dataframe + dataframe contains Latency column, which we want to turn into a 2D np array. + + Returns + ------- + 2D np array containing the latency column from df. + """ return np.reshape(df['Latency'], (-1, 1)) - def _posterior_pooling(self, hexels: Dict[str, IPPMHexel], hemi: str) -> Dict[str, IPPMHexel]: - for func in hexels.keys(): - if len(hexels[func].left_best_pairings) != 0 and hemi == 'leftHemisphere': - hexels[func].left_best_pairings = [min(hexels[func].left_best_pairings, key=lambda x: x[1])] - elif len(hexels[func].right_best_pairings) != 0 and hemi == 'rightHemisphere': + def _posterior_pooling(self, hexels: Dict[str, IPPMHexel], hemi: str, func: str) -> Dict[str, IPPMHexel]: + """ + Optional max pooling over the entire latency done at the end of clustering. It enforces the constraint that there is only one + spike per function. + + Params + ------ + hexels: dict of func_name, hexel + each hexel contains a list of cluster minimums. We wish to take the min over this list + hemi: either 'rightHemisphere' or 'leftHemisphere'. we dont check because the parent function checks. + which hemi we are clustering over. + func: the function we want to max pool over. + + Returns + ------- + the same dictionary but now one of the hemispheres has its list reduced to a single spike. + """ + + if len(hexels[func].left_best_pairings) != 0 and hemi == 'leftHemisphere': + hexels[func].left_best_pairings = [min(hexels[func].left_best_pairings, key=lambda x: x[1])] + elif len(hexels[func].right_best_pairings) != 0 and hemi == 'rightHemisphere': hexels[func].right_best_pairings = [min(hexels[func].right_best_pairings, key=lambda x: x[1])] return hexels - - def _hexels_to_df(self, hexels: Dict[str, IPPMHexel], hemi: str) -> pd.DataFrame: + def _hexels_to_df(self, hexels: Dict[str, IPPMHexel], hemi: str, alpha: float) -> pd.DataFrame: """ A generator used to build a dataframe of significant points only. For each call, it returns the dataframe for the next function in hexels.keys(). @@ -95,7 +207,6 @@ def _hexels_to_df(self, hexels: Dict[str, IPPMHexel], hemi: str) -> pd.DataFrame ------- A dataframe for a function that contains only significant spikes, i.e., those above alpha. """ - alpha = self._estimate_alpha() for func in hexels.keys(): df = pd.DataFrame(columns=['Latency', 'Mag']) df = self._filter_spikes(hexels[func].right_best_pairings if hemi == 'rightHemisphere' else @@ -198,33 +309,18 @@ def _get_cluster_mins(self, df: pd.DataFrame) -> List[Tuple[float, float]]: ------- For each class, the most significant spike and the associated latency. """ - ret = [] - class_mins = {} - for _, row in df.iterrows(): - label = row['Label'] - if label == -1: - # label = -1 indicates anomaly, so exclude. - continue - - if label not in class_mins.keys(): - # first time seeing this class. - class_mins[label] = [float(row['Mag']), float(row['Latency'])] - elif class_mins[label][0] > row['Mag']: - # found a more significant spike, so overwrite. - class_mins[label] = [float(row['Mag']), float(row['Latency'])] - - for _, items in class_mins.items(): - # (latency, magnitude) - ret.append((items[1], items[0])) - - return ret + mins = df.loc[df.groupby('Label')['Mag'].idxmin()] + mins = mins[mins['Label'] != -1] # filter out anomalies + return list(zip(mins['Latency'], mins['Mag'])) class MaxPooler(DenoisingStrategy): """ - Naive max pooling technique. It operates by first sorting the latencies into bins, identifying significant bins, and taking the most significant spikes in a significant bin. - A bin is considered significant if the number of spikes for a particular function in the bin exceeds the threshold (self._threshold). Moreover the bin size can be controlled by the - bin_sz hyperparameter. Hence, to improve robustness, the threshold should be increased or bin size reduced. A criteria that is too stringent may lead to no significant spikes, so it should be balanced. - Finally, it is possible to run max pooler as an anomaly detection system prior to running an unsupervised algorithm, albeit at a higher computational cost. + Naive max pooling technique. It operates by first sorting the latencies into bins, identifying significant bins, and taking the + most significant spike in a significant bin. A bin is considered significant if the number of spikes for a particular function + in the bin exceeds the threshold (self._threshold). Moreover, the bin size can be controlled by the bin size (self._bin_sz) hyperparameter. + To improve robustness, the threshold should be increased or bin size reduced. A criteria that is too stringent may lead to no + significant spikes leading to low function recall, so it should be balanced. Finally, it is possible to run max pooler as an anomaly + detection system prior to running an unsupervised algorithm, albeit at a higher computational cost. """ def __init__(self, **kwargs): @@ -239,12 +335,15 @@ def __init__(self, **kwargs): self._threshold = 15 if 'threshold' not in kwargs.keys() else kwargs['threshold'] self._bin_sz = 25 if 'bin_sz' not in kwargs.keys() else kwargs['bin_sz'] - if not isinstance(self._threshold, int): + + + if not isinstance(self._threshold, int) or isinstance(self._threshold, bool): # edge case with isinstance: bool is subtype of int, so technically a user can pass in a boolean value for an integer # True = 1, False = 0. print('Threshold needs to be an integer.') raise ValueError - if not isinstance(self._bin_sz, int): + + if not isinstance(self._bin_sz, int) or isinstance(self._bin_sz, bool): print('Bin size needs to be an integer.') raise ValueError @@ -254,9 +353,20 @@ def cluster( ) -> Dict[str, IPPMHexel]: """ Custom clustering method since it differs from other unsupervised techniques. + + Yikes: forgot to incorporate cluster_bin + logn_f + 1 > # of bins is greater. + + Time Complexity: O(f * (nlogn + n)) where n = number of pairings = # of hexels = 10000. Final: O(f * nlogn) + Time Complexity (if you don't sort): O(f * (# of bins * n)) where # of bins = 1000/binsz = 1000/25 = 40. Final: O(f * bin_num * n) + Hence, if bin_num > logn, then we should go for sorted. bin_num = 40 assuming 1000 latency range 25 bin size. + log10000 = logn = 4. Therefore, we should go with sorted. + + Possible optimisation: break it up into parallel and do each bin at once? Ans: Python bad Algorithm --------- + for each func: sort by latency, so we can partition into bins in ascending order. If it is unordered, then there is no guarentee that adjacent data points belong to the same bin. ret = [] for current_bin in partitioned_latency_axis: @@ -283,12 +393,13 @@ def cluster( """ super()._check_hemi(hemi) hexels = deepcopy(hexels) - for func, df in super()._hexels_to_df(hexels, hemi): + alpha = self._estimate_alpha() + for func, df in super()._hexels_to_df(hexels, hemi, alpha): if len(df) == 0: hexels = super()._update_pairings(hexels, func, [], hemi) continue - df = df.sort_values(by='Latency') # arrange latencies into bins + df = df.sort_values(by='Latency') # arrange latencies into bins. It uses QuickSort. Complexity: O(nlogn) worst case: O(n^2) r_idx = 0 # this points to the current_data_point. It is incremented in the inner loop. ret = [] for latency in range(-200, 800, self._bin_sz): @@ -304,13 +415,17 @@ def cluster( ret.append((lat_min, bin_min)) hexels = super()._update_pairings(hexels, func, ret, hemi) + if posterior_pooling: + hexels = self._posterior_pooling(hexels, hemi, func) - return hexels if not posterior_pooling else super()._posterior_pooling(hexels, hemi) + return hexels def _cluster_bin(self, df: pd.DataFrame, r_idx: int, latency: int) -> Tuple[float, int, int, int]: """ We dont need to check r_idx and latency since cluster function provides them rather than the user. + Time complexity: O(n_b) where n_b = # of spikes in bin. Don't think we can reduce this further. + Params ------ df : pd.DataFrame @@ -347,68 +462,173 @@ def _cluster_bin(self, df: pd.DataFrame, r_idx: int, latency: int) -> Tuple[floa return bin_min, lat_min, num_seen, r_idx class AdaptiveMaxPooler(DenoisingStrategy): - def __init__(self, base_bin_sz: int=10, threshold: int=5): - self._threshold = threshold - self._base_bin_sz = base_bin_sz + def __init__(self, **kwargs): + """ + Params + ------ + threshold : int + # of spikes required in a bin before it is considered significant + bin_sz : int + the size, in ms, of a bin. + """ + self._threshold = 5 if 'threshold' not in kwargs.keys() else kwargs['threshold'] + self._base_bin_sz = 10 if 'base_bin_sz' not in kwargs.keys() else kwargs['base_bin_sz'] + + if not isinstance(self._threshold, int) or isinstance(self._threshold, bool): + print('Threshold needs to be an integer.') + raise ValueError + if not isinstance(self._base_bin_sz, int) or isinstance(self._base_bin_sz, bool): + print('Base Bin size needs to be an integer.') + raise ValueError def cluster( self, hexels: Dict[str, IPPMHexel], hemi: str, normalise: bool = False, cluster_latency: bool = False, posterior_pooling: bool = False ) -> Dict[str, IPPMHexel]: + """ + Time complexity is O(f(nlogn + n)). f = # of funcs, n = # of hexels/spikes = 10000. Final: O(f * nlogn) + + If we want to do it in one pass of the dataframe, i.e., in time n, we need to have it sorted. Hence, the nlogn is irreducible since + sort uses QuickSort under the hood. Also, the f is irreducible unless we don't want to go through each function. + + Adaptive Max Pooler (AMP) is an improvement over the MaxPooler (MP) algorithm. Whereas MP has a fixed bin size, + hence a fixed cluster size, AMP starts off with a minute bin size and iteratively expands it until encountering the end of a + cluster. Hence, AMP has a variable bin size. Specifically, it works as follows: + + Algorithm + --------- + Note: I think the code + comments is probably better to understand algo than this. + + Loop through the time-series for each function: + + sorted_df = sort(time_series_func) + + bins = segment(sorted_df) # break into bins of size 5 ms. + + prev_sig = False + prev_bin_min = inf + prev_lat_min = None + denoised_pairings = [] + for bin in bins: + cur_bin_min = inf + cur_lat_min = None + num_spikes = 0 + + for lat, mag in bin: + num_spikes++ + if mag < cur_bin_min: + cur_bin_min = mag + cur_lat_min = lat + + if num_spikes > self._threshold: + prev_sig = True + if cur_bin_min < prev_bin_min: + # it is minimum over prev bin and cur bin + prev_bin_min, prev_lat_min = cur_bin_min, cur_lat_min + + else: + # not significant. + if prev_sig: + # previous bin was significant. This is the case when we reach end of cluster/bin. + denoised_pairings.append((prev_bin_min, prev_lat_min)) + prev_bin_min = inf + prev_lat_min = None + prev_sig = False + + if prev_sig: + # previous bin was significant. This is the case when we reach end of dataframe and the final bin is significant. + denoised_pairings.append((prev_bin_min, prev_lat_min)) + prev_bin_min = inf + prev_lat_min = None + prev_sig = False + + Parameters + ---------- + See DenoisingStrategy + + Returns + ------- + See DenoisingStrategy + """ + def get_default_vals(): return (np.inf, None) hexels = deepcopy(hexels) - for func, df in super()._hexels_to_df(hexels, hemi): + alpha = self._estimate_alpha() + for func, df in super()._hexels_to_df(hexels, hemi, alpha): + # Step 1: Create a dataframe for each function with only significant spikes if len(df) == 0: + # dont cluster on an empty dataframe hexels = super()._update_pairings(hexels, func, [], hemi) continue + # Step 2: Sort the pairings so that we can loop through them in one go. df = df.sort_values(by='Latency') - df_ptr = 0 # index into df - end_ptr = 1 # guarenteed to have > 1 data point. delineates end of bin - start_ptr = 0 - total_bins = 1000 / self._base_bin_sz - prev_bin_min, prev_bin_lat_min = get_default_vals() - prev_signi = False - ret = [] + + # Step 3: Loop through and expand significant bins until there aren't anymore significant ones. Take the min over the expanded bin. + df_ptr = 0 # Pointer to where we are in the dataframe. + end_ptr = 1 # Delineates the end of a bin + start_ptr = 0 # Delineates the start of a bin + total_bins = 1000 / self._base_bin_sz # # of bins. Assumption: We keep latency in [-200, 800]. + prev_bin_min, prev_bin_lat_min = get_default_vals() # bin_min = np.inf, lat_min = None. Default vals when no datapoints in bin. + prev_signi = False # signifies whether the previous bin was significant. If so, we expand into curr bin + ret = [] # Denoised hexel pairings. while df_ptr < len(df) and start_ptr < total_bins: + # Loop until we end of dataframe or exceed the number of bins + end_ms = end_ptr * self._base_bin_sz num_in_bin = 0 cur_bin_min, cur_bin_lat_min = get_default_vals() while df_ptr < len(df) and df.iloc[df_ptr, 0] < end_ms: + # Step 3.1) Loop through the bin and take the min over significant spikes if df.iloc[df_ptr, 1] < cur_bin_min: + # Found significant point, i.e., cur_min > df[df_ptr]['Mag'] cur_bin_min, cur_bin_lat_min = df.iloc[df_ptr, 1], df.iloc[df_ptr, 0] num_in_bin += 1 df_ptr += 1 + if num_in_bin >= self._threshold: - end_ptr += 1 + # Step 3.2) Check to see if curr bin is significant, i.e., # of data points > threshold. If so, we save the minimum. prev_signi = True if cur_bin_min < prev_bin_min: + # If it is significant, save the minimum over this bin. We will compare it to the next minimum (if it is also significant). prev_bin_min, prev_bin_lat_min = cur_bin_min, cur_bin_lat_min else: + # Not a significant bin. Either previous bin was not significant (so no cluster) or it was (so we reached the end of a cluster) if prev_signi: - # start_ptr to end_ptr is significant + # If the previous bin was significant, we take the final minimum (prev_min) and save it. Now we reset the minimums. ret.append((prev_bin_lat_min, prev_bin_min)) prev_bin_min, prev_bin_lat_min = get_default_vals() prev_signi = False + + # Increment start of current bin. start_ptr = end_ptr - end_ptr += 1 + + # Proceed to next bin. + end_ptr += 1 + if prev_signi: - # last bin was significant and we expanded + # The final bin was significant. Due to the algo only looping til last bin, we can miss clustering the final bin. + # Hence, we do it now. ret.append((prev_bin_lat_min, prev_bin_min)) prev_bin_min, prev_bin_lat_min = get_default_vals() prev_signi = False + + # Step 4: Overwrite previous noisy hexels hexels = super()._update_pairings(hexels, func, ret, hemi) - return hexels if not posterior_pooling else super()._posterior_pooling(hexels, hemi) - + # Optional Step 5: Take maximum over the function to constrain 1 spike. + if posterior_pooling: + hexels = self._posterior_pooling(hexels, hemi, func) + return hexels + class GMM(DenoisingStrategy): """ - This strategy uses the GaussianMixtureModel algorithm. It attempts to fit a multimodal Gaussian distribution to the data using the EM algorithm. - The primary disadvantage of this model is that the number of Gaussians have to be prespecified. This implementation does a grid search from 1 to max_gaussians - to find the optimal number of Gaussians. Moreover, it does not work well with anomalies. + This strategy uses the GaussianMixtureModel algorithm. It attempts to fit a multimodal Gaussian distribution to the data using the + EM algorithm. The primary disadvantage of this model is that the number of Gaussians have to be prespecified. This implementation + does a grid search from 1 to max_gaussians to find the optimal number of Gaussians. Moreover, it does not work well with anomalies. """ def __init__(self, **kwargs): """ @@ -438,29 +658,28 @@ def __init__(self, **kwargs): self._is_aic = False if 'is_aic' not in kwargs.keys() else kwargs['is_aic'] # default is BIC since it is better for explanatory models, since it assumes reality lies within the hypothesis space. invalid = False - if not isinstance(self._max_gaussians, int): + if not isinstance(self._max_gaussians, int) or isinstance(self._max_gaussians, bool): print('Max Gaussians must be of type int.') invalid = True if self._covariance_type not in ['full', 'tied', 'diag', 'spherical']: print('Covariance type must be one of {full, tied, diag, spherical}') invalid = True - if not isinstance(self._max_iter, int): + if not isinstance(self._max_iter, int) or isinstance(self._max_iter, bool): print('Max iterations must be of type int.') invalid = True - if not isinstance(self._n_init, int): + if not isinstance(self._n_init, int) or isinstance(self._n_init, bool): print('Number of initialisations must be of type int.') invalid = True if self._init_params not in ['kmeans', 'k-means++', 'random', 'random_from_data']: print('Initalisation of parameter strategy must be one of {kmeans, k-means++, random, random_from_data}') invalid = True - if self._random_state is not None and not isinstance(self._random_state, int): + if self._random_state is not None and not isinstance(self._random_state, int) or isinstance(self._random_state, bool): print('Random state must be none or int.') invalid = True if invalid: raise ValueError - def cluster( self, hexels: Dict[str, IPPMHexel], hemi: str, normalise: bool = False, cluster_latency: bool = False, posterior_pooling: bool = False @@ -472,42 +691,36 @@ def cluster( Params ------ - Hexels : Dict[function_name, hexel_object] - The data we want to cluster over. - hemi : str - Needs to be leftHemisphere or rightHemisphere. + See DenoisingStrategy() Returns ------- - Hexels that are formatted the same as input hexels but with denoised pairings. + See DenoisingStrategy() """ super()._check_hemi(hemi) hexels = deepcopy(hexels) - for func, df in super()._hexels_to_df(hexels, hemi): + alpha = self._estimate_alpha() + for func, df in super()._hexels_to_df(hexels, hemi, alpha): + # Step 1: Loop through every function time-series in a dataframe format. Each one only contains significant spikes. if len(df) == 0: hexels = super()._update_pairings(hexels, func, [], hemi) continue if len(df) == 1: # no point clustering, just return the single data point. - ret = [] - for _, row in df.iterrows(): - ret.append((row['Latency'], row['Mag'])) + ret = [(df.iloc[0, 'Latency'], df.iloc[0, 'Mag'])] hexels = super()._update_pairings(hexels, func, ret, hemi) continue best_labels = None best_score = np.inf for n in range(1, self._max_gaussians): + # Step 2: Perform a grid-search for the optimal value of n_gaussians (the modality of the mixture of Gaussians) if n > len(df): # the number of gaussians has to be less than the number of datapoints. continue - gmm = GaussianMixture(n_components=n, - covariance_type=self._covariance_type, - max_iter=self._max_iter, - n_init=self._n_init, - init_params=self._init_params, - random_state=self._random_state) + + # Optional Step 3: Do any preprocessing on the dataframe. temp = None if normalise and cluster_latency: temp = np.reshape(normalize(df['Latency']), (-1, 1)) @@ -517,20 +730,33 @@ def cluster( temp = normalize(df) else: temp = df - + + # Step 4: Fit the model + gmm = GaussianMixture(n_components=n, + covariance_type=self._covariance_type, + max_iter=self._max_iter, + n_init=self._n_init, + init_params=self._init_params, + random_state=self._random_state) gmm.fit(temp) + + # Step 5: Evaluate the model and retain the best performing one. score = gmm.aic(temp) if self._is_aic else gmm.bic(temp) labels = gmm.predict(temp) - if score < best_score: # this condition depends on the choice of AIC/BIC best_labels = labels best_score = score - + + # Step 6: Update the hexels with the best clustering we found. df['Label'] = best_labels cluster_mins = super()._get_cluster_mins(df) hexels = super()._update_pairings(hexels, func, cluster_mins, hemi) - return hexels if not posterior_pooling else super()._posterior_pooling(hexels, hemi) + + # Optional Step 7: Max pool over the denoised hexels. + if posterior_pooling: + hexels = self._posterior_pooling(hexels, hemi, func) + return hexels class DBSCAN(DenoisingStrategy): @@ -568,10 +794,10 @@ def __init__(self, **kwargs): n_jobs = -1 if 'n_jobs' not in kwargs.keys() else kwargs['n_jobs'] invalid = False - if not (isinstance(eps, int) or isinstance(eps, float)): + if not (isinstance(eps, int) or isinstance(eps, float)) or isinstance(eps, bool): print('Epsilon must be of numeric type.') invalid = True - if not isinstance(min_samples, int): + if not isinstance(min_samples, int) or isinstance(min_samples, bool): print('Min samples must be of type integer.') invalid = True if not isinstance(metric, str): @@ -583,10 +809,10 @@ def __init__(self, **kwargs): if algorithm not in ['auto', 'ball_tree', 'kd_tree', 'brute']: print('Algorithm must be one of {auto, ball_tree, kd_tree, brute}') invalid = True - if not isinstance(leaf_size, int): + if not isinstance(leaf_size, int) or isinstance(leaf_size, bool): print('leaf_size must be of type int.') invalid = True - if not isinstance(n_jobs, int): + if not isinstance(n_jobs, int) or isinstance(n_jobs, bool): print('The number of jobs must be of type int.') invalid = True @@ -633,16 +859,16 @@ def __init__(self, **kwargs): if not isinstance(cluster_all, bool): print('Cluster_all must be of type bool.') invalid = True - if not isinstance(bandwidth, float) and not isinstance(bandwidth, int) and bandwidth is not None: + if (not isinstance(bandwidth, float) and not isinstance(bandwidth, int) and bandwidth is not None) or isinstance(bandwidth, bool): print('bandwidth must be None or float.') invalid = True if not isinstance(seeds, list) and seeds is not None: print('Seeds must be a list or None.') invalid = True - if not isinstance(min_bin_freq, int): + if not isinstance(min_bin_freq, int or isinstance(min_bin_freq, bool)): print('Mininum bin frequency must be of type int.') invalid = True - if not isinstance(n_jobs, int): + if not isinstance(n_jobs, int) or isinstance(n_jobs, bool): print('Number of jobs must be of type int.') invalid = True diff --git a/kymata/ippm/plotter.py b/kymata/ippm/plotter.py index 5f38bbaa..d142a885 100644 --- a/kymata/ippm/plotter.py +++ b/kymata/ippm/plotter.py @@ -32,12 +32,13 @@ def draw(self, width """ # first lets aggregate all of the information. - hexel_x = [_ for _ in range(len(graph.keys()))] - hexel_y = [_ for _ in range(len(graph.keys()))] - node_colors = [_ for _ in range(len(graph.keys()))] - node_sizes = [_ for _ in range(len(graph.keys()))] - hexel_coordinate_pairs = [] # [[start_coord, end_coord], ..] + # TODO: refactor to generate BSplines in the first loop, so we dont have to loop again. + hexel_x = [_ for _ in range(len(graph.keys()))] # x coordinates for nodes e.g., (x, y) = (hexel_x[i], hexel_y[i]) + hexel_y = [_ for _ in range(len(graph.keys()))] # y coordinates for nodes + node_colors = [_ for _ in range(len(graph.keys()))] # color for nodes + node_sizes = [_ for _ in range(len(graph.keys()))] # size of nodes edge_colors = [] + bsplines = [] for i, node in enumerate(graph.keys()): for function, color in colors.items(): # search for function color. @@ -48,21 +49,22 @@ def draw(self, node_sizes[i] = graph[node].magnitude hexel_x[i] = graph[node].position[0] hexel_y[i] = graph[node].position[1] - + + pairs = [] for inc_edge in graph[node].in_edges: # save edge coordinates and color the edge the same color as the finishing node. start = graph[inc_edge].position end = graph[node].position - hexel_coordinate_pairs.append([(start[0], start[1]), (end[0], end[1])]) + pairs.append([(start[0], start[1]), (end[0], end[1])]) edge_colors.append(node_colors[i]) - bspline_path_array = self._make_bspline_paths(hexel_coordinate_pairs) + bsplines += self._make_bspline_paths(pairs) fig, ax = plt.subplots() fig.set_figheight(figheight) fig.set_figwidth(figwidth) plt.axis('on') - for path, color in zip(bspline_path_array, edge_colors): + for path, color in zip(bsplines, edge_colors): ax.plot(path[0], path[1], color=color, linewidth='3', zorder=-1) ax.scatter(x=hexel_x, y=hexel_y, c=node_colors, s=node_sizes, zorder=1) ax.tick_params(bottom=True, labelbottom=True, left=False) diff --git a/tests/test_ippm_builder.py b/tests/test_ippm_builder.py index 78a846ba..91757f64 100644 --- a/tests/test_ippm_builder.py +++ b/tests/test_ippm_builder.py @@ -21,8 +21,8 @@ def test_sort_by_latency(): test_hexels['f1'].right_best_pairings = [(100, 42), (20, 41), (50, 33)] builder = IPPMBuilder() - sorted = builder._sort_by_latency(test_hexels, 'leftHemisphere') - sorted = builder._sort_by_latency(sorted, 'rightHemisphere') + sorted = builder._sort_by_latency(test_hexels, 'leftHemisphere', ['f1']) + sorted = builder._sort_by_latency(sorted, 'rightHemisphere', ['f1']) assert [(-12, 43), (21, 21), (46, 23), (143, 2)] == sorted['f1'].left_best_pairings assert [(20, 41), (50, 33), (100, 42)] == sorted['f1'].right_best_pairings diff --git a/tests/test_ippm_denoiser.py b/tests/test_ippm_denoiser.py index 9287e44c..a0852d8a 100644 --- a/tests/test_ippm_denoiser.py +++ b/tests/test_ippm_denoiser.py @@ -249,7 +249,8 @@ def test_Should_hexelsToDf_When_validInput(): denoiser_ = denoiser.DenoisingStrategy() fs = ['f1', 'f2'] i = 0 - for func, df in denoiser_._hexels_to_df(test_hexels, 'rightHemisphere'): + alpha = denoiser_._estimate_alpha() + for func, df in denoiser_._hexels_to_df(test_hexels, 'rightHemisphere', alpha): assert fs[i] == func if fs[i] == 'f2': assert test_hexels['f2'].right_best_pairings[1][0] == df.iloc[0, 0] @@ -268,7 +269,8 @@ def test_Should_hexelsToDf_When_leftHemisphere(): denoiser_ = denoiser.DenoisingStrategy() fs = ['f1', 'f2'] i = 0 - for func, df in denoiser_._hexels_to_df(test_hexels, 'leftHemisphere'): + alpha = denoiser_._estimate_alpha() + for func, df in denoiser_._hexels_to_df(test_hexels, 'leftHemisphere', alpha): assert fs[i] == func if fs[i] == 'f2': assert test_hexels['f2'].left_best_pairings[1][0] == df.iloc[0, 0] @@ -283,7 +285,8 @@ def test_Should_hexelsToDf_When_leftHemisphere(): def test_Should_hexelsToDf_When_emptyInput(): test_hexels = {'f1': IPPMHexel('f1'), 'f2': IPPMHexel('f2')} denoiser_ = denoiser.DenoisingStrategy() - for _, df in denoiser_._hexels_to_df(test_hexels, 'leftHemisphere'): + alpha = denoiser_._estimate_alpha() + for _, df in denoiser_._hexels_to_df(test_hexels, 'leftHemisphere', alpha): assert len(df) == 0 def test_Should_updatePairings_When_validInput(): @@ -379,9 +382,11 @@ def test_Should_gmmCluster_When_validInputRightHemisphere(): self_test_hexels2 = deepcopy(self_test_hexels) denoised = clusterer.cluster(self_test_hexels2, 'rightHemisphere') f2_expected = self_test_hexels2['func2'].right_best_pairings - assert denoised['func1'].right_best_pairings == self_test_hexels2['func1'].right_best_pairings - assert denoised['func2'].right_best_pairings == f2_expected - assert denoised['func3'].right_best_pairings == self_test_hexels2['func3'].right_best_pairings + + # use set because we don't care about ordering of elements + assert set(denoised['func1'].right_best_pairings) == set(self_test_hexels2['func1'].right_best_pairings) + assert set(denoised['func2'].right_best_pairings) == set(f2_expected) + assert set(denoised['func3'].right_best_pairings) == set(self_test_hexels2['func3'].right_best_pairings) def test_Should_gmmCluster_When_validInputLeftHemisphere(): """ @@ -399,9 +404,9 @@ def test_Should_gmmCluster_When_validInputLeftHemisphere(): f3_expected = self_test_hexels2['func3'].left_best_pairings f2_expected.remove((10, 0.01)) f3_expected.remove((130, 0.001)) - assert denoised['func1'].left_best_pairings == self_test_hexels2['func1'].left_best_pairings - assert denoised['func2'].left_best_pairings == self_test_hexels2['func2'].left_best_pairings - assert denoised['func3'].left_best_pairings == self_test_hexels2['func3'].left_best_pairings + assert set(denoised['func1'].left_best_pairings) == set(self_test_hexels2['func1'].left_best_pairings) + assert set(denoised['func2'].left_best_pairings) == set(self_test_hexels2['func2'].left_best_pairings) + assert set(denoised['func3'].left_best_pairings) == set(self_test_hexels2['func3'].left_best_pairings) def test_Should_dbscanCluster_When_validInputRightHemi(): np.random.seed(0) @@ -444,15 +449,15 @@ def test_Should_meanShiftCluster_When_validInputLeftHemi(): f3_expected = self_test_hexels['func3'].left_best_pairings f2_expected.remove((10, 0.01)) f3_expected.remove((130, 0.001)) - assert [(20.0, 1e-66)] == denoised['func1'].left_best_pairings - assert f2_expected == denoised['func2'].left_best_pairings - assert self_test_hexels['func3'].left_best_pairings == denoised['func3'].left_best_pairings + assert set([(20.0, 1e-66)]) == set(denoised['func1'].left_best_pairings) + assert set(f2_expected) == set(denoised['func2'].left_best_pairings) + assert set(self_test_hexels['func3'].left_best_pairings) == set(denoised['func3'].left_best_pairings) def test_Should_getLatencyDim_When_validDfRightHemisphere(): clusterer = denoiser.DenoisingStrategy() self_test_hexels2 = deepcopy(self_test_hexels) latency_dfs = [] - for func, df in clusterer._hexels_to_df(self_test_hexels2, 'rightHemisphere'): + for func, df in clusterer._hexels_to_df(self_test_hexels2, 'rightHemisphere', clusterer._estimate_alpha()): latency_dfs.append(clusterer._get_latency_dim(df)) assert [23, 35, 66] == list(latency_dfs[0].flatten()) @@ -462,11 +467,10 @@ def test_Should_getLatencyDim_When_validDfRightHemisphere(): def test_Should_posteriorPool_When_validHexelsRightHemisphere(): clusterer = denoiser.DenoisingStrategy() self_test_hexels2 = deepcopy(self_test_hexels) - pooled_hexels = clusterer._posterior_pooling(self_test_hexels2, 'rightHemisphere') pooled = [] - for func in pooled_hexels.keys(): - pooled.append(pooled_hexels[func].right_best_pairings) - + for func in self_test_hexels2.keys(): + self_test_hexels2 = clusterer._posterior_pooling(self_test_hexels2, 'rightHemisphere', func) + pooled.append(self_test_hexels2[func].right_best_pairings) assert [(23, 1e-75)] == pooled[0] assert [(45, 1e-60)] == pooled[1] assert [(120, 1e-90)] == pooled[2] @@ -474,10 +478,10 @@ def test_Should_posteriorPool_When_validHexelsRightHemisphere(): def test_Should_posteriorPool_When_validHexelsLeftHemisphere(): clusterer = denoiser.DenoisingStrategy() self_test_hexels2 = deepcopy(self_test_hexels) - pooled_hexels = clusterer._posterior_pooling(self_test_hexels2, 'leftHemisphere') pooled = [] - for func in pooled_hexels.keys(): - pooled.append(pooled_hexels[func].left_best_pairings) + for func in self_test_hexels2.keys(): + self_test_hexels2 = clusterer._posterior_pooling(self_test_hexels2, 'leftHemisphere', func) + pooled.append(self_test_hexels2[func].left_best_pairings) assert [(20, 1e-66)] == pooled[0] assert [(75, 1e-80)] == pooled[1] @@ -487,7 +491,7 @@ def test_Should_posteriorPool_When_emptyHexels(): clusterer = denoiser.DenoisingStrategy() self_test_hexels2 = deepcopy(self_test_hexels) self_test_hexels2['func1'].right_best_pairings = [] - pooled_hexels = clusterer._posterior_pooling(self_test_hexels2, 'rightHemisphere') + pooled_hexels = clusterer._posterior_pooling(self_test_hexels2, 'rightHemisphere', 'func1') assert [] == pooled_hexels['func1'].right_best_pairings def test_Should_maxPoolerPool_When_normalised(): diff --git a/tests/test_plotting.py b/tests/test_plotting.py index 062c117d..97278f1e 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -4,6 +4,7 @@ from kymata.plot.plot import _get_best_ylim, _MAJOR_TICK_SIZE, _get_yticks, _get_xticks, expression_plot + def test_best_best_ylim_returns_supplied_ylim(): supplied_ylim = 1e-172 data_y_min = 1e-250