From 483a3e8ce443e5773a5db3d32505a5fbd5f2fd5a Mon Sep 17 00:00:00 2001 From: BeGeiger Date: Tue, 10 Oct 2023 21:04:37 -0400 Subject: [PATCH 01/13] add JIT KL-NMF numba to speed up KL-NMF --- src/salamander/nmf_framework/klnmf.py | 55 +++++++++++++++++---------- src/salamander/utils.py | 30 +++++++-------- 2 files changed, 50 insertions(+), 35 deletions(-) diff --git a/src/salamander/nmf_framework/klnmf.py b/src/salamander/nmf_framework/klnmf.py index 5aa377e..f0b636f 100644 --- a/src/salamander/nmf_framework/klnmf.py +++ b/src/salamander/nmf_framework/klnmf.py @@ -1,5 +1,6 @@ import numpy as np import pandas as pd +from numba import njit from ..utils import kl_divergence, normalize_WH, poisson_llh, samplewise_kl_divergence from .nmf import NMF @@ -7,6 +8,38 @@ EPSILON = np.finfo(np.float32).eps +@njit +def update_W(X, W, H): + """ + The multiplicative update rule of the signature matrix W + derived by Lee and Seung. See Theorem 2 in + "Algorithms for non-negative matrix factorization". + + Clipping the matrix avoids floating point errors. + """ + W *= (X / (W @ H)) @ H.T + W /= np.sum(H, axis=1) + W = W.clip(EPSILON) + + return W + + +@njit +def update_H(X, W, H): + """ + The multiplicative update rule of the exposure matrix H + derived by Lee and Seung. See Theorem 2 in + "Algorithms for non-negative matrix factorization". + + Clipping the matrix avoids floating point errors. + """ + H *= W.T @ (X / (W @ H)) + H /= np.sum(W, axis=0)[:, np.newaxis] + H = H.clip(EPSILON) + + return H + + class KLNMF(NMF): """ Decompose a mutation count matrix X into the product of a signature @@ -37,28 +70,10 @@ def loglikelihood(self) -> float: return poisson_llh(self.X, self.W, self.H) def _update_W(self): - """ - The multiplicative update rule of the signature matrix W - derived by Lee and Seung. See Theorem 2 in - "Algorithms for non-negative matrix factorization". - - Clipping the matrix avoids floating point errors. - """ - self.W *= (self.X / (self.W @ self.H)) @ self.H.T - self.W /= np.sum(self.H, axis=1) - self.W = self.W.clip(EPSILON) + self.W = update_W(self.X, self.W, self.H) def _update_H(self): - """ - The multiplicative update rule of the exposure matrix H - derived by Lee and Seung. See Theorem 2 in - "Algorithms for non-negative matrix factorization". - - Clipping the matrix avoids floating point errors. - """ - self.H *= self.W.T @ (self.X / (self.W @ self.H)) - self.H /= np.sum(self.W, axis=0)[:, np.newaxis] - self.H = self.H.clip(EPSILON) + self.H = update_H(self.X, self.W, self.H) def fit( self, diff --git a/src/salamander/utils.py b/src/salamander/utils.py index bccd751..3ae4221 100644 --- a/src/salamander/utils.py +++ b/src/salamander/utils.py @@ -2,6 +2,7 @@ import numpy as np import pandas as pd +from numba import njit from scipy.optimize import linear_sum_assignment from scipy.special import gammaln from scipy.stats import mannwhitneyu @@ -71,30 +72,29 @@ def value_checker(arg_name: str, arg, allowed_values): ) +@njit(fastmath=True) def kl_divergence(X: np.ndarray, W: np.ndarray, H: np.ndarray) -> float: r""" - The generalized Kullback-Leibler divergence D(X || WH). - - \sum_vd X_vd * ln(X_vd / (WH)_vd) - \sum_vd X_vd + \sum_vd (WH)_vd. - - Summands with X_vd = 0 are neglected and WH is clipped to avoid division by zero. + The generalized Kullback-Leibler divergence + D_KL(X || WH) = \sum_vd X_vd * ln(X_vd / (WH)_vd) - \sum_vd X_vd + \sum_vd (WH)_vd. """ - indices = X.nonzero() - X_data = X[indices] - WH_data = (W @ H)[indices] - WH_data = WH_data.clip(EPSILON) + V, D = X.shape + WH = W @ H + kl_divergence = 0.0 - s1 = np.dot(X_data, np.log(X_data / WH_data)) - s2 = -np.sum(X_data) - # fast np.sum(W @ H) - s3 = np.dot(np.sum(W, axis=0), np.sum(H, axis=1)) + for v in range(V): + for d in range(D): + if X[v, d] != 0: + kl_divergence += X[v, d] * np.log(X[v, d] / WH[v, d]) + kl_divergence -= X[v, d] + kl_divergence += WH[v, d] - return s1 + s2 + s3 + return kl_divergence def samplewise_kl_divergence(X, W, H): """ - A fast vectorized samplewise KL divergence. + Per sample generalizedKullback-Leibler divergence D_KL(x || Wh). """ X_data = np.copy(X).astype(float) indices = X == 0 From 34ebd0d60a23428b10caa8a73711265cbc4a6481 Mon Sep 17 00:00:00 2001 From: BeGeiger Date: Thu, 12 Oct 2023 13:48:35 -0400 Subject: [PATCH 02/13] add numba optimization - initial JIT implementation of core functions of the NMF framework - removed the "update_W" attribute from CorrNMF due to the mathematical connection between "1999-Lee" and "surrogate" --- src/salamander/nmf_framework/corrnmf.py | 161 ++++++++++++++++-- src/salamander/nmf_framework/corrnmf_det.py | 112 +++--------- .../nmf_framework/multimodal_corrnmf.py | 48 ++++-- src/salamander/utils.py | 48 ++++-- tests/test_corrnmf.py | 25 +-- ...orrnmf_nsigs1_dim1_W_surrogate_updated.npy | Bin 896 -> 0 bytes ....npy => corrnmf_nsigs1_dim1_W_updated.npy} | Bin ...orrnmf_nsigs2_dim2_W_surrogate_updated.npy | Bin 1664 -> 0 bytes ....npy => corrnmf_nsigs2_dim2_W_updated.npy} | Bin .../model0_W_surrogate_updated.npy | Bin 1664 -> 0 bytes ...W_Lee_updated.npy => model0_W_updated.npy} | Bin .../model1_W_surrogate_updated.npy | Bin 2432 -> 0 bytes ...W_Lee_updated.npy => model1_W_updated.npy} | Bin tests/test_multimodal_corrnmf.py | 36 +--- 14 files changed, 255 insertions(+), 175 deletions(-) delete mode 100644 tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs1_dim1_W_surrogate_updated.npy rename tests/test_data/nmf_framework/corrnmf/{corrnmf_nsigs1_dim1_W_Lee_updated.npy => corrnmf_nsigs1_dim1_W_updated.npy} (100%) delete mode 100644 tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs2_dim2_W_surrogate_updated.npy rename tests/test_data/nmf_framework/corrnmf/{corrnmf_nsigs2_dim2_W_Lee_updated.npy => corrnmf_nsigs2_dim2_W_updated.npy} (100%) delete mode 100644 tests/test_data/nmf_framework/multimodal_corrnmf/model0_W_surrogate_updated.npy rename tests/test_data/nmf_framework/multimodal_corrnmf/{model0_W_Lee_updated.npy => model0_W_updated.npy} (100%) delete mode 100644 tests/test_data/nmf_framework/multimodal_corrnmf/model1_W_surrogate_updated.npy rename tests/test_data/nmf_framework/multimodal_corrnmf/{model1_W_Lee_updated.npy => model1_W_updated.npy} (100%) diff --git a/src/salamander/nmf_framework/corrnmf.py b/src/salamander/nmf_framework/corrnmf.py index 8cbee5e..dec3a6a 100644 --- a/src/salamander/nmf_framework/corrnmf.py +++ b/src/salamander/nmf_framework/corrnmf.py @@ -7,6 +7,7 @@ import matplotlib.pyplot as plt import numpy as np import pandas as pd +from numba import njit from scipy.spatial.distance import squareform from scipy.special import gammaln @@ -32,6 +33,147 @@ EPSILON = np.finfo(np.float32).eps +@njit +def update_alpha(X, L, U): + exp_LTU = np.exp(L.T @ U) + alpha = np.log(np.sum(X, axis=0)) - np.log(np.sum(exp_LTU, axis=0)) + return alpha + + +@njit +def update_sigma_sq(L, U): + dim_embeddings, n_signatures = L.shape + n_samples = U.shape[1] + sum_norm_sigs = np.sum(L**2) + sum_norm_samples = np.sum(U**2) + sigma_sq = (sum_norm_sigs + sum_norm_samples) / ( + dim_embeddings * (n_signatures + n_samples) + ) + return sigma_sq + + +@njit +def update_W(X, W, H): + W *= (X / (W @ H)) @ H.T + W /= np.sum(W, axis=0) + W = W.clip(EPSILON) + return W + + +@njit +def update_p_unnormalized(W, H): + n_features, n_signatures = W.shape + n_samples = H.shape[1] + p = np.zeros((n_features, n_signatures, n_samples)) + + for v in range(n_features): + for k in range(n_signatures): + for d in range(n_samples): + p[v, k, d] = W[v, k] * H[k, d] + + return p + + +@njit +def _objective_fun_l(l, U, alpha, sigma_sq, aux_row): + UTl = U.T.dot(l) + s = np.dot(aux_row, UTl) + s -= np.sum(np.exp(alpha + UTl)) + s -= np.dot(l, l) / (2 * sigma_sq) + + return -s + + +@njit +def _gradient_l(l, U, alpha, sigma_sq, s_grad): + s = -np.sum(np.exp(alpha + U.T.dot(l)) * U, axis=1) + s -= l / sigma_sq + + return -(s_grad + s) + + +@njit +def _hessian_l(l, U, alpha, sigma_sq, outer_prods_U): + dim_embeddings, n_samples = U.shape + scalings = np.exp(alpha + U.T.dot(l)) + s = np.zeros((dim_embeddings, dim_embeddings)) + + for m1 in range(dim_embeddings): + for m2 in range(dim_embeddings): + for d in range(n_samples): + s[m1, m2] -= scalings[d] * outer_prods_U[d, m1, m2] + if m1 == m2: + s[m1, m2] -= 1 / sigma_sq + + return -s + + +@njit +def _objective_fun_u( + u: np.ndarray, + L: np.ndarray, + alpha: float, + sigma_sq: float, + aux_col: np.ndarray, + add_penalty_u=True, +): + n_signatures = L.shape[1] + LTu = L.T.dot(u) + s = 0.0 + + # aux_col not contiguous: + # s = np.dot(aux_col, LTu) doesn't work + for k in range(n_signatures): + s += aux_col[k] * LTu[k] + + s -= np.sum(np.exp(alpha + LTu)) + + if add_penalty_u: + s -= np.dot(u, u) / (2 * sigma_sq) + + return -s + + +@njit +def _gradient_u( + u: np.ndarray, + L: np.ndarray, + alpha: float, + sigma_sq: float, + s_grad: np.ndarray, + add_penalty_u=True, +): + s = -np.exp(alpha) * np.sum(np.exp(L.T.dot(u)) * L, axis=1) + + if add_penalty_u: + s -= u / sigma_sq + + return -(s_grad + s) + + +@njit +def _hessian_u( + u: np.ndarray, + L: np.ndarray, + alpha: float, + sigma_sq: float, + outer_prods_L: np.ndarray, + add_penalty_u=True, +): + dim_embeddings, n_signatures = L.shape + scalings = np.exp(alpha + L.T.dot(u)) + s = np.zeros((dim_embeddings, dim_embeddings)) + + for m1 in range(dim_embeddings): + for m2 in range(dim_embeddings): + for k in range(n_signatures): + s[m1, m2] -= scalings[k] * outer_prods_L[k, m1, m2] + if add_penalty_u and m1 == m2: + s[m1, m2] -= 1 / sigma_sq + + return -s + + class CorrNMF(SignatureNMF): r""" The abstract class CorrNMF unifies the structure of deterministic and @@ -134,7 +276,6 @@ def __init__( n_signatures=1, dim_embeddings=None, init_method="nndsvd", - update_W="1999-Lee", min_iterations=500, max_iterations=10000, tol=1e-7, @@ -158,11 +299,6 @@ def __init__( "nndsvda", "nndsvdar" "random" and "separableNMF". See the initialization module for further details. - update_W: str, "1999-Lee" or "surrogate" - The signature matrix W can be inferred by either using the Lee and Seung - multiplicative update rules to optimize the objective function or by - maximizing the surrogate objective function. - min_iterations: int The minimum number of iterations to perform during inference @@ -180,8 +316,6 @@ def __init__( dim_embeddings = n_signatures self.dim_embeddings = dim_embeddings - value_checker("update_W", update_W, ["1999-Lee", "surrogate"]) - self.update_W = update_W # initialize data/fitting dependent attributes self.W = None @@ -205,7 +339,7 @@ def exposures(self) -> pd.DataFrame: restructured and determined by the signature and sample embeddings. """ exposures = pd.DataFrame( - np.exp(np.tile(self.alpha, (self.n_signatures, 1)) + self.L.T @ self.U), + np.exp(self.alpha + self.L.T @ self.U), index=self.signature_names, columns=self.sample_names, ) @@ -311,13 +445,8 @@ def _update_sigma_sq(self): pass @abstractmethod - def _update_W(self, p): - """ - Input: - ------ - p: np.ndarray - The auxiliary parameters of CorrNMF - """ + def _update_W(self): + pass @abstractmethod def _update_p(self): diff --git a/src/salamander/nmf_framework/corrnmf_det.py b/src/salamander/nmf_framework/corrnmf_det.py index 8d62db8..00905fa 100644 --- a/src/salamander/nmf_framework/corrnmf_det.py +++ b/src/salamander/nmf_framework/corrnmf_det.py @@ -1,13 +1,17 @@ +# This implementation on helper functions in corrnmf.py. +# In particular, functions with leading '_'' are accessed +# pylint: disable=protected-access + import numpy as np import pandas as pd from scipy import optimize -from .corrnmf import CorrNMF +from . import corrnmf EPSILON = np.finfo(np.float32).eps -class CorrNMFDet(CorrNMF): +class CorrNMFDet(corrnmf.CorrNMF): r""" The CorrNMFDet class implements the deterministic batch version of a variant of the correlated NMF (CorrNMF) algorithm devolped in @@ -43,77 +47,47 @@ class CorrNMFDet(CorrNMF): """ def _update_alpha(self): - exp_LTU = np.exp(self.L.T @ self.U) - self.alpha = np.log(np.sum(self.X, axis=0)) - np.log(np.sum(exp_LTU, axis=0)) + self.alpha = corrnmf.update_alpha(self.X, self.L, self.U) def _update_sigma_sq(self): - sum_norm_sigs = np.sum(self.L**2) - sum_norm_samples = np.sum(self.U**2) - - self.sigma_sq = (sum_norm_sigs + sum_norm_samples) / ( - self.dim_embeddings * (self.n_signatures + self.n_samples) - ) + self.sigma_sq = corrnmf.update_sigma_sq(self.L, self.U) self.sigma_sq = np.clip(self.sigma_sq, EPSILON, None) - def _update_W(self, p): - if self.update_W == "1999-Lee": - self.W = self.W * ( - (self.X / (self.W @ self.exposures.values)) @ self.exposures.values.T - ) - - else: - self.W = np.einsum("vd,vkd->vk", self.X, p) - - self.W /= np.sum(self.W, axis=0) - self.W = self.W.clip(EPSILON) + def _update_W(self): + self.W = corrnmf.update_W(self.X, self.W, self.exposures.values) def _update_p(self): - p = np.einsum("vk,kd->vkd", self.W, self.exposures.values) + p = corrnmf.update_p_unnormalized(self.W, self.exposures.values) p /= np.sum(p, axis=1, keepdims=True) p = p.clip(EPSILON) - return p - def _objective_fun_l(self, l, aux_row): - UTl = self.U.T.dot(l) - s = np.dot(aux_row, UTl) - s -= np.sum(np.exp(self.alpha + UTl)) - s -= np.dot(l, l) / (2 * self.sigma_sq) - - return -s - - def _gradient_l(self, l, s_grad): - s = -np.sum(np.exp(self.alpha + self.U.T.dot(l)) * self.U, axis=1) - s -= l / self.sigma_sq - - return -(s_grad + s) - - def _hessian_l(self, l, outer_prods_U): - scalings = np.exp(self.alpha + self.U.T.dot(l)) - s = -np.einsum("D,Dmn->mn", scalings, outer_prods_U) - s -= np.diag(np.full(self.dim_embeddings, 1 / self.sigma_sq)) - - return -s - def _update_l(self, index, aux_row, outer_prods_U): def objective_fun(l): - return self._objective_fun_l(l, aux_row) + return corrnmf._objective_fun_l( + l, self.U, self.alpha, self.sigma_sq, aux_row + ) s_grad = np.sum(aux_row * self.U, axis=1) def gradient(l): - return self._gradient_l(l, s_grad) + return corrnmf._gradient_l(l, self.U, self.alpha, self.sigma_sq, s_grad) def hessian(l): - return self._hessian_l(l, outer_prods_U) + return corrnmf._hessian_l( + l, self.U, self.alpha, self.sigma_sq, outer_prods_U + ) - self.L[:, index] = optimize.minimize( + l = optimize.minimize( fun=objective_fun, x0=self.L[:, index], method="Newton-CG", jac=gradient, hess=hessian, ).x + l[(0 < l) & (l < EPSILON)] = EPSILON + l[(-EPSILON < l) & (l < 0)] = -EPSILON + self.L[:, index] = l def _update_L(self, aux, outer_prods_U=None): r""" @@ -131,49 +105,19 @@ def _update_L(self, aux, outer_prods_U=None): for k, aux_row in enumerate(aux): self._update_l(k, aux_row, outer_prods_U) - self.L[(0 < self.L) & (self.L < EPSILON)] = EPSILON - self.L[(-EPSILON < self.L) & (self.L < 0)] = -EPSILON - - def _objective_fun_u(self, u, index, aux_col, add_penalty_u=True): - LTu = self.L.T.dot(u) - s = np.dot(aux_col, LTu) - s -= np.sum(np.exp(self.alpha[index] + LTu)) - - if add_penalty_u: - s -= np.dot(u, u) / (2 * self.sigma_sq) - - return -s - - def _gradient_u(self, u, index, s_grad, add_penalty_u=True): - s = -np.exp(self.alpha[index]) * np.sum( - np.exp(self.L.T.dot(u)) * self.L, axis=1 - ) - - if add_penalty_u: - s -= u / self.sigma_sq - - return -(s_grad + s) - - def _hessian_u(self, u, index, outer_prods_L, add_penalty_u=True): - scalings = np.exp(self.alpha[index] + self.L.T.dot(u)) - s = -np.einsum("K,Kmn->mn", scalings, outer_prods_L) - - if add_penalty_u: - s -= np.diag(np.full(self.dim_embeddings, 1 / self.sigma_sq)) - - return -s - def _update_u(self, index, aux_col, outer_prods_L): + alpha = self.alpha[index] + def objective_fun(u): - return self._objective_fun_u(u, index, aux_col) + return corrnmf._objective_fun_u(u, self.L, alpha, self.sigma_sq, aux_col) s_grad = np.sum(aux_col * self.L, axis=1) def gradient(u): - return self._gradient_u(u, index, s_grad) + return corrnmf._gradient_u(u, self.L, alpha, self.sigma_sq, s_grad) def hessian(u): - return self._hessian_u(u, index, outer_prods_L) + return corrnmf._hessian_u(u, self.L, alpha, self.sigma_sq, outer_prods_L) u = optimize.minimize( fun=objective_fun, @@ -275,7 +219,7 @@ def fit( self._update_sigma_sq() if given_signatures is None: - self._update_W(p) + self._update_W() of_values.append(self.objective_function()) prev_sof_value = sof_values[-1] diff --git a/src/salamander/nmf_framework/multimodal_corrnmf.py b/src/salamander/nmf_framework/multimodal_corrnmf.py index 72ed331..ed89089 100755 --- a/src/salamander/nmf_framework/multimodal_corrnmf.py +++ b/src/salamander/nmf_framework/multimodal_corrnmf.py @@ -29,6 +29,7 @@ umap_2d, ) from ..utils import type_checker, value_checker +from . import corrnmf from .corrnmf_det import CorrNMFDet EPSILON = np.finfo(np.float32).eps @@ -41,7 +42,6 @@ def __init__( ns_signatures=None, dim_embeddings=None, init_method="nndsvd", - update_W="1999-Lee", min_iterations=500, max_iterations=10000, tol=1e-7, @@ -62,7 +62,7 @@ def __init__( self.max_iterations = max_iterations self.tol = tol self.models = [ - CorrNMFDet(n_signatures, dim_embeddings, init_method, update_W) + CorrNMFDet(n_signatures, dim_embeddings, init_method) for n_signatures in ns_signatures ] @@ -197,46 +197,70 @@ def _update_sigma_sq(self): for model in self.models: model.sigma_sq = sigma_sq - def _update_Ws(self, ps, given_signatures): - for model, p, given_sigs in zip(self.models, ps, given_signatures): + def _update_Ws(self, given_signatures): + for model, given_sigs in zip(self.models, given_signatures): if given_sigs is None: - model._update_W(p) + model._update_W() def _update_ps(self): return [model._update_p() for model in self.models] def _objective_fun_u(self, u, index, aux_cols): + sigma_sq = self.models[0].sigma_sq s = -np.sum( [ - model._objective_fun_u(u, index, aux_col, add_penalty_u=False) + corrnmf._objective_fun_u( + u, + model.L, + model.alpha[index], + sigma_sq, + aux_col, + add_penalty_u=False, + ) for model, aux_col in zip(self.models, aux_cols) ] ) - s -= np.dot(u, u) / (2 * self.models[0].sigma_sq) + s -= np.dot(u, u) / (2 * sigma_sq) return -s def _gradient_u(self, u, index, s_grads): + sigma_sq = self.models[0].sigma_sq s = -np.sum( [ - model._gradient_u(u, index, s_grad, add_penalty_u=False) + corrnmf._gradient_u( + u, + model.L, + model.alpha[index], + sigma_sq, + s_grad, + add_penalty_u=False, + ) for model, s_grad in zip(self.models, s_grads) ], axis=0, ) - s -= u / self.models[0].sigma_sq + s -= u / sigma_sq return -s def _hessian_u(self, u, index, outer_prods_Ls): + sigma_sq = self.models[0].sigma_sq s = -np.sum( [ - model._hessian_u(u, index, outer_prods_L, add_penalty_u=False) + corrnmf._hessian_u( + u, + model.L, + model.alpha[index], + sigma_sq, + outer_prods_L, + add_penalty_u=False, + ) for model, outer_prods_L in zip(self.models, outer_prods_Ls) ], axis=0, ) - s -= np.diag(np.full(self.dim_embeddings, 1 / self.models[0].sigma_sq)) + s -= np.diag(np.full(self.dim_embeddings, 1 / sigma_sq)) return -s @@ -394,7 +418,7 @@ def fit( ps = self._update_ps() self._update_LsU(ps, given_signature_embeddings, given_sample_embeddings) self._update_sigma_sq() - self._update_Ws(ps, given_signatures) + self._update_Ws(given_signatures) of_values.append(self.objective_function()) prev_sof_value = sof_values[-1] diff --git a/src/salamander/utils.py b/src/salamander/utils.py index 3ae4221..fcc1292 100644 --- a/src/salamander/utils.py +++ b/src/salamander/utils.py @@ -80,16 +80,16 @@ def kl_divergence(X: np.ndarray, W: np.ndarray, H: np.ndarray) -> float: """ V, D = X.shape WH = W @ H - kl_divergence = 0.0 + result = 0.0 for v in range(V): for d in range(D): if X[v, d] != 0: - kl_divergence += X[v, d] * np.log(X[v, d] / WH[v, d]) - kl_divergence -= X[v, d] - kl_divergence += WH[v, d] + result += X[v, d] * np.log(X[v, d] / WH[v, d]) + result -= X[v, d] + result += WH[v, d] - return kl_divergence + return result def samplewise_kl_divergence(X, W, H): @@ -111,24 +111,38 @@ def samplewise_kl_divergence(X, W, H): return errors -def poisson_llh(X: np.ndarray, W: np.ndarray, H: np.ndarray) -> float: +@njit(fastmath=True) +def _poisson_llh_wo_factorial(X: np.ndarray, W: np.ndarray, H: np.ndarray) -> float: """ The Poisson log-likelihood generalized to X, W and H having - non-negative real numbers. + non-negative real numbers without the summands involving the log-factorial + of elements of X. + Note: + scipy-special, which is required to computed the log-factorial, + is not supported by numba. """ - WH_data = W @ H - indices = WH_data.nonzero() - WH_data = WH_data[indices] - X_data = X[indices] + V, D = X.shape + WH = W @ H + result = 0.0 - s1 = np.dot(X_data, np.log(WH_data)) - # fast np.sum(W @ H) - s2 = -np.dot(np.sum(W, axis=0), np.sum(H, axis=1)) - s3 = -np.sum(gammaln(1 + X)) + for v in range(V): + for d in range(D): + if WH[v, d] != 0: + result += X[v, d] * np.log(WH[v, d]) + result -= WH[v, d] - llh = s1 + s2 + s3 + return result + + +def poisson_llh(X: np.ndarray, W: np.ndarray, H: np.ndarray) -> float: + """ + The Poisson log-likelihood generalized to X, W and H having + non-negative real numbers. + """ + result = _poisson_llh_wo_factorial(X, W, H) + result -= np.sum(gammaln(1 + X)) - return llh + return result def normalize_WH(W, H): diff --git a/tests/test_corrnmf.py b/tests/test_corrnmf.py index 1e8446f..9a8899d 100644 --- a/tests/test_corrnmf.py +++ b/tests/test_corrnmf.py @@ -89,13 +89,8 @@ def surrogate_objective_init(path): @pytest.fixture -def W_updated_Lee(path): - return np.load(f"{path}_W_Lee_updated.npy") - - -@pytest.fixture -def W_updated_surrogate(path): - return np.load(f"{path}_W_surrogate_updated.npy") +def W_updated(path): + return np.load(f"{path}_W_updated.npy") @pytest.fixture @@ -129,15 +124,9 @@ def test_surrogate_objective_function( model_init._surrogate_objective_function(_p), surrogate_objective_init ) - def test_update_W_Lee(self, model_init, _p, W_updated_Lee): - model_init.update_W = "1999-Lee" - model_init._update_W(_p) - assert np.allclose(model_init.W, W_updated_Lee) - - def test_update_W_surrogate(self, model_init, _p, W_updated_surrogate): - model_init.update_W = "surrogate" - model_init._update_W(_p) - assert np.allclose(model_init.W, W_updated_surrogate) + def test_update_W_Lee(self, model_init, W_updated): + model_init._update_W() + assert np.allclose(model_init.W, W_updated) def test_update_alpha(self, model_init, alpha_updated): model_init._update_alpha() @@ -152,7 +141,11 @@ def test_update_L(self, model_init, _aux, L_updated): assert np.allclose(model_init.L, L_updated) def test_update_U(self, model_init, _aux, U_updated): + print("\n\n", "BEFORE UPDATE", model_init.U, "\n\n") model_init._update_U(_aux) + print("UPDATED", model_init.U, "\n\n") + print("SOLUTION", U_updated, "\n\n") + print("DIFFERENCE", np.sum(np.abs(model_init.U - U_updated))) assert np.allclose(model_init.U, U_updated) def test_update_sigma_sq(self, model_init, sigma_sq_updated): diff --git a/tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs1_dim1_W_surrogate_updated.npy b/tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs1_dim1_W_surrogate_updated.npy deleted file mode 100644 index ef32f028a6e0f03bbba4ecb530dafb427492aa5b..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 896 zcmbVK{WB8)7*=Rh`ts3nTsy^@rY{|mJY5QvkJ=@vtITXB+QhbIGmU+jGqPkON)cD; zDq1OtBz^DtP;!Z@<%;yBaZcae?fL_HetGYGet6&ad7tNCWOT&F*a;SV3lS-Skxow| z1rSMrOn(xENMdr*xM?)@E>2nk<1ZgVOG#(^;nP{PRK}lv{&HUmaS5416chg!lMBhb za*kR9w`seh>wc1hWZ~ zFPqdjdw6W6E;$b}7jFIS@EmN-x9y$8---K_Fh|3H9F@s4#ET|sP!+x=)IGic)Rpco z4+?2uF8$I)sMUc`_qKveQ-J3ds}D3~{F@&Wsv3VBtHjj%2JU-R76Oi5ro76`K%+;8 zU*N4oltqnZKBmag&J`}Kf0+KSPHeQe&^JPb!&B@0cul+8 z*L5cs>=QvRK2;sW%qmuqLvS%He@L|Uv+{6OCJ61Bnu{*#osq>V9YU->I@|Xvagm(T z^(L_rrg7a3cV-DzoKFe&&piNU+?V$K2^DD3*{_(_o`c?U<7;bIA-=18iWfNUNA38j z2>%x5KRkJpUr5dwt3B0GsseUmE81sIN4EkuG zP@8hNo>>^^oNa9BkV5>_vwW*vCe}A>CX$!QP^6t-a?xIj?vvlgn$GbtT*YYB1_>aE zR_jNPNb&Jp*0YEwGThiP?l-zO+k6ingto-u{+3*`|9mto?TZ-Ro`5N-hO%$dcvwwV zddi>kQKR?TZFPu_?Q=hWh-0$wi|A43ck6Xu^#Zf+GW-TzyoQPZ diff --git a/tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs1_dim1_W_Lee_updated.npy b/tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs1_dim1_W_updated.npy similarity index 100% rename from tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs1_dim1_W_Lee_updated.npy rename to tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs1_dim1_W_updated.npy diff --git a/tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs2_dim2_W_surrogate_updated.npy b/tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs2_dim2_W_surrogate_updated.npy deleted file mode 100644 index 8ff7acda5f8d136983451bc97b96c62aeca5e415..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1664 zcmbV}Z#)zT9>x`=W+!ZAOId$TOkrzmVr}mC+fveOb#Cd{{OM)~DU&hAkR5}}7zUFW zGiFRGe`<`FRSH?MqVliS${e#5>TIc<4{Nk3HFa+9Uby%7ynCL{^Zh(89_fLD`vbpT zWRh&cAw@;AXo;lH9Z0@0pOM@hNHL7Wq(o}`VMbz9^nY@HDxDSmN6w0+CPe??H~M(F zJAAsy-GS@y|BB$|VM6{Mg-or9&L+j*GN#hEJn3Q*NpXp*>B4eq+{HO&#U z1^!ZKhvDRa7?qkYq`lTPV9?yYVyG4&BaADj8F_GF=4h{1J|7*7ZVv{1ib3e6+hS=$ zMVV=m=lI!tARf9etWxIU&-Y{&-QndhG$p*q_Lkw~@4M&MeO?Ey?A;lUS4i;eHy{6= zWzh-us7kuUwg2r`!=}??no_)6rT5TF)evRt!p*Q3BD$<6C)swxfFkXu(LZtVw%+NC zjIV}-r{rIj8^Uq7c+d7f&#^FQah>+6R~}BtWuu&s8!)>*KR!WB$CFjR9=}SvfC9EL z`MOsUUaVdc`P^QDPY5e5XD%PZed!_29e&;Lu6gvK$~*~|_t&iXh^RqM(GMx@R60su zx!7He6XP&p_ODrxj03An*JcP>z?-&nXk)-ZJa_67tz}>vyqtGS$$7@Y^n%cmkzESR zQQlHJJ!E5tvff7YTQ`2HEG%}cO+f9TBZ9@D?U*~cVWa0< zl3jLAwo(6qR@E=%_2vqcR8<57T$4cWu%|ZkeGcx((f1SX@PK&pi(`J|%7y<@ezg$u zR!MMU&g<59UDXB)S`KcVN$o!Hya;T@!i{4M={RtE@1^ji6)<8o zsQX(U54XJEY9X?vd#03WB~ZR^nz)Z4!%_KE@h)Q( zj^glDb#0-A&vz0cr7CY*9j1b zQ;xYpHLPYoF33HWjXo1ro7xwZfZf}7u9GKHu}XDd?;B8o^_h?rXS_KWwB=Ew?`j!v z<(d5ZO9j|QsgEt`DTiS3^+7Xg4tlhH(2x};1bsZ|mx5v*ss>XIJ|&ic)|fqSXidc> zR-1@TkI#YJfv)Z9$;ab)Ns7v4sr0vFNoK8>u_TssA7@&_7Rnp>c zH&4?(zqSi1`xV)bh7wQ`Yr!;qA;&o1UqwO`{;E#HZZ%0TI ze8}(!dbKDLx3$<)>MmS>0G}t;$x9NkKz&npAdH8KCf-=T>Ieq({I-qOB}DK09BGCq z1?$Jc?gU@^9t7TvXJh>6*lz7-#^BaLQGT-cP#Xt5{L9u6Iy=EGB#2aPo`dJLkqv`2 GYWO##=jD3< diff --git a/tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs2_dim2_W_Lee_updated.npy b/tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs2_dim2_W_updated.npy similarity index 100% rename from tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs2_dim2_W_Lee_updated.npy rename to tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs2_dim2_W_updated.npy diff --git a/tests/test_data/nmf_framework/multimodal_corrnmf/model0_W_surrogate_updated.npy b/tests/test_data/nmf_framework/multimodal_corrnmf/model0_W_surrogate_updated.npy deleted file mode 100644 index 57d3796a0bd73a0bae30419e6a4375f3a88756c4..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1664 zcmbV||34E51IK+Q;weNqX_bt}@YvH4I-h6Rh*~RO(zN-u&3p?p?M7^En6O0TX_70Y zI|<=ZkMrVCN)ENFPFmluTwEt5`sQ5sFFfzx-tQmZ<=evO+nL&0BCUNyE>FlwCwdcz z8{;+*Nd#hCYI;U`OiEO0I+yo%_m4>y@;-NAd`ueevpafukO*Wa5+R51|6-i4D(zd7 zxF3?T?tcCLBB)jg7*u%{B25O#-AV;MTDOdKQzFqc^v|re$q-y=S*mC+-48PxN%LFD zeh4`}O*v6nNFLw3OVE~wIl0zO9bOu;{T6E1`{u%=?n0?`I|bLe2yey`SZFKJeX5Bn z!K&1%2%W)5cv?4DfALa*Z3nuzBhwr>)-=_6&dMMh^XXf0kO{gbq`q@y6e#~NQkAb} zpj}o%uKqR~g<^6m<9aF*G$u9PFDo(p`dG-55Eb4%&7>+f=E7d;wR%#g6#Pwh;>gPd zFrgnTw$$^$tDANIvp8FY9}5hr_1AOp_Nr;u=HPu;#*iL0b(3Pnf}?JliwI}tR=zc& z$3gG8U)ySfI2^FJJjuOALAd>rUJ0=PSG$f53Ae^!oOsRj*eN+~%c9vj+0oDrYCI%8 znGVu#)G{;og*f^D%ywkDcP8Q)XM+|+hXe79XxzXu z6{Zt*$9O40Xt1M|*rW(SuI_NXzjh~BrYWC{uTbGz8Nq$FhKe#bqZh;OGK7c}GS9s~ zfbmK=6ViQ-??%p4xhwZl#gbA z!bGIp^Zc{~oo5}_Ssf~cGbN+Gq9+6F#JyjdP_nSDdP}%Ch>8BrB0)Rr1O|=SPh?LE zU_iOv|2~T%%$XC5iz9MyhUfiV?{ER`mHoatF z=O8fOqM--@a2EG}%X}n7;G3FVJKk-^nlQC`+FOZ@@-^G~x;Mdgtsiap%XCy<^wUl# zk3nS5$u|GdXx!LQf8pUBDVpt;M@+rThHWA@kZK-^`S3Gr?ssws+v&lq`^gyhxy*l0 zOGD$IzckPPn-2TIMt2%j0Nqf^55DFM1lFu4^t`9y;*c9Xp+*R!fsDlGto=ChCcaAeYVro-DqSEZ<(_dpg+@18cStkWq`x0N#1s>jr#s{ygC6Y+Mv}O z3NX;S)o^M>62wn09U@68P+`%>YT@qy!{YG3&ZMuQI&I)jo!JAI#m57K22$bZaIS|o zAVaafR#TiuIZhNj9@f{yAmG8DIws$^gLW=^>Sb3BcF!vcZW;K(PN_IH=9P>8l#D-k zT_FIEJ0r2UvIjY4`-Y}(QlTumVC55(j|KgOt&25sw0^@7w4P$Z@^;}@k62lFVX!nl zMrSP?e*Kr(y?GH_lNKf|Z^vS%6Y-2CO$ev3WtCKp2u_RS{O2cRSQ4%uTE>k5GeSIZ zcaac=Wa^Sxh7hM8O=x<0LRKQw8a|J%ord`2yIxxy z1kgL)ThX*38m_;+s100Ag(T~p(ZYjhREO5t=oz|UOgG9~)tH4k7x&qbx#gHJA1(>4 z3P$oT1h`z?gJN;ZQAI%#mbG_od!ikSJmQolz$^u%Jm#}{z9GU&*WEuZ-U|!5mHou! VV5oaPxpw(SBVq3fO|~=-e*sex%x(Yx diff --git a/tests/test_data/nmf_framework/multimodal_corrnmf/model0_W_Lee_updated.npy b/tests/test_data/nmf_framework/multimodal_corrnmf/model0_W_updated.npy similarity index 100% rename from tests/test_data/nmf_framework/multimodal_corrnmf/model0_W_Lee_updated.npy rename to tests/test_data/nmf_framework/multimodal_corrnmf/model0_W_updated.npy diff --git a/tests/test_data/nmf_framework/multimodal_corrnmf/model1_W_surrogate_updated.npy b/tests/test_data/nmf_framework/multimodal_corrnmf/model1_W_surrogate_updated.npy deleted file mode 100644 index aad58f25db882e3c69dbfda99fc57e4e33e23092..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 2432 zcmbW2`9Bm01IAHurIw^n{NcW=rk6V(?2UC9VU>{Nt<%d^4Xs}07Ph*c_q~iZ8(OX8HQ}+uw zI2>nTJSmU_VJpq94@c7AG9^O3W*r8WsqLy+FbCQrWw)tz2}obg`l)o#3DU}agM86M zT-3XG;IgC#_UUUY2*yQY@>k|{n<6%@o;mGyh8K->^0(5{XOr+xNHejN8UnGr?(zN; ziSSweVh|Tbf%;T?hAJ->RaAcooqBIrS2Rw#>~cqa$KJp6Y>uI>)nS@+CkXustDnEk zxM5aOy=Y7+64lF9+$ZdG)CQD9`Z^?m{cLak+Vn}Nt5JPSJg=bh<@VAE4g*%F&&pL; zhhS+*^?cd~FX+|GoKY&%f!zn)72LlJ+k?fwN&S(cx#APbnA3V zdSdEva95vw4D8=Wl+q*OP_d@IEN(#qd4FQJ`jrq+aZnorL}hd&klgENd2 z9v4Yk>)G#)$HD98l(UA}VUSBT$am8O9#g%>`70EFXUPG9Fl9p@#p0{Il6N?N;l|rH0X1|i%mGu8C+WR4q^|yZ|{v+%9?@ z2L&w?*%OZGXwy6@&DhL=gZC4?S#u6}wjU34T5xe+k3yA6<$x^S^;%Dbf%<#J%<`A6 zkeP@G%o)vwkax?;x#O46{hYf#_ihYG%LXm>Ct{KQk8`qrYYZZ~>TX18T*G7sm#?<0 z6DtqoeT^bpp-G_?yGw;)ntf}zxU3SZ4cdPRzfD2on?~V(pI?NLcDDJ<_FP!2R90?? z$b`6no!q0zCJd`&#xV_xu~t7_vTk2Bgbi|(RE5vs*0CV2VgVa-X^uCY5@^tV8F|tu zJP6@GHSbIBr+~l8kKf-i6#uxTjMnw~fw##_G}n#}qqnch;QT6;u)A~~+?379wDYEg0_H)9E zjhmS0mZ$&9Iw4 za)k_~-W8*^b5v|UDCwFh>Vk{|t4CfJc))yq_q(EeBH}YW1l)NJ$ocux>q-qT+$ro? zl%&$}uUf~>r=gapdG_6Ri=G~wvOG?Nm9gMi`}pXjq9@uDe=1rsi@^AojgN)}I3S8t zP<%v!5UFv>XecTOJmcNk)@}xfG8i0sVHAs}oq`(R_k)Ku^=+Cn3ELyS&$|rSAS_^S z)^Eqd5Y$3AbhP`&o^}6{wR)U@#)XRuJ~q*a)tQ;TelH#pUCy`tkLSY9-^(F>i4D&D zm#ny18jK5sWw-psKt<@#*;!>O_*AS$c$c{-Rv#HSWfqR8;iiloF%h_KDa%1A^pF+E-*a#cGB**fmC)<5!{ma#G#lwo2aSd?`-nFEip zySl1>hJz#^>&KP#N7d6Ehc{XKBW1~T*R%EE=v5XP-!~eLsks*i+AHZe%UatO85xWo z!A;K#grbqz$z-MN`89tUwk~|*xs`>Ya@`5O7laTP*D*=527YMq zPSZ&?C*jPi?*UiLQ}D$h;hW)Y8cN6skrqo+vpbpRl{f%c^d*6 z?+a2oJTE}VmPasV#A952gm~%44^B*#i;C&`WA=Ngeb*LmSO$JRDns+YM!`(iVC&N$ z^h{q{7h;ZYI;zqf-vHzthAeq!xJv6%9S+uooaU^82W&* z-%6l7_7JE_+_7b6Zwx={c5ArniS#F#xj7L*fRnA(yo5dG+l5`U0#4v_a61HH0ryjr zT+wVM`fFTgG{w%tbERdeaw{7(@q71NHKt-=p|hdxB^}zmM-*&CeUZ1;|A$nyBU*@- z>OGlG;LfzTyD8WrCANg3cGDUBTQ>|k2s%R1;$aZm<0P)=&u=TNp`+x}q3F$IADC)n z*WNfwM{4*Ar-*b4ZE>}Q#P2~U+vaBQP|_8h;s$qov}uqwuD5F#3InnAhU35^8rb7g zwGvt`Xuh1(AbG?M+g{1M&=B>*mjS~AOTT;L;Y3$;avKAL Date: Thu, 12 Oct 2023 17:29:26 -0400 Subject: [PATCH 03/13] restructure numba code - corrnmf numba code is now in a separate file called _ultils_corrnmf because MultimodalCorrNMF also uses it - the code for the objective function, gradient and hessian of a signature/sample embedding in CorrNMF was unified - the update of the model variance sigma^2 in CorrNMF was simplified --- .pre-commit-config.yaml | 1 + .../nmf_framework/_utils_corrnmf.py | 252 ++++++++++++++++++ src/salamander/nmf_framework/corrnmf.py | 142 ---------- src/salamander/nmf_framework/corrnmf_det.py | 38 ++- .../nmf_framework/multimodal_corrnmf.py | 23 +- 5 files changed, 287 insertions(+), 169 deletions(-) create mode 100644 src/salamander/nmf_framework/_utils_corrnmf.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 50e6ca5..bf8a5f5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,6 +7,7 @@ repos: - id: check-yaml - id: end-of-file-fixer - id: mixed-line-ending + - id: trailing-whitespace - repo: https://github.com/python-poetry/poetry rev: 1.6.1 hooks: diff --git a/src/salamander/nmf_framework/_utils_corrnmf.py b/src/salamander/nmf_framework/_utils_corrnmf.py new file mode 100644 index 0000000..910d2e9 --- /dev/null +++ b/src/salamander/nmf_framework/_utils_corrnmf.py @@ -0,0 +1,252 @@ +import numpy as np +from numba import njit + +EPSILON = np.finfo(np.float32).eps + + +@njit +def update_alpha(X: np.ndarray, L: np.ndarray, U: np.ndarray) -> np.ndarray: + """ + Compute the new sample biases alpha according to the update rule of CorrNMF. + + Parameters + ---------- + X : np.ndarray of shape (n_features, n_samples) + data matrix + + a: asdf + asdf + + L : np.ndarray of shape (dim_embeddings, n_signatures) + signature embeddings + + U : np.ndarray of shape (dim_embeddings, n_samples) + sample embeddings + + Returns + ------- + alpha : np.ndarray of shape (n_samples,) + The new sample biases alpha + """ + exp_LTU = np.exp(L.T @ U) + alpha = np.log(np.sum(X, axis=0)) - np.log(np.sum(exp_LTU, axis=0)) + return alpha + + +@njit +def update_W(X: np.ndarray, W: np.ndarray, H: np.ndarray) -> np.ndarray: + """ + Compute the new signatures according to the update rule of NMF with + the generalized Kullback-Leibler divergence. An additional normalization + step is performed to guarantee that the signatures are probability distributions + over the mutation types. + + Parameters + ---------- + X : np.ndarray of shape (n_features, n_samples) + data matrix + + W : np.ndarray of shape (n_features, n_signatures) + signature matrix + + H : np.ndarray of shape (n_signatures, n_samples) + exposure matrix + + Returns + ------- + W : np.ndarray of shape (n_features, n_signatures) + updated signature matrix + + References + ---------- + D. Lee, H. Seung: Algorithms for Non-negative Matrix Factorization + - Advances in neural information processing systems, 2000 + https://proceedings.neurips.cc/paper_files/paper/2000/file/f9d1152547c0bde01830b7e8bd60024c-Paper.pdf + """ + W *= (X / (W @ H)) @ H.T + W /= np.sum(W, axis=0) + W = W.clip(EPSILON) + return W + + +@njit +def update_p_unnormalized(W: np.ndarray, H: np.ndarray) -> np.ndarray: + """ + Compute the new auxiliary parameters according to the update rule of CorrNMF. + The normalization per mutation type and sample is not performed yet. + + Parameters + ---------- + W : np.ndarray of shape (n_features, n_signatures) + signature matrix + + H : np.ndarray of shape (n_signatures, n_samples) + exposure matrix + + Returns + ------- + p: np.ndarray of shape (n_features, n_signatures, n_samples) + """ + n_features, n_signatures = W.shape + n_samples = H.shape[1] + p = np.zeros((n_features, n_signatures, n_samples)) + + for v in range(n_features): + for k in range(n_signatures): + for d in range(n_samples): + p[v, k, d] = W[v, k] * H[k, d] + + return p + + +@njit +def objective_function_embedding( + embedding, embeddings_other, alpha, sigma_sq, aux_vec, add_penalty=True +): + r""" + The objective function of a signature or sample embedding in CorrNMF. + + Parameters + ---------- + embedding : np.ndarray of shape (dim_embeddings,) + The signature or sample embedding + + embeddings_other : np.ndarray of shape (dim_embeddings, n_samples | n_signatures) + If 'embedding' is a signature embedding, 'embeddings_other' are + all sample embeddings. If 'embedding' is a sample embedding, + 'embeddings_other' are all signature embeddings. + + alpha : float | np.narray of shape (n_samples,) + If 'embedding' is a signature embedding, 'alpha' are + all sample biases. If 'embedding' is a sample embedding, + 'alpha' is the bias of the corresponding sample. + + sigma_sq : float + model variance + + aux_vec : np.ndarray of shape (n_signatures | n_samples,) + A row or column of + aux[k, d] = \sum_v X_vd * p_vkd, + where X is the data matrix and p are the auxiliary parameters of CorrNMF. + If 'embedding' is a signature embedding, the corresponding row is provided. + If 'embedding' is a sample embedding, the corresponding column is provided. + + add_penalty : bool, default=True + Set to True, the norm of the embedding will be penalized. + This argument is useful for the implementation of multimodal CorrNMF. + """ + n_embeddings_other = embeddings_other.shape[1] + of_value = 0.0 + scalar_products = embeddings_other.T.dot(embedding) + + # aux_vec not necessarily contiguous: + # np.dot(scalar_products, aux_vec) doesn't work + for i in range(n_embeddings_other): + of_value += scalar_products[i] * aux_vec[i] + + # works for alpha being a scalar or vector + of_value -= np.sum(np.exp(alpha + scalar_products)) + + if add_penalty: + of_value -= np.dot(embedding, embedding) / (2 * sigma_sq) + + return -of_value + + +@njit +def gradient_embedding( + embedding, embeddings_other, alpha, sigma_sq, summand_grad, add_penalty=True +): + r""" + The gradient of the objective function w.r.t. a signature or sample embedding + in CorrNMF. + + Parameters + ---------- + embedding : np.ndarray of shape (dim_embeddings,) + The signature or sample embedding + + embeddings_other : np.ndarray of shape (dim_embeddings, n_samples | n_signatures) + If 'embedding' is a signature embedding, 'embeddings_other' are + all sample embeddings. If 'embedding' is a sample embedding, + 'embeddings_other' are all signature embeddings. + + alpha : float | np.narray of shape (n_samples,) + If 'embedding' is a signature embedding, 'alpha' are + all sample biases. If 'embedding' is a sample embedding, + 'alpha' is the bias of the corresponding sample. + + sigma_sq : float + model variance + + summand_grad : np.ndarray of shape (dim_embeddings,) + A signature/sample-independent summand of the gradient. + + add_penalty : bool, default=True + Set to True, the norm of the embedding will be penalized. + This argument is useful for the implementation of multimodal CorrNMF. + """ + scalar_products = embeddings_other.T.dot(embedding) + gradient = -np.sum(np.exp(alpha + scalar_products) * embeddings_other, axis=1) + gradient += summand_grad + + if add_penalty: + gradient -= embedding / sigma_sq + + return -gradient + + +@njit +def hessian_embedding( + embedding, + embeddings_other, + alpha, + sigma_sq, + outer_prods_embeddings_other, + add_penalty=True, +): + r""" + The Hessian of the objective function w.r.t. a signature or sample embedding + in CorrNMF. + + Parameters + ---------- + embedding : np.ndarray of shape (dim_embeddings,) + The signature or sample embedding + + embeddings_other : np.ndarray of shape (dim_embeddings, n_samples | n_signatures) + If 'embedding' is a signature embedding, 'embeddings_other' are + all sample embeddings. If 'embedding' is a sample embedding, + 'embeddings_other' are all signature embeddings. + + alpha : float | np.narray of shape (n_samples,) + If 'embedding' is a signature embedding, 'alpha' are + all sample biases. If 'embedding' is a sample embedding, + 'alpha' is the bias of the corresponding sample. + + sigma_sq : float + model variance + + aux_vec : np.ndarray of shape (n_signatures | n_samples,) + A row or column of + aux[k, d] = \sum_v X_vd * p_vkd, + where X is the data matrix and p are the auxiliary parameters of CorrNMF. + If 'embedding' is a signature embedding, the corresponding row is provided. + If 'embedding' is a sample embedding, the corresponding column is provided. + + add_penalty : bool, default=True + Set to True, the norm of the embedding will be penalized. + This argument is useful for the implementation of multimodal CorrNMF. + """ + dim_embeddings, n_embeddings_other = embeddings_other.shape + scalings = np.exp(alpha + embeddings_other.T.dot(embedding)) + hessian = np.zeros((dim_embeddings, dim_embeddings)) + + for m1 in range(dim_embeddings): + for m2 in range(dim_embeddings): + for i in range(n_embeddings_other): + hessian[m1, m2] -= scalings[i] * outer_prods_embeddings_other[i, m1, m2] + if add_penalty and m1 == m2: + hessian[m1, m2] -= 1 / sigma_sq + + return -hessian diff --git a/src/salamander/nmf_framework/corrnmf.py b/src/salamander/nmf_framework/corrnmf.py index dec3a6a..fecc9e4 100644 --- a/src/salamander/nmf_framework/corrnmf.py +++ b/src/salamander/nmf_framework/corrnmf.py @@ -7,7 +7,6 @@ import matplotlib.pyplot as plt import numpy as np import pandas as pd -from numba import njit from scipy.spatial.distance import squareform from scipy.special import gammaln @@ -33,147 +32,6 @@ EPSILON = np.finfo(np.float32).eps -@njit -def update_alpha(X, L, U): - exp_LTU = np.exp(L.T @ U) - alpha = np.log(np.sum(X, axis=0)) - np.log(np.sum(exp_LTU, axis=0)) - return alpha - - -@njit -def update_sigma_sq(L, U): - dim_embeddings, n_signatures = L.shape - n_samples = U.shape[1] - sum_norm_sigs = np.sum(L**2) - sum_norm_samples = np.sum(U**2) - sigma_sq = (sum_norm_sigs + sum_norm_samples) / ( - dim_embeddings * (n_signatures + n_samples) - ) - return sigma_sq - - -@njit -def update_W(X, W, H): - W *= (X / (W @ H)) @ H.T - W /= np.sum(W, axis=0) - W = W.clip(EPSILON) - return W - - -@njit -def update_p_unnormalized(W, H): - n_features, n_signatures = W.shape - n_samples = H.shape[1] - p = np.zeros((n_features, n_signatures, n_samples)) - - for v in range(n_features): - for k in range(n_signatures): - for d in range(n_samples): - p[v, k, d] = W[v, k] * H[k, d] - - return p - - -@njit -def _objective_fun_l(l, U, alpha, sigma_sq, aux_row): - UTl = U.T.dot(l) - s = np.dot(aux_row, UTl) - s -= np.sum(np.exp(alpha + UTl)) - s -= np.dot(l, l) / (2 * sigma_sq) - - return -s - - -@njit -def _gradient_l(l, U, alpha, sigma_sq, s_grad): - s = -np.sum(np.exp(alpha + U.T.dot(l)) * U, axis=1) - s -= l / sigma_sq - - return -(s_grad + s) - - -@njit -def _hessian_l(l, U, alpha, sigma_sq, outer_prods_U): - dim_embeddings, n_samples = U.shape - scalings = np.exp(alpha + U.T.dot(l)) - s = np.zeros((dim_embeddings, dim_embeddings)) - - for m1 in range(dim_embeddings): - for m2 in range(dim_embeddings): - for d in range(n_samples): - s[m1, m2] -= scalings[d] * outer_prods_U[d, m1, m2] - if m1 == m2: - s[m1, m2] -= 1 / sigma_sq - - return -s - - -@njit -def _objective_fun_u( - u: np.ndarray, - L: np.ndarray, - alpha: float, - sigma_sq: float, - aux_col: np.ndarray, - add_penalty_u=True, -): - n_signatures = L.shape[1] - LTu = L.T.dot(u) - s = 0.0 - - # aux_col not contiguous: - # s = np.dot(aux_col, LTu) doesn't work - for k in range(n_signatures): - s += aux_col[k] * LTu[k] - - s -= np.sum(np.exp(alpha + LTu)) - - if add_penalty_u: - s -= np.dot(u, u) / (2 * sigma_sq) - - return -s - - -@njit -def _gradient_u( - u: np.ndarray, - L: np.ndarray, - alpha: float, - sigma_sq: float, - s_grad: np.ndarray, - add_penalty_u=True, -): - s = -np.exp(alpha) * np.sum(np.exp(L.T.dot(u)) * L, axis=1) - - if add_penalty_u: - s -= u / sigma_sq - - return -(s_grad + s) - - -@njit -def _hessian_u( - u: np.ndarray, - L: np.ndarray, - alpha: float, - sigma_sq: float, - outer_prods_L: np.ndarray, - add_penalty_u=True, -): - dim_embeddings, n_signatures = L.shape - scalings = np.exp(alpha + L.T.dot(u)) - s = np.zeros((dim_embeddings, dim_embeddings)) - - for m1 in range(dim_embeddings): - for m2 in range(dim_embeddings): - for k in range(n_signatures): - s[m1, m2] -= scalings[k] * outer_prods_L[k, m1, m2] - if add_penalty_u and m1 == m2: - s[m1, m2] -= 1 / sigma_sq - - return -s - - class CorrNMF(SignatureNMF): r""" The abstract class CorrNMF unifies the structure of deterministic and diff --git a/src/salamander/nmf_framework/corrnmf_det.py b/src/salamander/nmf_framework/corrnmf_det.py index 00905fa..21bcc5f 100644 --- a/src/salamander/nmf_framework/corrnmf_det.py +++ b/src/salamander/nmf_framework/corrnmf_det.py @@ -6,12 +6,13 @@ import pandas as pd from scipy import optimize -from . import corrnmf +from . import _utils_corrnmf +from .corrnmf import CorrNMF EPSILON = np.finfo(np.float32).eps -class CorrNMFDet(corrnmf.CorrNMF): +class CorrNMFDet(CorrNMF): r""" The CorrNMFDet class implements the deterministic batch version of a variant of the correlated NMF (CorrNMF) algorithm devolped in @@ -47,34 +48,37 @@ class CorrNMFDet(corrnmf.CorrNMF): """ def _update_alpha(self): - self.alpha = corrnmf.update_alpha(self.X, self.L, self.U) + self.alpha = _utils_corrnmf.update_alpha(self.X, self.L, self.U) def _update_sigma_sq(self): - self.sigma_sq = corrnmf.update_sigma_sq(self.L, self.U) + embeddings = np.concatenate([self.L, self.U], axis=1) + self.sigma_sq = np.mean(embeddings**2) self.sigma_sq = np.clip(self.sigma_sq, EPSILON, None) def _update_W(self): - self.W = corrnmf.update_W(self.X, self.W, self.exposures.values) + self.W = _utils_corrnmf.update_W(self.X, self.W, self.exposures.values) def _update_p(self): - p = corrnmf.update_p_unnormalized(self.W, self.exposures.values) + p = _utils_corrnmf.update_p_unnormalized(self.W, self.exposures.values) p /= np.sum(p, axis=1, keepdims=True) p = p.clip(EPSILON) return p def _update_l(self, index, aux_row, outer_prods_U): def objective_fun(l): - return corrnmf._objective_fun_l( + return _utils_corrnmf.objective_function_embedding( l, self.U, self.alpha, self.sigma_sq, aux_row ) - s_grad = np.sum(aux_row * self.U, axis=1) + summand_grad = np.sum(aux_row * self.U, axis=1) def gradient(l): - return corrnmf._gradient_l(l, self.U, self.alpha, self.sigma_sq, s_grad) + return _utils_corrnmf.gradient_embedding( + l, self.U, self.alpha, self.sigma_sq, summand_grad + ) def hessian(l): - return corrnmf._hessian_l( + return _utils_corrnmf.hessian_embedding( l, self.U, self.alpha, self.sigma_sq, outer_prods_U ) @@ -109,15 +113,21 @@ def _update_u(self, index, aux_col, outer_prods_L): alpha = self.alpha[index] def objective_fun(u): - return corrnmf._objective_fun_u(u, self.L, alpha, self.sigma_sq, aux_col) + return _utils_corrnmf.objective_function_embedding( + u, self.L, alpha, self.sigma_sq, aux_col + ) - s_grad = np.sum(aux_col * self.L, axis=1) + summand_grad = np.sum(aux_col * self.L, axis=1) def gradient(u): - return corrnmf._gradient_u(u, self.L, alpha, self.sigma_sq, s_grad) + return _utils_corrnmf.gradient_embedding( + u, self.L, alpha, self.sigma_sq, summand_grad + ) def hessian(u): - return corrnmf._hessian_u(u, self.L, alpha, self.sigma_sq, outer_prods_L) + return _utils_corrnmf.hessian_embedding( + u, self.L, alpha, self.sigma_sq, outer_prods_L + ) u = optimize.minimize( fun=objective_fun, diff --git a/src/salamander/nmf_framework/multimodal_corrnmf.py b/src/salamander/nmf_framework/multimodal_corrnmf.py index ed89089..deb3114 100755 --- a/src/salamander/nmf_framework/multimodal_corrnmf.py +++ b/src/salamander/nmf_framework/multimodal_corrnmf.py @@ -29,7 +29,7 @@ umap_2d, ) from ..utils import type_checker, value_checker -from . import corrnmf +from . import _utils_corrnmf from .corrnmf_det import CorrNMFDet EPSILON = np.finfo(np.float32).eps @@ -186,12 +186,9 @@ def _update_alphas(self): model._update_alpha() def _update_sigma_sq(self): - sum_norm_sigs = np.sum([np.sum(model.L**2) for model in self.models]) - sum_norm_samples = np.sum(self.models[0].U ** 2) - - sigma_sq = (sum_norm_sigs + sum_norm_samples) / ( - self.dim_embeddings * (np.sum(self.ns_signatures) + self.n_samples) - ) + embeddings = np.concatenate([model.L for model in self.models], axis=1) + embeddings = np.concatenate([embeddings, self.models[0].U], axis=1) + sigma_sq = np.mean(embeddings**2) sigma_sq = np.clip(sigma_sq, EPSILON, None) for model in self.models: @@ -209,13 +206,13 @@ def _objective_fun_u(self, u, index, aux_cols): sigma_sq = self.models[0].sigma_sq s = -np.sum( [ - corrnmf._objective_fun_u( + _utils_corrnmf.objective_function_embedding( u, model.L, model.alpha[index], sigma_sq, aux_col, - add_penalty_u=False, + add_penalty=False, ) for model, aux_col in zip(self.models, aux_cols) ] @@ -228,13 +225,13 @@ def _gradient_u(self, u, index, s_grads): sigma_sq = self.models[0].sigma_sq s = -np.sum( [ - corrnmf._gradient_u( + _utils_corrnmf.gradient_embedding( u, model.L, model.alpha[index], sigma_sq, s_grad, - add_penalty_u=False, + add_penalty=False, ) for model, s_grad in zip(self.models, s_grads) ], @@ -248,13 +245,13 @@ def _hessian_u(self, u, index, outer_prods_Ls): sigma_sq = self.models[0].sigma_sq s = -np.sum( [ - corrnmf._hessian_u( + _utils_corrnmf.hessian_embedding( u, model.L, model.alpha[index], sigma_sq, outer_prods_L, - add_penalty_u=False, + add_penalty=False, ) for model, outer_prods_L in zip(self.models, outer_prods_Ls) ], From fcad1fd452263c361c61b73dcb3bb645f615058c Mon Sep 17 00:00:00 2001 From: BeGeiger Date: Thu, 12 Oct 2023 17:41:25 -0400 Subject: [PATCH 04/13] remove hyperparameter selectors Selecting a model hyperparameter in an automatic way should not be part of Salamander --- src/salamander/nmf_framework/corrnmf.py | 132 ----------------- src/salamander/nmf_framework/mvnmf.py | 183 +----------------------- src/salamander/utils.py | 46 ------ 3 files changed, 2 insertions(+), 359 deletions(-) diff --git a/src/salamander/nmf_framework/corrnmf.py b/src/salamander/nmf_framework/corrnmf.py index fecc9e4..ae3b2d4 100644 --- a/src/salamander/nmf_framework/corrnmf.py +++ b/src/salamander/nmf_framework/corrnmf.py @@ -1,8 +1,5 @@ -import multiprocessing -import os import warnings from abc import abstractmethod -from copy import deepcopy import matplotlib.pyplot as plt import numpy as np @@ -597,132 +594,3 @@ def plot_embeddings( plt.savefig(outfile, bbox_inches="tight") return ax - - -class CorrNMFHyperparameterSelector: - """ - The embedding dimension of samples and signatures is - the only hyperparameter of correlated NMF. - This class implements methods to select the "optimal" embedding dimension. - The framework of hyperparameter selectors allows to implement - a denovo signature analysis pipeline in an NMF algorithm agnostic manner: - A dictionary can be used to set all hyperparameters, - irrespective of the NMF algorithm and its arbitrary number of hyperparameters. - """ - - def __init__(self, method="unbiased", method_kwargs=None): - value_checker("method", method, ["BIC", "proportional", "unbiased"]) - self.method = method - self.method_kwargs = {} if method_kwargs is None else method_kwargs.copy() - - # initialize selection dependent attributes - self.corrnmf_algorithm = None - self.dims_embeddings = np.empty(0, dtype=int) - self.data = None - self.given_signatures = None - self.init_kwargs = None - self.verbose = 0 - self.models = [] - - def _job_select_bic(self, dim_embeddings): - """ - Apply CorrNMF for a single embedding dimension. - """ - model = deepcopy(self.corrnmf_algorithm) - model.dim_embeddings = dim_embeddings - model.fit( - data=self.data, - given_signatures=self.given_signatures, - init_kwargs=self.init_kwargs, - verbose=0, - ) - - if self.verbose: - print(f"CorrNMF with dim_embeddings = {dim_embeddings} finished.") - - return model - - def select_bic(self, ncpu=None): - """ - Select the best embedding dimension based - on the Bayesian Information Criterion (BIC). - """ - if ncpu is None: - ncpu = os.cpu_count() - - if ncpu > 1: - workers = multiprocessing.Pool(ncpu) - models = workers.map(self._job_select_bic, self.dims_embeddings) - workers.close() - workers.join() - - else: - models = [ - self._job_select_bic(dim_embeddings) - for dim_embeddings in self.dims_embeddings - ] - - self.models = models - bics = np.array([model.bic for model in models]) - best_index = np.argmin(bics) - best_model = models[best_index] - - return best_model.dim_embeddings - - def select_proportional(self, proportion=0.75): - """ - The embedding dimension is set to a proportion of the number of signatures. - """ - n_signatures = self.corrnmf_algorithm.n_signatures - dim_embeddings = int(proportion * n_signatures) if n_signatures > 1 else 1 - - return dim_embeddings - - def select_unbiased(self, normalized=True): - """ - The embedding dimension is set to the number of signatures - if 'normalized' is false. If 'normalized' is true, the embedding - dimension is set to the number of signatures minus one. - - Input: - ------ - normalized: bool - If the input count matrix will be normalized, the number of free - parameters for each sample exposure is 'n_signatures - 1'. - Without the normalization, there are 'n_signatures' many free parameters. - """ - n_signatures = self.corrnmf_algorithm.n_signatures - - if not normalized: - return n_signatures - - return max(1, n_signatures - 1) - - def select( - self, - corrnmf_algorithm, - data: pd.DataFrame, - given_signatures=None, - init_kwargs=None, - ncpu=None, - verbose=0, - ): - self.corrnmf_algorithm = corrnmf_algorithm - self.dims_embeddings = np.arange(1, corrnmf_algorithm.n_signatures + 1) - self.data = data - self.given_signatures = given_signatures - self.init_kwargs = init_kwargs - self.verbose = verbose - - if self.method == "BIC": - dim_embeddings = self.select_bic(ncpu=ncpu, **self.method_kwargs) - - elif self.method == "proportional": - dim_embeddings = self.select_proportional(**self.method_kwargs) - - elif self.method == "unbiased": - dim_embeddings = self.select_unbiased(**self.method_kwargs) - - hyperparameters = {"dim_embeddings": dim_embeddings} - - return hyperparameters diff --git a/src/salamander/nmf_framework/mvnmf.py b/src/salamander/nmf_framework/mvnmf.py index b96f7e2..ddc8ecd 100644 --- a/src/salamander/nmf_framework/mvnmf.py +++ b/src/salamander/nmf_framework/mvnmf.py @@ -1,19 +1,7 @@ -import multiprocessing -import os -import warnings -from copy import deepcopy - import numpy as np import pandas as pd -from scipy import stats - -from ..utils import ( - differential_tail_test, - kl_divergence, - normalize_WH, - poisson_llh, - samplewise_kl_divergence, -) + +from ..utils import kl_divergence, normalize_WH, poisson_llh, samplewise_kl_divergence from .nmf import NMF EPSILON = np.finfo(np.float32).eps @@ -237,170 +225,3 @@ def fit( self.history["objective_function"] = of_values[1:] return self - - -class MvNMFHyperparameterSelector: - """ - The volume-regularization weight is the only hyperparameter of mvNMF. - This class implements methods to select the "optimal" volume-regularization weight. - The framework of hyperparameter selectors allow to implement - a denovo signature analysis in an NMF algorithm agnostic manner: A dictionary can - be used to set all hyperparameters, irrespective of the NMF algorithm - and its arbitrary number of hyperparameters. - - The best model is defined as the model with the strongest volume regularization - such that the samplewise reconstruction errors are still (approximately) identically - distributed to the model with the lowest volume regularization. - The distributions are compared with the Mann-Whitney-U test. - """ - - # fmt: off - default_lambda_tildes = ( - 1e-10, 2e-10, 5e-10, 1e-9, 2e-9, 5e-9, - 1e-8, 2e-8, 5e-8, 1e-7, 2e-7, 5e-7, - 1e-6, 2e-6, 5e-6, 1e-5, 2e-5, 5e-5, - 1e-4, 2e-4, 5e-4, 1e-3, 2e-3, 5e-3, - 1e-2, 2e-2, 5e-2, 1e-1, 2e-1, 5e-1, - 1.0, 2.0, - ) - # fmt: on - - def __init__(self, lambda_tildes=default_lambda_tildes, pthresh=0.05): - """ - Inputs: - ------ - lambda_tildes: tuple - An ordered list of possible volume-regularization parameters. - - pthresh: float - The distribution of samplewise reconstruction errors between two - mvNMF fitted models is considered different if the Mann-Whitney-U - test pvalue of comparing their, and their tail, distribution is lower - than pthresh. - """ - self.lambda_tildes = lambda_tildes - self.pthresh = pthresh - - # initialize selection dependent attributes - self.mvnmf_algorithm = None - self.data = None - self.given_signatures = None - self.init_kwargs = None - self.verbose = 0 - - def _job(self, lambda_tilde): - """ - Apply mvNMF for a single lambda_tilde volume regularization. - """ - model = deepcopy(self.mvnmf_algorithm) - model.lambda_tilde = lambda_tilde - model.fit( - data=self.data, - given_signatures=self.given_signatures, - init_kwargs=self.init_kwargs, - verbose=0, - ) - - if self.verbose: - print(f"mvNMF with lambda_tilde = {lambda_tilde:.2E} finished.") - - return model - - def _indicate_good_models(self, rerrors_base, rerrors_rest): - """ - Compare the distributions of the samplewise baseline reconstruction errors - and the samplewise model reconstruction errors. - - Output: - ------- - indicators: np.ndarray - One-dimensional boolean array indicating all models having - samplewise reconstruction errors similar to the baseline errors. - """ - n_models_rest = len(rerrors_rest) - - pvalue_indicators = np.empty(n_models_rest, dtype=bool) - pvalue_tail_indicators = np.empty(n_models_rest, dtype=bool) - - for i, rerrors in enumerate(rerrors_rest): - # Turn everything non-negative for the differential tail test. - # Note: The Mann-Whitney U test statistic is shift-invariant - shift = np.min([rerrors_base, rerrors]) - re_base = rerrors_base - shift - re = rerrors - shift - - pvalue = stats.mannwhitneyu(re_base, re, alternative="less")[1] - pvalue_indicators[i] = pvalue > self.pthresh - - pvalue_tail = differential_tail_test( - re_base, re, percentile=90, alternative="less" - )[1] - pvalue_tail_indicators[i] = pvalue_tail > self.pthresh - - indicators = pvalue_indicators & pvalue_tail_indicators - - return indicators - - def _get_best_lambda_tilde(self, indicators): - # np.argmin returns the first "bad" model index - # Note: self.lambda_tildes[index] will be a "good" model because the number of - # possible volume regularizations and the length of indicators differs by one. - index = np.argmin(indicators) - - if all(indicators): - index = len(indicators) - warnings.warn( - "For all lambda_tilde, the sample-wise reconstruction errors are " - "comparable to the reconstruction errors with no regularization. " - "The model with the strongest volume regularization is selected.", - UserWarning, - ) - - if index == 0: - warnings.warn( - "The smallest lambda_tilde is selected. The optimal lambda_tilde " - "might be smaller. We suggest to extend the grid to smaller " - "lambda_tilde values to validate.", - UserWarning, - ) - - best_lambda_tilde = self.lambda_tildes[index] - - return best_lambda_tilde - - def select( - self, - mvnmf_algorithm, - data: pd.DataFrame, - given_signatures=None, - init_kwargs=None, - ncpu=1, - verbose=0, - ): - self.mvnmf_algorithm = mvnmf_algorithm - self.data = data - self.given_signatures = given_signatures - self.init_kwargs = init_kwargs - self.verbose = verbose - - if ncpu is None: - ncpu = os.cpu_count() - - workers = multiprocessing.Pool(ncpu) - models = workers.map(self._job, self.lambda_tildes) - workers.close() - workers.join() - - samplewise_rerrors_all = np.array( - [model.samplewise_reconstruction_error for model in models] - ) - rerrors_base, rerrors_rest = ( - samplewise_rerrors_all[0], - samplewise_rerrors_all[1:], - ) - - indicators = self._indicate_good_models(rerrors_base, rerrors_rest) - best_lambda_tilde = self._get_best_lambda_tilde(indicators) - hyperparameters = {"lambda_tilde": best_lambda_tilde} - - return hyperparameters diff --git a/src/salamander/utils.py b/src/salamander/utils.py index fcc1292..96e9b46 100644 --- a/src/salamander/utils.py +++ b/src/salamander/utils.py @@ -1,11 +1,8 @@ -import warnings - import numpy as np import pandas as pd from numba import njit from scipy.optimize import linear_sum_assignment from scipy.special import gammaln -from scipy.stats import mannwhitneyu from sklearn.metrics import pairwise_distances EPSILON = np.finfo(np.float32).eps @@ -182,46 +179,3 @@ def match_signatures_pair( reordered_indices = linear_sum_assignment(pdist)[1] return reordered_indices - - -def differential_tail_test(a, b, percentile=90, alternative="two-sided"): - """ - Test if distribution tails are different (pubmed: 18655712) - - Input - ------ - a, b : array-like - must be positive. - - percentile : float - Percentile threshold above which data points are considered tails. - - alternative : {'two-sided', 'less', 'greater'} - Defines the alternative hypothesis. For example, when set to 'greater', - the alternative hypothesis is that the tail of a is greater than the tail - of b. - """ - a, b = np.array(a), np.array(b) - - if len(a) != len(b): - warnings.warn( - "Lengths of a and b are different. " - "The differential tail test could lose power.", - UserWarning, - ) - - both = np.concatenate([a, b]) - thresh = np.percentile(both, percentile) - za, zb = a * (a > thresh), b * (b > thresh) - - # If za and zb contain identical values, e.g., both za and zb are all zeros. - if len(set(np.concatenate((za, zb)))) == 1: - if alternative == "two-sided": - return np.nan, 1.0 - - else: - return np.nan, 0.5 - - statistic, pvalue = mannwhitneyu(za, zb, alternative=alternative) - - return statistic, pvalue From 826aebeb2ce049337918925f33cc2d354af2ee7a Mon Sep 17 00:00:00 2001 From: BeGeiger Date: Thu, 12 Oct 2023 17:47:37 -0400 Subject: [PATCH 05/13] rename plotting decorator --- .../nmf_framework/multimodal_corrnmf.py | 12 +++++----- src/salamander/nmf_framework/signature_nmf.py | 6 ++--- src/salamander/plot.py | 24 +++++++++---------- 3 files changed, 21 insertions(+), 21 deletions(-) diff --git a/src/salamander/nmf_framework/multimodal_corrnmf.py b/src/salamander/nmf_framework/multimodal_corrnmf.py index deb3114..af74549 100755 --- a/src/salamander/nmf_framework/multimodal_corrnmf.py +++ b/src/salamander/nmf_framework/multimodal_corrnmf.py @@ -20,8 +20,8 @@ from ..plot import ( corr_plot, - paper_style, pca_2d, + salamander_style, scatter_1d, scatter_2d, signatures_plot, @@ -431,7 +431,7 @@ def fit( return self - @paper_style + @salamander_style def plot_signatures( self, colors=None, @@ -468,7 +468,7 @@ def plot_signatures( return axes - @paper_style + @salamander_style def plot_exposures( self, reorder_signatures=True, @@ -538,7 +538,7 @@ def corr_signatures(self) -> pd.DataFrame: def corr_samples(self) -> pd.DataFrame: return self.models[0].corr_samples - @paper_style + @salamander_style def plot_correlation(self, data="signatures", annot=False, outfile=None, **kwargs): """ Plot the correlation matrix of the signatures or samples. @@ -579,7 +579,7 @@ def _get_embedding_annotations(self, annotate_signatures, annotate_samples): return annotations - @paper_style + @salamander_style def plot_embeddings( self, method="umap", @@ -679,7 +679,7 @@ def feature_change(self, in_modality=None, out_modalities="all", normalize=True) return results - @paper_style + @salamander_style def plot_feature_change( self, in_modality=None, diff --git a/src/salamander/nmf_framework/signature_nmf.py b/src/salamander/nmf_framework/signature_nmf.py index 6c54594..c0d536c 100644 --- a/src/salamander/nmf_framework/signature_nmf.py +++ b/src/salamander/nmf_framework/signature_nmf.py @@ -4,7 +4,7 @@ import numpy as np import pandas as pd -from ..plot import corr_plot, exposures_plot, paper_style, signatures_plot +from ..plot import corr_plot, exposures_plot, salamander_style, signatures_plot from ..utils import type_checker, value_checker @@ -303,7 +303,7 @@ def fit(self, data: pd.DataFrame, given_signatures=None): """ pass - @paper_style + @salamander_style def plot_signatures( self, catalog=None, @@ -335,7 +335,7 @@ def plot_signatures( return axes - @paper_style + @salamander_style def plot_exposures( self, reorder_signatures=True, diff --git a/src/salamander/plot.py b/src/salamander/plot.py index 1221713..26dec0a 100644 --- a/src/salamander/plot.py +++ b/src/salamander/plot.py @@ -17,7 +17,7 @@ from .utils import match_to_catalog -def paper_style(func): +def salamander_style(func): @wraps(func) def rc_wrapper(*args, **kwargs): sns.set_context("notebook") @@ -60,7 +60,7 @@ def _annotate_plot( ) -@paper_style +@salamander_style def scatter_1d( data: np.ndarray, annotations=None, annotation_kwargs=None, ax=None, **kwargs ): @@ -86,7 +86,7 @@ def scatter_1d( return ax -@paper_style +@salamander_style def scatter_2d(data, annotations=None, annotation_kwargs=None, ax=None, **kwargs): """ The rows (!) of 'data' are assumed to be the data points. @@ -108,7 +108,7 @@ def scatter_2d(data, annotations=None, annotation_kwargs=None, ax=None, **kwargs return ax -@paper_style +@salamander_style def pca_2d(data, annotations=None, annotation_kwargs=None, ax=None, **kwargs): """ The rows (!) of 'data' are assumed to be the data points. @@ -128,7 +128,7 @@ def pca_2d(data, annotations=None, annotation_kwargs=None, ax=None, **kwargs): return ax -@paper_style +@salamander_style def tsne_2d( data, perplexity=30, annotations=None, annotation_kwargs=None, ax=None, **kwargs ): @@ -153,7 +153,7 @@ def tsne_2d( return ax -@paper_style +@salamander_style def umap_2d( data, n_neighbors=15, @@ -185,7 +185,7 @@ def umap_2d( return ax -@paper_style +@salamander_style def plot_history(function_values, figtitle="", ax=None, **kwargs): if ax is None: ax = plt.gca() @@ -196,7 +196,7 @@ def plot_history(function_values, figtitle="", ax=None, **kwargs): return ax -@paper_style +@salamander_style def corr_plot( corr: pd.DataFrame, figsize=(6, 6), cmap="vlag", annot=True, fmt=".2f", **kwargs ): @@ -258,7 +258,7 @@ def _get_colors_signature_plot(colors, mutation_types): return colors -@paper_style +@salamander_style def _signature_plot( signature, colors=None, annotate_mutation_types=False, ax=None, **kwargs ): @@ -313,7 +313,7 @@ def _signature_plot( return ax -@paper_style +@salamander_style def signature_plot( signature, catalog=None, @@ -376,7 +376,7 @@ def signature_plot( return axes -@paper_style +@salamander_style def signatures_plot( signatures, catalog=None, @@ -459,7 +459,7 @@ def _reorder_exposures(exposures: pd.DataFrame, reorder_signatures=True): return exposures_reordered -@paper_style +@salamander_style def exposures_plot( exposures: pd.DataFrame, reorder_signatures=True, From e7f9dde8d044c72ee1f557a9856b0bdfa2738010 Mon Sep 17 00:00:00 2001 From: BeGeiger Date: Fri, 13 Oct 2023 16:13:38 -0400 Subject: [PATCH 06/13] refactor embedding plot --- poetry.lock | 6 +- src/salamander/nmf_framework/corrnmf.py | 123 ++---------------- .../nmf_framework/multimodal_corrnmf.py | 117 +++++++---------- src/salamander/nmf_framework/nmf.py | 106 ++------------- src/salamander/nmf_framework/signature_nmf.py | 96 +++++++++++++- src/salamander/plot.py | 83 +++++++++++- 6 files changed, 248 insertions(+), 283 deletions(-) diff --git a/poetry.lock b/poetry.lock index 434f571..b621e1c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -792,13 +792,13 @@ testing = ["pytest", "pytest-benchmark"] [[package]] name = "pre-commit" -version = "3.4.0" +version = "3.5.0" description = "A framework for managing and maintaining multi-language pre-commit hooks." optional = false python-versions = ">=3.8" files = [ - {file = "pre_commit-3.4.0-py2.py3-none-any.whl", hash = "sha256:96d529a951f8b677f730a7212442027e8ba53f9b04d217c4c67dc56c393ad945"}, - {file = "pre_commit-3.4.0.tar.gz", hash = "sha256:6bbd5129a64cad4c0dfaeeb12cd8f7ea7e15b77028d985341478c8af3c759522"}, + {file = "pre_commit-3.5.0-py2.py3-none-any.whl", hash = "sha256:841dc9aef25daba9a0238cd27984041fa0467b4199fc4852e27950664919f660"}, + {file = "pre_commit-3.5.0.tar.gz", hash = "sha256:5804465c675b659b0862f07907f96295d490822a450c4c40e747d0b1c6ebcb32"}, ] [package.dependencies] diff --git a/src/salamander/nmf_framework/corrnmf.py b/src/salamander/nmf_framework/corrnmf.py index ae3b2d4..40ebc1d 100644 --- a/src/salamander/nmf_framework/corrnmf.py +++ b/src/salamander/nmf_framework/corrnmf.py @@ -1,13 +1,10 @@ -import warnings from abc import abstractmethod -import matplotlib.pyplot as plt import numpy as np import pandas as pd from scipy.spatial.distance import squareform from scipy.special import gammaln -from ..plot import pca_2d, scatter_1d, scatter_2d, tsne_2d, umap_2d from ..utils import ( kl_divergence, match_signatures_pair, @@ -15,7 +12,6 @@ samplewise_kl_divergence, shape_checker, type_checker, - value_checker, ) from .initialization import ( init_custom, @@ -483,114 +479,19 @@ def reorder(self, other_signatures, metric="cosine", keep_names=False): return reordered_indices - def _get_embedding_annotations(self, annotate_signatures, annotate_samples): - # Only annotate with the first 20 characters of names - annotations = np.empty(self.n_signatures + self.n_samples, dtype="U20") - - if annotate_signatures: - annotations[: self.n_signatures] = self.signature_names - - if annotate_samples: - annotations[-self.n_samples :] = self.sample_names - - return annotations - - def plot_embeddings( - self, - method="umap", - annotate_signatures=True, - annotate_samples=False, - annotation_kwargs=None, - normalize=False, - ax=None, - outfile=None, - **kwargs, - ): + def _get_embedding_data(self) -> np.ndarray: """ - Plot the signature and sample embeddings. If the embedding dimension is two, - the embeddings will be plotted directly, ignoring the chosen method. - See plot.py for the implementation of scatter_2d, tsne_2d, pca_2d, umap_2d. - - Input: - ------ - methdod: str - Either 'tsne', 'pca' or 'umap'. The respective dimensionality reduction - will be applied to plot the signature and sample embeddings in 2D. - - annotate_signatures: bool - - annotate_samples: bool - - normalize: bool - Normalize the embeddings before applying the dimensionality reduction. - - *args, **kwargs: - arguments to be passed to scatter_2d, tsne_2d, pca_2d or umap_2d + In CorrNMF, the data for the embedding plot are the (transpoed) signature and + sample embeddings. """ - value_checker("method", method, ["pca", "tsne", "umap"]) - annotations = self._get_embedding_annotations( - annotate_signatures, annotate_samples - ) - - data = np.concatenate([self.L, self.U], axis=1).T - - if normalize: - data /= np.sum(data, axis=0) + return np.concatenate([self.L, self.U], axis=1).T.copy() - if self.dim_embeddings in [1, 2]: - warnings.warn( - f"The embedding dimension is {self.dim_embeddings}. " - f"The method argument '{method}' will be ignored " - "and the embeddings are plotted directly.", - UserWarning, - ) - - if self.dim_embeddings == 1: - ax = scatter_1d( - data[:, 0], - annotations=annotations, - annotation_kwargs=annotation_kwargs, - ax=ax, - **kwargs, - ) - - elif self.dim_embeddings == 2: - ax = scatter_2d( - data, - annotations=annotations, - annotation_kwargs=annotation_kwargs, - ax=ax, - **kwargs, - ) - - elif method == "tsne": - ax = tsne_2d( - data, - annotations=annotations, - annotation_kwargs=annotation_kwargs, - ax=ax, - **kwargs, - ) - - elif method == "pca": - ax = pca_2d( - data, - annotations=annotations, - annotation_kwargs=annotation_kwargs, - ax=ax, - **kwargs, - ) - - else: - ax = umap_2d( - data, - annotations=annotations, - annotation_kwargs=annotation_kwargs, - ax=ax, - **kwargs, - ) - - if outfile is not None: - plt.savefig(outfile, bbox_inches="tight") + def _get_default_embedding_annotations(self) -> np.ndarray: + """ + The embedding plot defaults to annotating the signature embeddings. + """ + # Only annotate with the first 20 characters of names + annotations = np.empty(self.n_signatures + self.n_samples, dtype="U20") + annotations[: self.n_signatures] = self.signature_names - return ax + return annotations diff --git a/src/salamander/nmf_framework/multimodal_corrnmf.py b/src/salamander/nmf_framework/multimodal_corrnmf.py index af74549..749c9dd 100755 --- a/src/salamander/nmf_framework/multimodal_corrnmf.py +++ b/src/salamander/nmf_framework/multimodal_corrnmf.py @@ -10,24 +10,13 @@ # are accessed. # pylint: disable=protected-access -import warnings - import matplotlib.pyplot as plt import numpy as np import pandas as pd from scipy import optimize from scipy.spatial.distance import squareform -from ..plot import ( - corr_plot, - pca_2d, - salamander_style, - scatter_1d, - scatter_2d, - signatures_plot, - tsne_2d, - umap_2d, -) +from ..plot import corr_plot, embeddings_plot, salamander_style, signatures_plot from ..utils import type_checker, value_checker from . import _utils_corrnmf from .corrnmf_det import CorrNMFDet @@ -186,8 +175,8 @@ def _update_alphas(self): model._update_alpha() def _update_sigma_sq(self): - embeddings = np.concatenate([model.L for model in self.models], axis=1) - embeddings = np.concatenate([embeddings, self.models[0].U], axis=1) + Ls = np.concatenate([model.L for model in self.models], axis=1) + embeddings = np.concatenate([Ls, self.models[0].U], axis=1) sigma_sq = np.mean(embeddings**2) sigma_sq = np.clip(sigma_sq, EPSILON, None) @@ -564,28 +553,22 @@ def plot_correlation(self, data="signatures", annot=False, outfile=None, **kwarg return clustergrid - def _get_embedding_annotations(self, annotate_signatures, annotate_samples): + def _get_default_embedding_annotations(self): # Only annotate with the first 20 characters of names annotations = np.empty(np.sum(self.ns_signatures) + self.n_samples, dtype="U20") - - if annotate_signatures: - signature_names = np.concatenate( - [model.signature_names for model in self.models] - ) - annotations[: len(signature_names)] = signature_names - - if annotate_samples: - annotations[-self.n_samples :] = self.models[0].sample_names + signature_names = np.concatenate( + [model.signature_names for model in self.models] + ) + annotations[: len(signature_names)] = signature_names return annotations - @salamander_style def plot_embeddings( self, method="umap", - annotate_signatures=True, - annotate_samples=False, normalize=False, + annotations=None, + annotation_kwargs=None, ax=None, outfile=None, **kwargs, @@ -595,55 +578,55 @@ def plot_embeddings( is two, the embeddings will be plotted directly, ignoring the chosen method. See plot.py for the implementation of scatter_2d, tsne_2d, pca_2d, umap_2d. - Input: - ------ - methdod: str + Parameters + ---------- + method : str, default='umap' Either 'tsne', 'pca' or 'umap'. The respective dimensionality reduction will be applied to plot the signature and sample embeddings in 2D space. - annotate_signatures: bool + normalize : bool, default=False + If True, normalize the embeddings before applying the dimensionality + reduction. - annotate_samples: bool + annotations : list[str], default=None + Annotations per data point, e.g. the sample names. If None, + all signatures are annotated. + Note that there are sum('ns_signatures') + 'n_samples' data points, + i.e. the first sum('ns_signatures') elements in 'annotations' + are the signature annotations, not any sample annotations. - normalize: bool - Normalize the embeddings before applying the dimensionality reduction. + annotation_kwargs : dict, default=None + keyword arguments to pass to matplotlibs plt.txt() - *args, **kwargs: - arguments to be passed to scatter_2d, tsne_2d, pca_2d or umap_2d - """ - value_checker("method", method, ["pca", "tsne", "umap"]) - annotations = self._get_embedding_annotations( - annotate_signatures, annotate_samples - ) - - Ls = np.concatenate([model.L for model in self.models], axis=1) - data = np.concatenate([Ls, self.models[0].U], axis=1).T - - if normalize: - data /= np.sum(data, axis=0) - - if self.dim_embeddings in [1, 2]: - warnings.warn( - f"The embedding dimension is {self.dim_embeddings}. " - f"The method argument '{method}' will be ignored " - "and the embeddings are plotted directly.", - UserWarning, - ) + ax : matplotlib.axes.Axes, default=None + Pre-existing axes for the plot. Otherwise, an axes is created. - if self.dim_embeddings == 1: - ax = scatter_1d(data[:, 0], annotations=annotations, ax=ax, **kwargs) + outfile : str, default=None + If not None, the figure will be saved in the specified file path. - elif self.dim_embeddings == 2: - ax = scatter_2d(data, annotations=annotations, ax=ax, **kwargs) + **kwargs : + keyword arguments to pass to seaborn's scatterplot - elif method == "tsne": - ax = tsne_2d(data, annotations=annotations, ax=ax, **kwargs) - - elif method == "pca": - ax = pca_2d(data, annotations=annotations, ax=ax, **kwargs) - - else: - ax = umap_2d(data, annotations=annotations, ax=ax, **kwargs) + Returns + ------- + ax : matplotlib.axes.Axes + The matplotlib axes containing the plot. + """ + Ls = np.concatenate([model.L for model in self.models], axis=1) + embedding_data = np.concatenate([Ls, self.models[0].U], axis=1).T.copy() + + if annotations is None: + annotations = self._get_default_embedding_annotations() + + ax = embeddings_plot( + embedding_data, + method, + normalize, + annotations, + annotation_kwargs, + ax, + **kwargs, + ) if outfile is not None: plt.savefig(outfile, bbox_inches="tight") diff --git a/src/salamander/nmf_framework/nmf.py b/src/salamander/nmf_framework/nmf.py index 4f916e9..e884e22 100644 --- a/src/salamander/nmf_framework/nmf.py +++ b/src/salamander/nmf_framework/nmf.py @@ -1,12 +1,9 @@ -import warnings from abc import abstractmethod -import matplotlib.pyplot as plt import numpy as np import pandas as pd -from ..plot import pca_2d, scatter_1d, scatter_2d, tsne_2d, umap_2d -from ..utils import match_signatures_pair, normalize_WH, value_checker +from ..utils import match_signatures_pair, normalize_WH from .initialization import ( init_custom, init_flat, @@ -249,98 +246,15 @@ def reorder(self, other_signatures, metric="cosine", keep_names=False): return reordered_indices - def _get_embedding_annotations(self, annotate_samples) -> np.ndarray: - # Only annotate with the first 20 characters of names - annotations = np.empty(self.n_samples, dtype="U20") - - if annotate_samples: - annotations[:] = self.sample_names - - return annotations - - def plot_embeddings( - self, - method="umap", - annotate_samples=False, - annotation_kwargs=None, - ax=None, - outfile=None, - **kwargs, - ): + def _get_embedding_data(self): """ - Plot the sample embeddings using the exposure matrix H. - If the embedding dimension is set to two, the embeddings will - be plotted directly, ignoring the method chosen. - See plot.py for the implementation of scatter_2d, tsne_2d, pca_2d, umap_2d. - - Input: - ------ - methdod: str - Either 'tsne', 'pca' or 'umap'. The respective dimensionality reduction - will be applied to plot signature and sample embeddings in 2D. - - **kwargs: - Arguments to be passed to scatter_2d, tsne_2d, pca_2d or umap_2d + In most NMF models like KL-NMF or mvNMF, the data for the embedding plot + are just the (transposed) exposures. """ - value_checker("method", method, ["pca", "tsne", "umap"]) - - data = self.H.T - annotations = self._get_embedding_annotations(annotate_samples) - - if self.n_signatures in [1, 2]: - warnings.warn( - f"The number of signatures is {self.n_signatures}. " - f"The method argument '{method}' will be ignored " - "and the embeddings are plotted directly.", - UserWarning, - ) - - if self.n_signatures == 1: - ax = scatter_1d( - data[:, 0], - annotations=annotations, - annotation_kwargs=annotation_kwargs, - ax=ax, - **kwargs, - ) - - elif self.n_signatures == 2: - ax = scatter_2d( - data, - annotations=annotations, - annotation_kwargs=annotation_kwargs, - ax=ax, - **kwargs, - ) - - elif method == "tsne": - ax = tsne_2d( - data, - annotations=annotations, - annotation_kwargs=annotation_kwargs, - ax=ax, - **kwargs, - ) - - elif method == "pca": - ax = pca_2d( - data, - annotations=annotations, - annotation_kwargs=annotation_kwargs, - ax=ax, - **kwargs, - ) + return self.H.T.copy() - else: - ax = umap_2d( - data, - annotations=annotations, - annotation_kwargs=annotation_kwargs, - ax=ax, - **kwargs, - ) - - if outfile is not None: - plt.savefig(outfile, bbox_inches="tight") - - return ax + def _get_default_embedding_annotations(self) -> np.ndarray: + """ + The embedding plot defaults to no annotations. + """ + return np.empty(self.n_samples, dtype=str) diff --git a/src/salamander/nmf_framework/signature_nmf.py b/src/salamander/nmf_framework/signature_nmf.py index c0d536c..0b2c6e7 100644 --- a/src/salamander/nmf_framework/signature_nmf.py +++ b/src/salamander/nmf_framework/signature_nmf.py @@ -4,7 +4,13 @@ import numpy as np import pandas as pd -from ..plot import corr_plot, exposures_plot, salamander_style, signatures_plot +from ..plot import ( + corr_plot, + embeddings_plot, + exposures_plot, + salamander_style, + signatures_plot, +) from ..utils import type_checker, value_checker @@ -414,8 +420,90 @@ def plot_correlation(self, data="signatures", annot=False, outfile=None, **kwarg return clustergrid @abstractmethod - def plot_embeddings(self): + def _get_embedding_data(self) -> np.ndarray: """ - Plot the sample (and potentially the signature) embeddings in 2D. + Get the data points for the dimensionality reduction / embedding plot. + One data point corresponds to a row of the embedding data. + Usually, these are the transposed exposures. """ - pass + + @abstractmethod + def _get_default_embedding_annotations(self) -> np.ndarray: + """ + Get the annotations of the data points in the embedding plot. + """ + + def plot_embeddings( + self, + method="umap", + normalize=False, + annotations=None, + annotation_kwargs=None, + ax=None, + outfile=None, + **kwargs, + ): + """ + Plot a dimensionality reduction of the exposure representation. + In most NMF algorithms, this is just the exposures of the samples. + In CorrNMF, the exposures matrix is refactored, and there are both + sample and signature exposures in a shared embedding space. + + If the embedding dimension is one or two, the embeddings are be plotted + directly, ignoring the chosen method. + See plot.py for the implementation of scatter_2d, tsne_2d, pca_2d, umap_2d. + + Parameters + ---------- + method : str, default='umap' + Either 'tsne', 'pca' or 'umap'. The respective dimensionality reduction + will be applied to plot the data in 2D space. + + normalize : bool, default=False + If True, normalize the data before applying the dimensionality reduction. + + annotations : list[str], default=None + Annotations per data point, e.g. the sample names. If None, + the algorithm-specific default annotations are used. + For example, CorrNMF annotates the signature embeddings by default. + Note that there are 'n_signatures' + 'n_samples' data points in CorrNMF, + i.e. the first 'n_signatures' elements in 'annotations' + are the signature annotations, not any sample annotations. + + annotation_kwargs : dict, default=None + keyword arguments to pass to matplotlibs plt.txt() + + ax : matplotlib.axes.Axes, default=None + Pre-existing axes for the plot. Otherwise, an axes is created. + + outfile : str, default=None + If not None, the figure will be saved in the specified file path. + + **kwargs : + keyword arguments to pass to seaborn's scatterplot + + Returns + ------- + ax : matplotlib.axes.Axes + The matplotlib axes containing the plot. + """ + # one data point corresponds to a row of embedding_data + embedding_data = self._get_embedding_data() + + if annotations is None: + annotations = self._get_default_embedding_annotations() + + ax = embeddings_plot( + embedding_data, + method, + normalize, + annotations, + annotation_kwargs, + ax, + **kwargs, + ) + + if outfile is not None: + plt.savefig(outfile, bbox_inches="tight") + + return ax diff --git a/src/salamander/plot.py b/src/salamander/plot.py index 26dec0a..8beea5d 100644 --- a/src/salamander/plot.py +++ b/src/salamander/plot.py @@ -14,7 +14,7 @@ from sklearn.manifold import TSNE from .consts import COLORS_INDEL83, COLORS_SBS96, INDEL_TYPES_83, SBS_TYPES_96 -from .utils import match_to_catalog +from .utils import match_to_catalog, value_checker def salamander_style(func): @@ -37,7 +37,6 @@ def rc_wrapper(*args, **kwargs): "xtick.labelsize": 12, "ytick.labelsize": 12, } - mpl.rcParams.update(params) return func(*args, **kwargs) @@ -185,6 +184,86 @@ def umap_2d( return ax +@salamander_style +def embeddings_plot( + data: np.ndarray, + method="umap", + normalize=False, + annotations=None, + annotation_kwargs=None, + ax=None, + **kwargs, +): + """ + The rows (!) of 'data' are assumed to be the single data points. + """ + value_checker("method", method, ["pca", "tsne", "umap"]) + + if normalize: + data /= data.sum(axis=1)[:, np.newaxis] + + if ax is None: + _, ax = plt.subplots(figsize=(6, 6)) + + annotation_kwargs = {} if annotation_kwargs is None else annotation_kwargs.copy() + n_dimensions = data.shape[0] + + if n_dimensions in [1, 2]: + warnings.warn( + f"The dimension of the data points is {n_dimensions}. " + f"The method argument '{method}' will be ignored " + "and the embeddings are plotted directly.", + UserWarning, + ) + + if n_dimensions == 1: + ax = scatter_1d( + data[:, 0], + annotations=annotations, + annotation_kwargs=annotation_kwargs, + ax=ax, + **kwargs, + ) + + elif n_dimensions == 2: + ax = scatter_2d( + data, + annotations=annotations, + annotation_kwargs=annotation_kwargs, + ax=ax, + **kwargs, + ) + + elif method == "tsne": + ax = tsne_2d( + data, + annotations=annotations, + annotation_kwargs=annotation_kwargs, + ax=ax, + **kwargs, + ) + + elif method == "pca": + ax = pca_2d( + data, + annotations=annotations, + annotation_kwargs=annotation_kwargs, + ax=ax, + **kwargs, + ) + + else: + ax = umap_2d( + data, + annotations=annotations, + annotation_kwargs=annotation_kwargs, + ax=ax, + **kwargs, + ) + + return ax + + @salamander_style def plot_history(function_values, figtitle="", ax=None, **kwargs): if ax is None: From 703fa57b08a6be42e11ec1c1a2e3ecdeedd80bac Mon Sep 17 00:00:00 2001 From: BeGeiger Date: Sat, 14 Oct 2023 22:53:12 -0400 Subject: [PATCH 07/13] add jit optimization mvNMF Additional changes: - the mvNMF implementation is now identical to the pseudocode in the original paper - slightly improved the documentation --- src/salamander/nmf_framework/mvnmf.py | 249 ++++++++++++++------------ src/salamander/utils.py | 1 + tests/test_mvnmf.py | 6 +- 3 files changed, 139 insertions(+), 117 deletions(-) diff --git a/src/salamander/nmf_framework/mvnmf.py b/src/salamander/nmf_framework/mvnmf.py index ddc8ecd..edf367e 100644 --- a/src/salamander/nmf_framework/mvnmf.py +++ b/src/salamander/nmf_framework/mvnmf.py @@ -1,5 +1,6 @@ import numpy as np import pandas as pd +from numba import njit from ..utils import kl_divergence, normalize_WH, poisson_llh, samplewise_kl_divergence from .nmf import NMF @@ -7,70 +8,142 @@ EPSILON = np.finfo(np.float32).eps +@njit +def volume_logdet(W: np.ndarray, delta: float) -> float: + n_signatures = W.shape[1] + diag = np.diag(np.full(n_signatures, delta)) + volume = np.log(np.linalg.det(W.T @ W + diag)) + + return volume + + +@njit +def kl_divergence_penalized( + X: np.ndarray, W: np.ndarray, H: np.ndarray, lam: float, delta: float +) -> float: + reconstruction_error = kl_divergence(X, W, H) + volume = volume_logdet(W, delta) + loss = reconstruction_error + lam * volume + + return loss + + +@njit +def update_H(X: np.ndarray, W: np.ndarray, H: np.ndarray) -> np.ndarray: + """ + The multiplicative update rule of the exposure matrix H + derived by Lee and Seung. See Theorem 2 in + "Algorithms for non-negative matrix factorization". + + Clipping the matrix avoids floating point errors. + """ + H *= W.T @ (X / (W @ H)) + H /= np.sum(W, axis=0)[:, np.newaxis] + H = H.clip(EPSILON) + + return H + + +@njit +def update_W_unconstrained( + X: np.ndarray, W: np.ndarray, H: np.ndarray, lam: float, delta: float +) -> np.ndarray: + n_signatures = W.shape[1] + diag = np.diag(np.full(n_signatures, delta)) + Y = np.linalg.inv(W.T @ W + diag) + Y_minus = np.maximum(0, -Y) + Y_abs = np.abs(Y) + WY_minus = W @ Y_minus + WY_abs = W @ Y_abs + + rowsums_H = np.sum(H, axis=1) + discriminant_s1 = (rowsums_H - 4 * lam * WY_minus) ** 2 + discriminant_s2 = 8 * lam * WY_abs * ((X / (W @ H)) @ H.T) + numerator_s1 = np.sqrt(discriminant_s1 + discriminant_s2) + numerator_s2 = -rowsums_H + 4 * lam * WY_minus + numerator = numerator_s1 + numerator_s2 + denominator = 4 * lam * WY_abs + W_unconstrained = W * numerator / denominator + W_unconstrained = W_unconstrained.clip(EPSILON) + + return W_unconstrained + + +@njit +def line_search( + X: np.ndarray, + W: np.ndarray, + H: np.ndarray, + lam: float, + delta: float, + gamma: float, + W_unconstrained: np.ndarray, +) -> tuple[np.ndarray, np.ndarray, float]: + prev_of_value = kl_divergence_penalized(X, W, H, lam, delta) + W_new, H_new = normalize_WH(W_unconstrained, H) + W_new, H_new = W_new.clip(EPSILON), H_new.clip(EPSILON) + of_value = kl_divergence_penalized(X, W_new, H_new, lam, delta) + + while of_value > prev_of_value and gamma > 1e-16: + gamma *= 0.8 + W_new = (1 - gamma) * W + gamma * W_unconstrained + W_new, H_new = normalize_WH(W_new, H) + W_new, H_new = W_new.clip(EPSILON), H_new.clip(EPSILON) + of_value = kl_divergence_penalized(X, W_new, H_new, lam, delta) + + gamma = min(1.0, 1.2 * gamma) + + return W_new, H_new, gamma + + class MvNMF(NMF): """ - Min-volume non-negative matrix factorization. Based on Algorithm 1 of + Min-volume non-negative matrix factorization. See Algorithm 1 in Leplat, V., Gillis, N. and Ang, A.M., 2020. Blind audio source separation with minimum-volume beta-divergence NMF. IEEE Transactions on Signal Processing, 68, pp.3400-3410. - Input: - ------ + Parameters + ---------- n_signatures: int Number of signatures to decipher. - init_method: str + init_method : str, default=nndsvd One of "custom", "flat", "hierarchical_cluster", "nndsvd", "nndsvda", "nndsvdar" "random" and "separableNMF". Please see the initialization module for further details on each method. - lambda_tilde : float - Objective function hyperparameter. + lam : float, default=1.0 + Objective function volume penalty weight. - delta : float + delta : float, default=1.0 Objective function hyperparameter. - min_iterations : int, default=200 + min_iterations : int, default=500 Minimum number of iterations. - max_iterations : int, default=400 + max_iterations : int, default=10000 Maximum number of iterations. tol : float, default=1e-7 Tolerance of the stopping condition. - - Note: - ----- - The algorithm should work better when the initial guesses are better. - One reason lies in lambda and lambda_tilde. - Lambda is calculated in a way such that the two terms - in the objective function are comparable. Ideally, lambda should be set to - kl_divergence(X, W_true @ H_true)/abs(volume(W_true)) * lambda_tilde. - In our code, the true W and H are replaced by the initial guesses. - So if the initial guesses are good, then indeed the two terms will be - comparable. If the initial guesses are far off, then the kl_divergence - part will be far over-estimated. As a result, the two terms are - not comparable anymore. One potential improvement is to first run - a small number of NMF iterations, and then use the NMF results - as hot starts for the mvNMF algorithm. """ def __init__( self, n_signatures=1, init_method="nndsvd", - lambda_tilde=1e-5, + lam=1.0, delta=1.0, min_iterations=500, max_iterations=10000, tol=1e-7, ): super().__init__(n_signatures, init_method, min_iterations, max_iterations, tol) - self.lambda_tilde = lambda_tilde - self.lam = lambda_tilde + self.lam = lam self.delta = delta - self.gamma = 1.0 + self._gamma = None @property def reconstruction_error(self): @@ -80,26 +153,8 @@ def reconstruction_error(self): def samplewise_reconstruction_error(self): return samplewise_kl_divergence(self.X, self.W, self.H) - @staticmethod - def _volume_logdet(W, delta) -> float: - n_signatures = W.shape[1] - diag = delta * np.identity(n_signatures) - volume = np.log(np.linalg.det(W.T @ W + diag)) - - return volume - - @staticmethod - def _objective_function( - X: np.ndarray, W: np.ndarray, H: np.ndarray, lam: float, delta: float - ) -> float: - reconstruction_error = kl_divergence(X, W, H) - volume = MvNMF._volume_logdet(W, delta) - loss = reconstruction_error + lam * volume - - return loss - def objective_function(self): - return self._objective_function(self.X, self.W, self.H, self.lam, self.delta) + return kl_divergence_penalized(self.X, self.W, self.H, self.lam, self.delta) @property def objective(self) -> str: @@ -109,69 +164,25 @@ def loglikelihood(self) -> float: return poisson_llh(self.X, self.W, self.H) def _update_H(self): - self.H *= self.W.T @ (self.X / (self.W @ self.H)) - self.H /= np.sum(self.W, axis=0)[:, np.newaxis] - self.H = self.H.clip(EPSILON) + self.H = update_H(self.X, self.W, self.H) def _update_W_unconstrained(self): - diag = np.diag(np.full(self.n_signatures, self.delta)) - Y = np.linalg.inv(self.W.T @ self.W + diag) - - Y_minus = np.maximum(0, -Y) - Y_abs = np.abs(Y) - - WY_minus = self.W @ Y_minus - WY_abs = self.W @ Y_abs - - rowsums_H = np.sum(self.H, axis=1) - - discriminant_s1 = (rowsums_H - 4 * self.lam * WY_minus) ** 2 - discriminant_s2 = ( - 8 * self.lam * WY_abs * ((self.X / (self.W @ self.H)) @ self.H.T) + return update_W_unconstrained(self.X, self.W, self.H, self.lam, self.delta) + + def _line_search(self, W_unconstrained): + self.W, self.H, self._gamma = line_search( + self.X, + self.W, + self.H, + self.lam, + self.delta, + self._gamma, + W_unconstrained, ) - numerator_s1 = np.sqrt(discriminant_s1 + discriminant_s2) - numerator_s2 = -rowsums_H + 4 * self.lam * WY_minus - numerator = numerator_s1 + numerator_s2 - - denominator = 4 * self.lam * WY_abs - - W_uc = self.W * numerator / denominator - W_uc = W_uc.clip(EPSILON) - - return W_uc - - def _line_search(self, W_uc, loss_prev): - W_new = self.W + self.gamma * (W_uc - self.W) - W_new, H_new = normalize_WH(W_new, self.H) - W_new, H_new = W_new.clip(EPSILON), H_new.clip(EPSILON) - - loss = self._objective_function(self.X, W_new, H_new, self.lam, self.delta) - - while loss > loss_prev and self.gamma > 1e-16: - self.gamma *= 0.8 - - W_new = self.W + self.gamma * (W_uc - self.W) - W_new, H_new = normalize_WH(W_new, self.H) - W_new, H_new = W_new.clip(EPSILON), H_new.clip(EPSILON) - - loss = self._objective_function(self.X, W_new, H_new, self.lam, self.delta) - - self.gamma = min(1.0, 2 * self.gamma) - self.W, self.H = W_new, H_new - - # pylint: disable-next=W0221 - def _update_W(self, loss_prev): - W_uc = self._update_W_unconstrained() - self._line_search(W_uc, loss_prev) - - def _initialize_mvnmf_parameters(self): - # lambda is chosen s.t. both loss summands - # approximately contribute equally for lambda_tilde = 1 - init_reconstruction_error = self.reconstruction_error - init_volume = self._volume_logdet(self.W, self.delta) - self.lam = self.lambda_tilde * init_reconstruction_error / abs(init_volume) - self.gamma = 1.0 + def _update_W(self): + W_unconstrained = self._update_W_unconstrained() + self._line_search(W_unconstrained) def fit( self, @@ -182,23 +193,33 @@ def fit( verbose=0, ): """ - Input: - ------ - data : array-like of shape (n_features, n_samples) - The mutation count data. + Parameters + ---------- + data : pd.DataFrame + Input count data. + + given_signatures : pd.DataFrame, default=None + Signatures to fix during the inference. init_kwargs: dict Any further keyword arguments to be passed to the initialization method. This includes, for example, a possible 'seed' keyword argument for all stochastic methods. + history : bool, default=True + If true, the objective function value of each iteration is saved. + verbose : int, default=0 Verbosity level. + + Returns + ------- + self : object + Returns the instance itself. """ self._setup_data_parameters(data) self._initialize(given_signatures, init_kwargs) - self._initialize_mvnmf_parameters() - + self._gamma = 1.0 of_values = [self.objective_function()] n_iteration = 0 converged = False @@ -210,11 +231,11 @@ def fit( print(f"iteration {n_iteration}") self._update_H() - prev_of_value = of_values[-1] if given_signatures is None: - self._update_W(prev_of_value) + self._update_W() + prev_of_value = of_values[-1] of_values.append(self.objective_function()) rel_change = (prev_of_value - of_values[-1]) / prev_of_value converged = ( diff --git a/src/salamander/utils.py b/src/salamander/utils.py index 96e9b46..7fd2efa 100644 --- a/src/salamander/utils.py +++ b/src/salamander/utils.py @@ -142,6 +142,7 @@ def poisson_llh(X: np.ndarray, W: np.ndarray, H: np.ndarray) -> float: return result +@njit def normalize_WH(W, H): normalization_factor = np.sum(W, axis=0) return W / normalization_factor, H * normalization_factor[:, None] diff --git a/tests/test_mvnmf.py b/tests/test_mvnmf.py index 1787b0f..412957f 100644 --- a/tests/test_mvnmf.py +++ b/tests/test_mvnmf.py @@ -40,7 +40,7 @@ def model_init(model, counts, W_init, H_init): model.H = H_init model.lam = 1.0 model.delta = 1.0 - model.gamma = 1.0 + model._gamma = 1.0 return model @@ -63,8 +63,8 @@ class TestMVNMF: def test_objective_function(self, model_init, objective_init): assert np.allclose(model_init.objective_function(), objective_init) - def test_update_W(self, model_init, objective_init, W_updated): - model_init._update_W(objective_init) + def test_update_W(self, model_init, W_updated): + model_init._update_W() assert np.allclose(model_init.W, W_updated) def test_update_H(self, model_init, H_updated): From a088aa841b29fcbe6b1042c8a208d74902beff8a Mon Sep 17 00:00:00 2001 From: BeGeiger Date: Fri, 20 Oct 2023 14:19:55 -0400 Subject: [PATCH 08/13] add variable given signatures Every algorithm now supports a variable amount of given signatures less or equal to the number of signatures specified in the algorithm instance. I also restructured the testing code and renamed the testing data files. --- poetry.lock | 110 ++++---- .../nmf_framework/_utils_corrnmf.py | 36 --- src/salamander/nmf_framework/_utils_klnmf.py | 264 ++++++++++++++++++ src/salamander/nmf_framework/corrnmf.py | 31 +- src/salamander/nmf_framework/corrnmf_det.py | 14 +- src/salamander/nmf_framework/klnmf.py | 151 +++++----- .../nmf_framework/multimodal_corrnmf.py | 16 +- src/salamander/nmf_framework/mvnmf.py | 51 ++-- src/salamander/nmf_framework/nmf.py | 19 +- src/salamander/nmf_framework/signature_nmf.py | 34 +-- src/salamander/utils.py | 74 ----- tests/test_corrnmf.py | 191 +++++++------ ...dim1_L_init.npy => L_init_nsigs1_dim1.npy} | Bin ...dim2_L_init.npy => L_init_nsigs2_dim2.npy} | Bin ..._updated.npy => L_updated_nsigs1_dim1.npy} | Bin ..._updated.npy => L_updated_nsigs2_dim2.npy} | Bin ...dim1_U_init.npy => U_init_nsigs1_dim1.npy} | Bin ...dim2_U_init.npy => U_init_nsigs2_dim2.npy} | Bin ..._updated.npy => U_updated_nsigs1_dim1.npy} | Bin ..._updated.npy => U_updated_nsigs2_dim2.npy} | Bin ...dim1_W_init.npy => W_init_nsigs1_dim1.npy} | Bin ...dim2_W_init.npy => W_init_nsigs2_dim2.npy} | Bin ..._updated.npy => W_updated_nsigs1_dim1.npy} | Bin ..._updated.npy => W_updated_nsigs2_dim2.npy} | Bin ...ha_init.npy => alpha_init_nsigs1_dim1.npy} | Bin ...ha_init.npy => alpha_init_nsigs2_dim2.npy} | Bin ...ated.npy => alpha_updated_nsigs1_dim1.npy} | Bin ...ated.npy => alpha_updated_nsigs2_dim2.npy} | Bin .../nmf_framework/corrnmf/counts.csv | 97 +++++++ ...nit.npy => objective_init_nsigs1_dim1.npy} | Bin ...nit.npy => objective_init_nsigs2_dim2.npy} | Bin ...mf_nsigs1_dim1_p.npy => p_nsigs1_dim1.npy} | Bin ...mf_nsigs2_dim2_p.npy => p_nsigs2_dim2.npy} | Bin ...init.npy => sigma_sq_init_nsigs1_dim1.npy} | Bin ...init.npy => sigma_sq_init_nsigs2_dim2.npy} | Bin ...d.npy => sigma_sq_updated_nsigs1_dim1.npy} | Bin ...d.npy => sigma_sq_updated_nsigs2_dim2.npy} | Bin ... surrogate_objective_init_nsigs1_dim1.npy} | Bin ... surrogate_objective_init_nsigs2_dim2.npy} | Bin ...mf_nsigs1_H_init.npy => H_init_nsigs1.npy} | Bin ...mf_nsigs2_H_init.npy => H_init_nsigs2.npy} | Bin .../klnmf/WH_updated_mu-joint_nsigs1.pkl | Bin 0 -> 1007 bytes .../klnmf/WH_updated_mu-joint_nsigs2.pkl | Bin 0 -> 1857 bytes .../klnmf/WH_updated_mu-standard_nsigs1.pkl | Bin 0 -> 1007 bytes .../klnmf/WH_updated_mu-standard_nsigs2.pkl | Bin 0 -> 1857 bytes ...mf_nsigs1_W_init.npy => W_init_nsigs1.npy} | Bin ...mf_nsigs2_W_init.npy => W_init_nsigs2.npy} | Bin .../test_data/nmf_framework/klnmf/counts.csv | 97 +++++++ .../klnmf/klnmf_nsigs1_W_updated.npy | Bin 896 -> 0 bytes .../klnmf/klnmf_nsigs2_H_updated.npy | Bin 288 -> 0 bytes .../klnmf/klnmf_nsigs2_W_updated.npy | Bin 1664 -> 0 bytes ...ive_init.npy => objective_init_nsigs1.npy} | Bin ...ive_init.npy => objective_init_nsigs2.npy} | Bin ...mf_nsigs1_H_init.npy => H_init_nsigs1.npy} | Bin ...mf_nsigs2_H_init.npy => H_init_nsigs2.npy} | Bin ...gs1_H_updated.npy => H_updated_nsigs1.npy} | Bin ...gs2_H_updated.npy => H_updated_nsigs2.npy} | Bin ...mf_nsigs1_W_init.npy => W_init_nsigs1.npy} | Bin ...mf_nsigs2_W_init.npy => W_init_nsigs2.npy} | Bin ...gs1_W_updated.npy => W_updated_nsigs1.npy} | Bin ...gs2_W_updated.npy => W_updated_nsigs2.npy} | Bin .../test_data/nmf_framework/mvnmf/counts.csv | 97 +++++++ ...ive_init.npy => objective_init_nsigs1.npy} | Bin ...ive_init.npy => objective_init_nsigs2.npy} | Bin .../utils_klnmf/H_nsigs1.npy} | Bin .../utils_klnmf/H_nsigs2.npy} | Bin .../H_updated_mu-joint_nsigs1.npy} | Bin 208 -> 208 bytes .../utils_klnmf/H_updated_mu-joint_nsigs2.npy | Bin 0 -> 288 bytes .../H_updated_mu-standard_nsigs1.npy | Bin 0 -> 208 bytes .../H_updated_mu-standard_nsigs2.npy | Bin 0 -> 288 bytes .../utils_klnmf/W_nsigs1.npy} | Bin .../utils_klnmf/W_nsigs2.npy} | Bin .../utils_klnmf/W_updated_mu-joint_nsigs1.npy | Bin 0 -> 896 bytes .../utils_klnmf/W_updated_mu-joint_nsigs2.npy | Bin 0 -> 1664 bytes .../W_updated_mu-standard_nsigs1.npy | Bin 0 -> 896 bytes .../W_updated_mu-standard_nsigs2.npy | Bin 0 -> 1664 bytes .../utils_klnmf}/counts.csv | 0 .../utils_klnmf/kl_divergence_nsigs1.npy} | Bin .../utils_klnmf/kl_divergence_nsigs2.npy} | Bin .../utils_klnmf/poisson_llh_nsigs1.npy} | Bin .../utils_klnmf/poisson_llh_nsigs2.npy} | Bin .../samplewise_kl_divergence_nsigs1.npy} | Bin .../samplewise_kl_divergence_nsigs2.npy} | Bin tests/test_klnmf.py | 82 +++--- tests/test_multimodal_corrnmf.py | 187 +++++++------ tests/test_mvnmf.py | 67 +++-- tests/test_utils.py | 62 ---- tests/test_utils_klnmf.py | 127 +++++++++ 88 files changed, 1180 insertions(+), 627 deletions(-) create mode 100644 src/salamander/nmf_framework/_utils_klnmf.py rename tests/test_data/nmf_framework/corrnmf/{corrnmf_nsigs1_dim1_L_init.npy => L_init_nsigs1_dim1.npy} (100%) rename tests/test_data/nmf_framework/corrnmf/{corrnmf_nsigs2_dim2_L_init.npy => L_init_nsigs2_dim2.npy} (100%) rename tests/test_data/nmf_framework/corrnmf/{corrnmf_nsigs1_dim1_L_updated.npy => L_updated_nsigs1_dim1.npy} (100%) rename tests/test_data/nmf_framework/corrnmf/{corrnmf_nsigs2_dim2_L_updated.npy => L_updated_nsigs2_dim2.npy} (100%) rename tests/test_data/nmf_framework/corrnmf/{corrnmf_nsigs1_dim1_U_init.npy => U_init_nsigs1_dim1.npy} (100%) rename tests/test_data/nmf_framework/corrnmf/{corrnmf_nsigs2_dim2_U_init.npy => U_init_nsigs2_dim2.npy} (100%) rename tests/test_data/nmf_framework/corrnmf/{corrnmf_nsigs1_dim1_U_updated.npy => U_updated_nsigs1_dim1.npy} (100%) rename tests/test_data/nmf_framework/corrnmf/{corrnmf_nsigs2_dim2_U_updated.npy => U_updated_nsigs2_dim2.npy} (100%) rename tests/test_data/nmf_framework/corrnmf/{corrnmf_nsigs1_dim1_W_init.npy => W_init_nsigs1_dim1.npy} (100%) rename tests/test_data/nmf_framework/corrnmf/{corrnmf_nsigs2_dim2_W_init.npy => W_init_nsigs2_dim2.npy} (100%) rename tests/test_data/nmf_framework/corrnmf/{corrnmf_nsigs1_dim1_W_updated.npy => W_updated_nsigs1_dim1.npy} (100%) rename tests/test_data/nmf_framework/corrnmf/{corrnmf_nsigs2_dim2_W_updated.npy => W_updated_nsigs2_dim2.npy} (100%) rename tests/test_data/nmf_framework/corrnmf/{corrnmf_nsigs1_dim1_alpha_init.npy => alpha_init_nsigs1_dim1.npy} (100%) rename tests/test_data/nmf_framework/corrnmf/{corrnmf_nsigs2_dim2_alpha_init.npy => alpha_init_nsigs2_dim2.npy} (100%) rename tests/test_data/nmf_framework/corrnmf/{corrnmf_nsigs1_dim1_alpha_updated.npy => alpha_updated_nsigs1_dim1.npy} (100%) rename tests/test_data/nmf_framework/corrnmf/{corrnmf_nsigs2_dim2_alpha_updated.npy => alpha_updated_nsigs2_dim2.npy} (100%) create mode 100644 tests/test_data/nmf_framework/corrnmf/counts.csv rename tests/test_data/nmf_framework/corrnmf/{corrnmf_nsigs1_dim1_objective_init.npy => objective_init_nsigs1_dim1.npy} (100%) rename tests/test_data/nmf_framework/corrnmf/{corrnmf_nsigs2_dim2_objective_init.npy => objective_init_nsigs2_dim2.npy} (100%) rename tests/test_data/nmf_framework/corrnmf/{corrnmf_nsigs1_dim1_p.npy => p_nsigs1_dim1.npy} (100%) rename tests/test_data/nmf_framework/corrnmf/{corrnmf_nsigs2_dim2_p.npy => p_nsigs2_dim2.npy} (100%) rename tests/test_data/nmf_framework/corrnmf/{corrnmf_nsigs1_dim1_sigma_sq_init.npy => sigma_sq_init_nsigs1_dim1.npy} (100%) rename tests/test_data/nmf_framework/corrnmf/{corrnmf_nsigs2_dim2_sigma_sq_init.npy => sigma_sq_init_nsigs2_dim2.npy} (100%) rename tests/test_data/nmf_framework/corrnmf/{corrnmf_nsigs1_dim1_sigma_sq_updated.npy => sigma_sq_updated_nsigs1_dim1.npy} (100%) rename tests/test_data/nmf_framework/corrnmf/{corrnmf_nsigs2_dim2_sigma_sq_updated.npy => sigma_sq_updated_nsigs2_dim2.npy} (100%) rename tests/test_data/nmf_framework/corrnmf/{corrnmf_nsigs1_dim1_surrogate_objective_init.npy => surrogate_objective_init_nsigs1_dim1.npy} (100%) rename tests/test_data/nmf_framework/corrnmf/{corrnmf_nsigs2_dim2_surrogate_objective_init.npy => surrogate_objective_init_nsigs2_dim2.npy} (100%) rename tests/test_data/nmf_framework/klnmf/{klnmf_nsigs1_H_init.npy => H_init_nsigs1.npy} (100%) rename tests/test_data/nmf_framework/klnmf/{klnmf_nsigs2_H_init.npy => H_init_nsigs2.npy} (100%) create mode 100644 tests/test_data/nmf_framework/klnmf/WH_updated_mu-joint_nsigs1.pkl create mode 100644 tests/test_data/nmf_framework/klnmf/WH_updated_mu-joint_nsigs2.pkl create mode 100644 tests/test_data/nmf_framework/klnmf/WH_updated_mu-standard_nsigs1.pkl create mode 100644 tests/test_data/nmf_framework/klnmf/WH_updated_mu-standard_nsigs2.pkl rename tests/test_data/nmf_framework/klnmf/{klnmf_nsigs1_W_init.npy => W_init_nsigs1.npy} (100%) rename tests/test_data/nmf_framework/klnmf/{klnmf_nsigs2_W_init.npy => W_init_nsigs2.npy} (100%) create mode 100644 tests/test_data/nmf_framework/klnmf/counts.csv delete mode 100644 tests/test_data/nmf_framework/klnmf/klnmf_nsigs1_W_updated.npy delete mode 100644 tests/test_data/nmf_framework/klnmf/klnmf_nsigs2_H_updated.npy delete mode 100644 tests/test_data/nmf_framework/klnmf/klnmf_nsigs2_W_updated.npy rename tests/test_data/nmf_framework/klnmf/{klnmf_nsigs1_objective_init.npy => objective_init_nsigs1.npy} (100%) rename tests/test_data/nmf_framework/klnmf/{klnmf_nsigs2_objective_init.npy => objective_init_nsigs2.npy} (100%) rename tests/test_data/nmf_framework/mvnmf/{mvnmf_nsigs1_H_init.npy => H_init_nsigs1.npy} (100%) rename tests/test_data/nmf_framework/mvnmf/{mvnmf_nsigs2_H_init.npy => H_init_nsigs2.npy} (100%) rename tests/test_data/nmf_framework/mvnmf/{mvnmf_nsigs1_H_updated.npy => H_updated_nsigs1.npy} (100%) rename tests/test_data/nmf_framework/mvnmf/{mvnmf_nsigs2_H_updated.npy => H_updated_nsigs2.npy} (100%) rename tests/test_data/nmf_framework/mvnmf/{mvnmf_nsigs1_W_init.npy => W_init_nsigs1.npy} (100%) rename tests/test_data/nmf_framework/mvnmf/{mvnmf_nsigs2_W_init.npy => W_init_nsigs2.npy} (100%) rename tests/test_data/nmf_framework/mvnmf/{mvnmf_nsigs1_W_updated.npy => W_updated_nsigs1.npy} (100%) rename tests/test_data/nmf_framework/mvnmf/{mvnmf_nsigs2_W_updated.npy => W_updated_nsigs2.npy} (100%) create mode 100644 tests/test_data/nmf_framework/mvnmf/counts.csv rename tests/test_data/nmf_framework/mvnmf/{mvnmf_nsigs1_objective_init.npy => objective_init_nsigs1.npy} (100%) rename tests/test_data/nmf_framework/mvnmf/{mvnmf_nsigs2_objective_init.npy => objective_init_nsigs2.npy} (100%) rename tests/test_data/{utils/objective_input_nsigs1_H.npy => nmf_framework/utils_klnmf/H_nsigs1.npy} (100%) rename tests/test_data/{utils/objective_input_nsigs2_H.npy => nmf_framework/utils_klnmf/H_nsigs2.npy} (100%) rename tests/test_data/nmf_framework/{klnmf/klnmf_nsigs1_H_updated.npy => utils_klnmf/H_updated_mu-joint_nsigs1.npy} (61%) create mode 100644 tests/test_data/nmf_framework/utils_klnmf/H_updated_mu-joint_nsigs2.npy create mode 100644 tests/test_data/nmf_framework/utils_klnmf/H_updated_mu-standard_nsigs1.npy create mode 100644 tests/test_data/nmf_framework/utils_klnmf/H_updated_mu-standard_nsigs2.npy rename tests/test_data/{utils/objective_input_nsigs1_W.npy => nmf_framework/utils_klnmf/W_nsigs1.npy} (100%) rename tests/test_data/{utils/objective_input_nsigs2_W.npy => nmf_framework/utils_klnmf/W_nsigs2.npy} (100%) create mode 100644 tests/test_data/nmf_framework/utils_klnmf/W_updated_mu-joint_nsigs1.npy create mode 100644 tests/test_data/nmf_framework/utils_klnmf/W_updated_mu-joint_nsigs2.npy create mode 100644 tests/test_data/nmf_framework/utils_klnmf/W_updated_mu-standard_nsigs1.npy create mode 100644 tests/test_data/nmf_framework/utils_klnmf/W_updated_mu-standard_nsigs2.npy rename tests/test_data/{utils => nmf_framework/utils_klnmf}/counts.csv (100%) rename tests/test_data/{utils/kl_divergence_nsigs1_result.npy => nmf_framework/utils_klnmf/kl_divergence_nsigs1.npy} (100%) rename tests/test_data/{utils/kl_divergence_nsigs2_result.npy => nmf_framework/utils_klnmf/kl_divergence_nsigs2.npy} (100%) rename tests/test_data/{utils/poisson_llh_nsigs1_result.npy => nmf_framework/utils_klnmf/poisson_llh_nsigs1.npy} (100%) rename tests/test_data/{utils/poisson_llh_nsigs2_result.npy => nmf_framework/utils_klnmf/poisson_llh_nsigs2.npy} (100%) rename tests/test_data/{utils/samplewise_kl_divergence_nsigs1_result.npy => nmf_framework/utils_klnmf/samplewise_kl_divergence_nsigs1.npy} (100%) rename tests/test_data/{utils/samplewise_kl_divergence_nsigs2_result.npy => nmf_framework/utils_klnmf/samplewise_kl_divergence_nsigs2.npy} (100%) delete mode 100644 tests/test_utils.py create mode 100644 tests/test_utils_klnmf.py diff --git a/poetry.lock b/poetry.lock index b621e1c..5fbcb65 100644 --- a/poetry.lock +++ b/poetry.lock @@ -695,65 +695,65 @@ test = ["hypothesis (>=5.5.3)", "pytest (>=6.0)", "pytest-xdist (>=1.31)"] [[package]] name = "pillow" -version = "10.0.1" +version = "10.1.0" description = "Python Imaging Library (Fork)" optional = false python-versions = ">=3.8" files = [ - {file = "Pillow-10.0.1-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:8f06be50669087250f319b706decf69ca71fdecd829091a37cc89398ca4dc17a"}, - {file = "Pillow-10.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:50bd5f1ebafe9362ad622072a1d2f5850ecfa44303531ff14353a4059113b12d"}, - {file = "Pillow-10.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e6a90167bcca1216606223a05e2cf991bb25b14695c518bc65639463d7db722d"}, - {file = "Pillow-10.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f11c9102c56ffb9ca87134bd025a43d2aba3f1155f508eff88f694b33a9c6d19"}, - {file = "Pillow-10.0.1-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:186f7e04248103482ea6354af6d5bcedb62941ee08f7f788a1c7707bc720c66f"}, - {file = "Pillow-10.0.1-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:0462b1496505a3462d0f35dc1c4d7b54069747d65d00ef48e736acda2c8cbdff"}, - {file = "Pillow-10.0.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:d889b53ae2f030f756e61a7bff13684dcd77e9af8b10c6048fb2c559d6ed6eaf"}, - {file = "Pillow-10.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:552912dbca585b74d75279a7570dd29fa43b6d93594abb494ebb31ac19ace6bd"}, - {file = "Pillow-10.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:787bb0169d2385a798888e1122c980c6eff26bf941a8ea79747d35d8f9210ca0"}, - {file = "Pillow-10.0.1-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:fd2a5403a75b54661182b75ec6132437a181209b901446ee5724b589af8edef1"}, - {file = "Pillow-10.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2d7e91b4379f7a76b31c2dda84ab9e20c6220488e50f7822e59dac36b0cd92b1"}, - {file = "Pillow-10.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:19e9adb3f22d4c416e7cd79b01375b17159d6990003633ff1d8377e21b7f1b21"}, - {file = "Pillow-10.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:93139acd8109edcdeffd85e3af8ae7d88b258b3a1e13a038f542b79b6d255c54"}, - {file = "Pillow-10.0.1-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:92a23b0431941a33242b1f0ce6c88a952e09feeea9af4e8be48236a68ffe2205"}, - {file = "Pillow-10.0.1-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:cbe68deb8580462ca0d9eb56a81912f59eb4542e1ef8f987405e35a0179f4ea2"}, - {file = "Pillow-10.0.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:522ff4ac3aaf839242c6f4e5b406634bfea002469656ae8358644fc6c4856a3b"}, - {file = "Pillow-10.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:84efb46e8d881bb06b35d1d541aa87f574b58e87f781cbba8d200daa835b42e1"}, - {file = "Pillow-10.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:898f1d306298ff40dc1b9ca24824f0488f6f039bc0e25cfb549d3195ffa17088"}, - {file = "Pillow-10.0.1-cp312-cp312-macosx_10_10_x86_64.whl", hash = "sha256:bcf1207e2f2385a576832af02702de104be71301c2696d0012b1b93fe34aaa5b"}, - {file = "Pillow-10.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5d6c9049c6274c1bb565021367431ad04481ebb54872edecfcd6088d27edd6ed"}, - {file = "Pillow-10.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:28444cb6ad49726127d6b340217f0627abc8732f1194fd5352dec5e6a0105635"}, - {file = "Pillow-10.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:de596695a75496deb3b499c8c4f8e60376e0516e1a774e7bc046f0f48cd620ad"}, - {file = "Pillow-10.0.1-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:2872f2d7846cf39b3dbff64bc1104cc48c76145854256451d33c5faa55c04d1a"}, - {file = "Pillow-10.0.1-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:4ce90f8a24e1c15465048959f1e94309dfef93af272633e8f37361b824532e91"}, - {file = "Pillow-10.0.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ee7810cf7c83fa227ba9125de6084e5e8b08c59038a7b2c9045ef4dde61663b4"}, - {file = "Pillow-10.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:b1be1c872b9b5fcc229adeadbeb51422a9633abd847c0ff87dc4ef9bb184ae08"}, - {file = "Pillow-10.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:98533fd7fa764e5f85eebe56c8e4094db912ccbe6fbf3a58778d543cadd0db08"}, - {file = "Pillow-10.0.1-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:764d2c0daf9c4d40ad12fbc0abd5da3af7f8aa11daf87e4fa1b834000f4b6b0a"}, - {file = "Pillow-10.0.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:fcb59711009b0168d6ee0bd8fb5eb259c4ab1717b2f538bbf36bacf207ef7a68"}, - {file = "Pillow-10.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:697a06bdcedd473b35e50a7e7506b1d8ceb832dc238a336bd6f4f5aa91a4b500"}, - {file = "Pillow-10.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9f665d1e6474af9f9da5e86c2a3a2d2d6204e04d5af9c06b9d42afa6ebde3f21"}, - {file = "Pillow-10.0.1-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:2fa6dd2661838c66f1a5473f3b49ab610c98a128fc08afbe81b91a1f0bf8c51d"}, - {file = "Pillow-10.0.1-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:3a04359f308ebee571a3127fdb1bd01f88ba6f6fb6d087f8dd2e0d9bff43f2a7"}, - {file = "Pillow-10.0.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:723bd25051454cea9990203405fa6b74e043ea76d4968166dfd2569b0210886a"}, - {file = "Pillow-10.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:71671503e3015da1b50bd18951e2f9daf5b6ffe36d16f1eb2c45711a301521a7"}, - {file = "Pillow-10.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:44e7e4587392953e5e251190a964675f61e4dae88d1e6edbe9f36d6243547ff3"}, - {file = "Pillow-10.0.1-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:3855447d98cced8670aaa63683808df905e956f00348732448b5a6df67ee5849"}, - {file = "Pillow-10.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ed2d9c0704f2dc4fa980b99d565c0c9a543fe5101c25b3d60488b8ba80f0cce1"}, - {file = "Pillow-10.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f5bb289bb835f9fe1a1e9300d011eef4d69661bb9b34d5e196e5e82c4cb09b37"}, - {file = "Pillow-10.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a0d3e54ab1df9df51b914b2233cf779a5a10dfd1ce339d0421748232cea9876"}, - {file = "Pillow-10.0.1-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:2cc6b86ece42a11f16f55fe8903595eff2b25e0358dec635d0a701ac9586588f"}, - {file = "Pillow-10.0.1-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:ca26ba5767888c84bf5a0c1a32f069e8204ce8c21d00a49c90dabeba00ce0145"}, - {file = "Pillow-10.0.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:f0b4b06da13275bc02adfeb82643c4a6385bd08d26f03068c2796f60d125f6f2"}, - {file = "Pillow-10.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:bc2e3069569ea9dbe88d6b8ea38f439a6aad8f6e7a6283a38edf61ddefb3a9bf"}, - {file = "Pillow-10.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:8b451d6ead6e3500b6ce5c7916a43d8d8d25ad74b9102a629baccc0808c54971"}, - {file = "Pillow-10.0.1-pp310-pypy310_pp73-macosx_10_10_x86_64.whl", hash = "sha256:32bec7423cdf25c9038fef614a853c9d25c07590e1a870ed471f47fb80b244db"}, - {file = "Pillow-10.0.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b7cf63d2c6928b51d35dfdbda6f2c1fddbe51a6bc4a9d4ee6ea0e11670dd981e"}, - {file = "Pillow-10.0.1-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:f6d3d4c905e26354e8f9d82548475c46d8e0889538cb0657aa9c6f0872a37aa4"}, - {file = "Pillow-10.0.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:847e8d1017c741c735d3cd1883fa7b03ded4f825a6e5fcb9378fd813edee995f"}, - {file = "Pillow-10.0.1-pp39-pypy39_pp73-macosx_10_10_x86_64.whl", hash = "sha256:7f771e7219ff04b79e231d099c0a28ed83aa82af91fd5fa9fdb28f5b8d5addaf"}, - {file = "Pillow-10.0.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:459307cacdd4138edee3875bbe22a2492519e060660eaf378ba3b405d1c66317"}, - {file = "Pillow-10.0.1-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:b059ac2c4c7a97daafa7dc850b43b2d3667def858a4f112d1aa082e5c3d6cf7d"}, - {file = "Pillow-10.0.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:d6caf3cd38449ec3cd8a68b375e0c6fe4b6fd04edb6c9766b55ef84a6e8ddf2d"}, - {file = "Pillow-10.0.1.tar.gz", hash = "sha256:d72967b06be9300fed5cfbc8b5bafceec48bf7cdc7dab66b1d2549035287191d"}, + {file = "Pillow-10.1.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:1ab05f3db77e98f93964697c8efc49c7954b08dd61cff526b7f2531a22410106"}, + {file = "Pillow-10.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6932a7652464746fcb484f7fc3618e6503d2066d853f68a4bd97193a3996e273"}, + {file = "Pillow-10.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a5f63b5a68daedc54c7c3464508d8c12075e56dcfbd42f8c1bf40169061ae666"}, + {file = "Pillow-10.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0949b55eb607898e28eaccb525ab104b2d86542a85c74baf3a6dc24002edec2"}, + {file = "Pillow-10.1.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:ae88931f93214777c7a3aa0a8f92a683f83ecde27f65a45f95f22d289a69e593"}, + {file = "Pillow-10.1.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:b0eb01ca85b2361b09480784a7931fc648ed8b7836f01fb9241141b968feb1db"}, + {file = "Pillow-10.1.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:d27b5997bdd2eb9fb199982bb7eb6164db0426904020dc38c10203187ae2ff2f"}, + {file = "Pillow-10.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:7df5608bc38bd37ef585ae9c38c9cd46d7c81498f086915b0f97255ea60c2818"}, + {file = "Pillow-10.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:41f67248d92a5e0a2076d3517d8d4b1e41a97e2df10eb8f93106c89107f38b57"}, + {file = "Pillow-10.1.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:1fb29c07478e6c06a46b867e43b0bcdb241b44cc52be9bc25ce5944eed4648e7"}, + {file = "Pillow-10.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2cdc65a46e74514ce742c2013cd4a2d12e8553e3a2563c64879f7c7e4d28bce7"}, + {file = "Pillow-10.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:50d08cd0a2ecd2a8657bd3d82c71efd5a58edb04d9308185d66c3a5a5bed9610"}, + {file = "Pillow-10.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:062a1610e3bc258bff2328ec43f34244fcec972ee0717200cb1425214fe5b839"}, + {file = "Pillow-10.1.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:61f1a9d247317fa08a308daaa8ee7b3f760ab1809ca2da14ecc88ae4257d6172"}, + {file = "Pillow-10.1.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:a646e48de237d860c36e0db37ecaecaa3619e6f3e9d5319e527ccbc8151df061"}, + {file = "Pillow-10.1.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:47e5bf85b80abc03be7455c95b6d6e4896a62f6541c1f2ce77a7d2bb832af262"}, + {file = "Pillow-10.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:a92386125e9ee90381c3369f57a2a50fa9e6aa8b1cf1d9c4b200d41a7dd8e992"}, + {file = "Pillow-10.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:0f7c276c05a9767e877a0b4c5050c8bee6a6d960d7f0c11ebda6b99746068c2a"}, + {file = "Pillow-10.1.0-cp312-cp312-macosx_10_10_x86_64.whl", hash = "sha256:a89b8312d51715b510a4fe9fc13686283f376cfd5abca8cd1c65e4c76e21081b"}, + {file = "Pillow-10.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:00f438bb841382b15d7deb9a05cc946ee0f2c352653c7aa659e75e592f6fa17d"}, + {file = "Pillow-10.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3d929a19f5469b3f4df33a3df2983db070ebb2088a1e145e18facbc28cae5b27"}, + {file = "Pillow-10.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9a92109192b360634a4489c0c756364c0c3a2992906752165ecb50544c251312"}, + {file = "Pillow-10.1.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:0248f86b3ea061e67817c47ecbe82c23f9dd5d5226200eb9090b3873d3ca32de"}, + {file = "Pillow-10.1.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:9882a7451c680c12f232a422730f986a1fcd808da0fd428f08b671237237d651"}, + {file = "Pillow-10.1.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:1c3ac5423c8c1da5928aa12c6e258921956757d976405e9467c5f39d1d577a4b"}, + {file = "Pillow-10.1.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:806abdd8249ba3953c33742506fe414880bad78ac25cc9a9b1c6ae97bedd573f"}, + {file = "Pillow-10.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:eaed6977fa73408b7b8a24e8b14e59e1668cfc0f4c40193ea7ced8e210adf996"}, + {file = "Pillow-10.1.0-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:fe1e26e1ffc38be097f0ba1d0d07fcade2bcfd1d023cda5b29935ae8052bd793"}, + {file = "Pillow-10.1.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7a7e3daa202beb61821c06d2517428e8e7c1aab08943e92ec9e5755c2fc9ba5e"}, + {file = "Pillow-10.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:24fadc71218ad2b8ffe437b54876c9382b4a29e030a05a9879f615091f42ffc2"}, + {file = "Pillow-10.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fa1d323703cfdac2036af05191b969b910d8f115cf53093125e4058f62012c9a"}, + {file = "Pillow-10.1.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:912e3812a1dbbc834da2b32299b124b5ddcb664ed354916fd1ed6f193f0e2d01"}, + {file = "Pillow-10.1.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:7dbaa3c7de82ef37e7708521be41db5565004258ca76945ad74a8e998c30af8d"}, + {file = "Pillow-10.1.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:9d7bc666bd8c5a4225e7ac71f2f9d12466ec555e89092728ea0f5c0c2422ea80"}, + {file = "Pillow-10.1.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:baada14941c83079bf84c037e2d8b7506ce201e92e3d2fa0d1303507a8538212"}, + {file = "Pillow-10.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:2ef6721c97894a7aa77723740a09547197533146fba8355e86d6d9a4a1056b14"}, + {file = "Pillow-10.1.0-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:0a026c188be3b443916179f5d04548092e253beb0c3e2ee0a4e2cdad72f66099"}, + {file = "Pillow-10.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:04f6f6149f266a100374ca3cc368b67fb27c4af9f1cc8cb6306d849dcdf12616"}, + {file = "Pillow-10.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bb40c011447712d2e19cc261c82655f75f32cb724788df315ed992a4d65696bb"}, + {file = "Pillow-10.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1a8413794b4ad9719346cd9306118450b7b00d9a15846451549314a58ac42219"}, + {file = "Pillow-10.1.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:c9aeea7b63edb7884b031a35305629a7593272b54f429a9869a4f63a1bf04c34"}, + {file = "Pillow-10.1.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:b4005fee46ed9be0b8fb42be0c20e79411533d1fd58edabebc0dd24626882cfd"}, + {file = "Pillow-10.1.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:4d0152565c6aa6ebbfb1e5d8624140a440f2b99bf7afaafbdbf6430426497f28"}, + {file = "Pillow-10.1.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:d921bc90b1defa55c9917ca6b6b71430e4286fc9e44c55ead78ca1a9f9eba5f2"}, + {file = "Pillow-10.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:cfe96560c6ce2f4c07d6647af2d0f3c54cc33289894ebd88cfbb3bcd5391e256"}, + {file = "Pillow-10.1.0-pp310-pypy310_pp73-macosx_10_10_x86_64.whl", hash = "sha256:937bdc5a7f5343d1c97dc98149a0be7eb9704e937fe3dc7140e229ae4fc572a7"}, + {file = "Pillow-10.1.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b1c25762197144e211efb5f4e8ad656f36c8d214d390585d1d21281f46d556ba"}, + {file = "Pillow-10.1.0-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:afc8eef765d948543a4775f00b7b8c079b3321d6b675dde0d02afa2ee23000b4"}, + {file = "Pillow-10.1.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:883f216eac8712b83a63f41b76ddfb7b2afab1b74abbb413c5df6680f071a6b9"}, + {file = "Pillow-10.1.0-pp39-pypy39_pp73-macosx_10_10_x86_64.whl", hash = "sha256:b920e4d028f6442bea9a75b7491c063f0b9a3972520731ed26c83e254302eb1e"}, + {file = "Pillow-10.1.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1c41d960babf951e01a49c9746f92c5a7e0d939d1652d7ba30f6b3090f27e412"}, + {file = "Pillow-10.1.0-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:1fafabe50a6977ac70dfe829b2d5735fd54e190ab55259ec8aea4aaea412fa0b"}, + {file = "Pillow-10.1.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:3b834f4b16173e5b92ab6566f0473bfb09f939ba14b23b8da1f54fa63e4b623f"}, + {file = "Pillow-10.1.0.tar.gz", hash = "sha256:e6bf8de6c36ed96c86ea3b6e1d5273c53f46ef518a062464cd7ef5dd2cf92e38"}, ] [package.extras] diff --git a/src/salamander/nmf_framework/_utils_corrnmf.py b/src/salamander/nmf_framework/_utils_corrnmf.py index 910d2e9..b4faf5c 100644 --- a/src/salamander/nmf_framework/_utils_corrnmf.py +++ b/src/salamander/nmf_framework/_utils_corrnmf.py @@ -33,42 +33,6 @@ def update_alpha(X: np.ndarray, L: np.ndarray, U: np.ndarray) -> np.ndarray: return alpha -@njit -def update_W(X: np.ndarray, W: np.ndarray, H: np.ndarray) -> np.ndarray: - """ - Compute the new signatures according to the update rule of NMF with - the generalized Kullback-Leibler divergence. An additional normalization - step is performed to guarantee that the signatures are probability distributions - over the mutation types. - - Parameters - ---------- - X : np.ndarray of shape (n_features, n_samples) - data matrix - - W : np.ndarray of shape (n_features, n_signatures) - signature matrix - - H : np.ndarray of shape (n_signatures, n_samples) - exposure matrix - - Returns - ------- - W : np.ndarray of shape (n_features, n_signatures) - updated signature matrix - - References - ---------- - D. Lee, H. Seung: Algorithms for Non-negative Matrix Factorization - - Advances in neural information processing systems, 2000 - https://proceedings.neurips.cc/paper_files/paper/2000/file/f9d1152547c0bde01830b7e8bd60024c-Paper.pdf - """ - W *= (X / (W @ H)) @ H.T - W /= np.sum(W, axis=0) - W = W.clip(EPSILON) - return W - - @njit def update_p_unnormalized(W: np.ndarray, H: np.ndarray) -> np.ndarray: """ diff --git a/src/salamander/nmf_framework/_utils_klnmf.py b/src/salamander/nmf_framework/_utils_klnmf.py new file mode 100644 index 0000000..70fe9aa --- /dev/null +++ b/src/salamander/nmf_framework/_utils_klnmf.py @@ -0,0 +1,264 @@ +import numpy as np +from numba import njit +from scipy.special import gammaln + +EPSILON = np.finfo(np.float32).eps + + +@njit(fastmath=True) +def kl_divergence(X: np.ndarray, W: np.ndarray, H: np.ndarray) -> float: + r""" + The generalized Kullback-Leibler divergence + D_KL(X || WH) = \sum_vd X_vd * ln(X_vd / (WH)_vd) - \sum_vd X_vd + \sum_vd (WH)_vd. + + Parameters + ---------- + X : np.ndarray of shape (n_features, n_samples) + data matrix + + W : np.ndarray of shape (n_features, n_signatures) + signature matrix + + H : np.ndarray of shape (n_signatures, n_samples) + exposure matrix + + Returns + ------- + result : float + """ + V, D = X.shape + WH = W @ H + result = 0.0 + + for v in range(V): + for d in range(D): + if X[v, d] != 0: + result += X[v, d] * np.log(X[v, d] / WH[v, d]) + result -= X[v, d] + result += WH[v, d] + + return result + + +def samplewise_kl_divergence(X, W, H): + """ + Per sample generalized Kullback-Leibler divergence D_KL(x || Wh). + + Parameters + ---------- + X : np.ndarray of shape (n_features, n_samples) + data matrix + + W : np.ndarray of shape (n_features, n_signatures) + signature matrix + + H : np.ndarray of shape (n_signatures, n_samples) + exposure matrix + + Returns + ------- + errors : np.ndarray of shape (n_samples,) + """ + X_data = np.copy(X).astype(float) + indices = X == 0 + X_data[indices] = EPSILON + WH_data = W @ H + WH_data[indices] = EPSILON + + s1 = np.einsum("vd,vd->d", X_data, np.log(X_data / WH_data)) + s2 = -np.sum(X, axis=0) + s3 = np.dot(H.T, np.sum(W, axis=0)) + + errors = s1 + s2 + s3 + + return errors + + +@njit(fastmath=True) +def _poisson_llh_wo_factorial(X: np.ndarray, W: np.ndarray, H: np.ndarray) -> float: + """ + The Poisson log-likelihood generalized to X, W and H having + non-negative real numbers without the summands involving the log-factorial + of elements of X. + Note: + scipy-special, which is required to computed the log-factorial, + is not supported by numba. + + Parameters + ---------- + X : np.ndarray of shape (n_features, n_samples) + data matrix + + W : np.ndarray of shape (n_features, n_signatures) + signature matrix + + H : np.ndarray of shape (n_signatures, n_samples) + exposure matrix + + Returns + ------- + result : float + """ + V, D = X.shape + WH = W @ H + result = 0.0 + + for v in range(V): + for d in range(D): + if WH[v, d] != 0: + result += X[v, d] * np.log(WH[v, d]) + result -= WH[v, d] + + return result + + +def poisson_llh(X: np.ndarray, W: np.ndarray, H: np.ndarray) -> float: + """ + The Poisson log-likelihood generalized to X, W and H having + non-negative real numbers. + + Parameters + ---------- + X : np.ndarray of shape (n_features, n_samples) + data matrix + + W : np.ndarray of shape (n_features, n_signatures) + signature matrix + + H : np.ndarray of shape (n_signatures, n_samples) + exposure matrix + + Returns + ------- + result : float + """ + result = _poisson_llh_wo_factorial(X, W, H) + result -= np.sum(gammaln(1 + X)) + + return result + + +@njit +def update_W( + X: np.ndarray, W: np.ndarray, H: np.ndarray, n_given_signatures: int = 0 +) -> np.ndarray: + """ + The multiplicative update rule of the signature matrix W + under the constraint of normalized signatures. It can be shown + that the generalized KL-divegence D_KL(X || WH) is decreasing + under the implemented update rule. + + Clipping the matrix avoids floating point errors. + + Parameters + ---------- + X : np.ndarray of shape (n_features, n_samples) + data matrix + + W : np.ndarray of shape (n_features, n_signatures) + signature matrix + + H : np.ndarray of shape (n_signatures, n_samples) + exposure matrix + + n_given_signatures : int + The number of known signatures, which will not be updated. + + Returns + ------- + W : np.ndarray of shape (n_features, n_signatures) + updated signature matrix + """ + W_updated = W * ((X / (W @ H)) @ H.T) + W_updated /= W_updated.sum(axis=0) + W_updated[:, :n_given_signatures] = W[:, :n_given_signatures].copy() + W_updated[:, n_given_signatures:] = W_updated[:, n_given_signatures:].clip(EPSILON) + + return W_updated + + +@njit +def update_H(X: np.ndarray, W: np.ndarray, H: np.ndarray) -> np.ndarray: + """ + The multiplicative update rule of the exposure matrix H + under the constraint of normalized signatures. + + Clipping the matrix avoids floating point errors. + + Parameters + ---------- + X : np.ndarray of shape (n_features, n_samples) + data matrix + + W : np.ndarray of shape (n_features, n_signatures) + signature matrix + + H : np.ndarray of shape (n_signatures, n_samples) + exposure matrix + + Returns + ------- + H : np.ndarray of shape (n_signatures, n_samples) + updated exposure matrix + + Reference + --------- + D. Lee, H. Seung: Algorithms for Non-negative Matrix Factorization + - Advances in neural information processing systems, 2000 + https://proceedings.neurips.cc/paper_files/paper/2000/file/f9d1152547c0bde01830b7e8bd60024c-Paper.pdf + """ + H *= W.T @ (X / (W @ H)) + H = H.clip(EPSILON) + + return H + + +@njit +def update_WH( + X: np.ndarray, W: np.ndarray, H: np.ndarray, n_given_signatures: int = 0 +) -> np.ndarray: + """ + A joint update rule for the signature matrix W and + the exposure matrix H under the constraint of normalized + signatures. + + Clipping the matrix avoids floating point errors. + + Parameters + ---------- + X : np.ndarray of shape (n_features, n_samples) + data matrix + + W : np.ndarray of shape (n_features, n_signatures) + signature matrix + + H : np.ndarray of shape (n_signatures, n_samples) + exposure matrix + + n_given_signatures : int + The number of known signatures, which will not be updated. + + Returns + ------- + W : np.ndarray of shape (n_features, n_signatures) + updated signature matrix + + H : np.ndarray of shape (n_signatures, n_samples) + updated exposure matrix + """ + n_signatures = W.shape[1] + aux = X / (W @ H) + + if n_given_signatures < n_signatures: + # the old signatures are needed for updating H + W_updated = W * (aux @ H.T) + W_updated /= np.sum(W_updated, axis=0) + W_updated[:, :n_given_signatures] = W[:, :n_given_signatures].copy() + W_updated = W_updated.clip(EPSILON) + else: + W_updated = W + + H *= W.T @ aux + H = H.clip(EPSILON) + + return W_updated, H diff --git a/src/salamander/nmf_framework/corrnmf.py b/src/salamander/nmf_framework/corrnmf.py index 40ebc1d..332eebd 100644 --- a/src/salamander/nmf_framework/corrnmf.py +++ b/src/salamander/nmf_framework/corrnmf.py @@ -5,14 +5,8 @@ from scipy.spatial.distance import squareform from scipy.special import gammaln -from ..utils import ( - kl_divergence, - match_signatures_pair, - poisson_llh, - samplewise_kl_divergence, - shape_checker, - type_checker, -) +from ..utils import match_signatures_pair, shape_checker, type_checker +from ._utils_klnmf import kl_divergence, poisson_llh, samplewise_kl_divergence from .initialization import ( init_custom, init_flat, @@ -180,7 +174,6 @@ def signatures(self) -> pd.DataFrame: signatures = pd.DataFrame( self.W, index=self.mutation_types, columns=self.signature_names ) - return signatures @property @@ -194,7 +187,6 @@ def exposures(self) -> pd.DataFrame: index=self.signature_names, columns=self.sample_names, ) - return exposures @property @@ -380,6 +372,9 @@ def _initialize( """ if given_signatures is not None: self._check_given_signatures(given_signatures) + self.n_given_signatures = len(given_signatures.columns) + else: + self.n_given_signatures = 0 if given_signature_embeddings is not None: self._check_given_signature_embeddings(given_signature_embeddings) @@ -407,8 +402,20 @@ def _initialize( self.W = init_separableNMF(self.X, self.n_signatures) if given_signatures is not None: - self.W = given_signatures.copy().values - self.signature_names = given_signatures.columns.to_numpy(dtype=str) + self.W[:, : self.n_given_signatures] = given_signatures.copy().values + given_signatures_names = given_signatures.columns.to_numpy(dtype=" float: - return kl_divergence(self.X, self.W, self.H) + return _utils_klnmf.kl_divergence(self.X, self.W, self.H) @property def samplewise_reconstruction_error(self) -> np.ndarray: - return samplewise_kl_divergence(self.X, self.W, self.H) + return _utils_klnmf.samplewise_kl_divergence(self.X, self.W, self.H) def objective_function(self) -> float: return self.reconstruction_error @@ -67,13 +76,23 @@ def objective(self) -> str: return "minimize" def loglikelihood(self) -> float: - return poisson_llh(self.X, self.W, self.H) + return _utils_klnmf.poisson_llh(self.X, self.W, self.H) def _update_W(self): - self.W = update_W(self.X, self.W, self.H) + self.W = _utils_klnmf.update_W(self.X, self.W, self.H, self.n_given_signatures) def _update_H(self): - self.H = update_H(self.X, self.W, self.H) + self.H = _utils_klnmf.update_H(self.X, self.W, self.H) + + def _update_WH(self): + if self.update_method == "mu-standard": + self._update_H() + if self.n_given_signatures < self.n_signatures: + self._update_W() + else: + self.W, self.H = _utils_klnmf.update_WH( + self.X, self.W, self.H, self.n_given_signatures + ) def fit( self, @@ -86,29 +105,34 @@ def fit( """ Minimize the generalized Kullback-Leibler divergence D_KL(X || WH) between the mutation count matrix X and product of the signature matrix W and - exposure matrix H by altering the multiplicative update steps for W and H. + exposure matrix H under the constraint of normalized signatures. - Input: - ------ - data: pd.DataFrame + Parameters + ---------- + data : pd.DataFrame The mutation count data - given_signatures: pd.DataFrame, default=None - In the case of refitting, a priori known signatures have to be provided. The - number of signatures has to match to the NMF object and the mutation type - names have to match to the mutation count matrix + given_signatures : pd.DataFrame, default=None + Known signatures that should be fixed by the algorithm. + The number of known signatures can be less or equal to the + number of signatures specified in the algorithm instance. - init_kwargs: dict + init_kwargs : dict, default=None Any further keyword arguments to be passed to the initialization method. This includes, for example, a possible 'seed' keyword argument for all stochastic methods. - history: bool - When set to true, the history of the objective function - will be stored in a dictionary. + history : bool, default=False + If True, the objective function value will be stored after every + iteration. - verbose: int - Every 100th iteration number will be printed when set unequal to zero. + verbose : int, default=0 + verbosity level + + Returns + ------- + self : object + Returns the instance itself. """ self._setup_data_parameters(data) self._initialize(given_signatures, init_kwargs) @@ -122,14 +146,7 @@ def fit( if verbose and n_iteration % 100 == 0: print(f"iteration {n_iteration}") - self._update_H() - - if given_signatures is None: - self._update_W() - - self.W, self.H = normalize_WH(self.W, self.H) - self.W, self.H = self.W.clip(EPSILON), self.H.clip(EPSILON) - + self._update_WH() prev_of_value = of_values[-1] of_values.append(self.objective_function()) rel_change = (prev_of_value - of_values[-1]) / prev_of_value diff --git a/src/salamander/nmf_framework/multimodal_corrnmf.py b/src/salamander/nmf_framework/multimodal_corrnmf.py index 749c9dd..7d2aa82 100755 --- a/src/salamander/nmf_framework/multimodal_corrnmf.py +++ b/src/salamander/nmf_framework/multimodal_corrnmf.py @@ -183,9 +183,9 @@ def _update_sigma_sq(self): for model in self.models: model.sigma_sq = sigma_sq - def _update_Ws(self, given_signatures): - for model, given_sigs in zip(self.models, given_signatures): - if given_sigs is None: + def _update_Ws(self): + for model in self.models: + if model.n_given_signatures < model.n_signatures: model._update_W() def _update_ps(self): @@ -353,17 +353,15 @@ def _initialize( given_signatures, given_signature_embeddings, ): - if given_sigs is None: - model.signature_names = np.char.add( - modality_name + " ", model.signature_names - ) - model._initialize( given_signatures=given_sigs, given_signature_embeddings=given_sig_embs, given_sample_embeddings=U, init_kwargs=init_kwargs, ) + model.signature_names[model.n_given_signatures :] = np.char.add( + modality_name + " ", model.signature_names[model.n_given_signatures :] + ) def fit( self, @@ -404,7 +402,7 @@ def fit( ps = self._update_ps() self._update_LsU(ps, given_signature_embeddings, given_sample_embeddings) self._update_sigma_sq() - self._update_Ws(given_signatures) + self._update_Ws() of_values.append(self.objective_function()) prev_sof_value = sof_values[-1] diff --git a/src/salamander/nmf_framework/mvnmf.py b/src/salamander/nmf_framework/mvnmf.py index edf367e..9c9bee0 100644 --- a/src/salamander/nmf_framework/mvnmf.py +++ b/src/salamander/nmf_framework/mvnmf.py @@ -2,7 +2,8 @@ import pandas as pd from numba import njit -from ..utils import kl_divergence, normalize_WH, poisson_llh, samplewise_kl_divergence +from ..utils import normalize_WH +from ._utils_klnmf import kl_divergence, poisson_llh, samplewise_kl_divergence, update_H from .nmf import NMF EPSILON = np.finfo(np.float32).eps @@ -28,25 +29,14 @@ def kl_divergence_penalized( return loss -@njit -def update_H(X: np.ndarray, W: np.ndarray, H: np.ndarray) -> np.ndarray: - """ - The multiplicative update rule of the exposure matrix H - derived by Lee and Seung. See Theorem 2 in - "Algorithms for non-negative matrix factorization". - - Clipping the matrix avoids floating point errors. - """ - H *= W.T @ (X / (W @ H)) - H /= np.sum(W, axis=0)[:, np.newaxis] - H = H.clip(EPSILON) - - return H - - @njit def update_W_unconstrained( - X: np.ndarray, W: np.ndarray, H: np.ndarray, lam: float, delta: float + X: np.ndarray, + W: np.ndarray, + H: np.ndarray, + lam: float, + delta: float, + n_given_signatures: int = 0, ) -> np.ndarray: n_signatures = W.shape[1] diag = np.diag(np.full(n_signatures, delta)) @@ -64,7 +54,10 @@ def update_W_unconstrained( numerator = numerator_s1 + numerator_s2 denominator = 4 * lam * WY_abs W_unconstrained = W * numerator / denominator - W_unconstrained = W_unconstrained.clip(EPSILON) + W_unconstrained[:, :n_given_signatures] = W[:, :n_given_signatures].copy() + W_unconstrained[:, n_given_signatures:] = W_unconstrained[ + :, n_given_signatures: + ].clip(EPSILON) return W_unconstrained @@ -98,11 +91,9 @@ def line_search( class MvNMF(NMF): """ - Min-volume non-negative matrix factorization. See Algorithm 1 in - - Leplat, V., Gillis, N. and Ang, A.M., 2020. - Blind audio source separation with minimum-volume beta-divergence NMF. - IEEE Transactions on Signal Processing, 68, pp.3400-3410. + Min-volume non-negative matrix factorization. This algorithms is a volume- + regularized version of NMF with the generalized Kullback-Leibler (KL) + divergence. Parameters ---------- @@ -128,6 +119,12 @@ class MvNMF(NMF): tol : float, default=1e-7 Tolerance of the stopping condition. + + Reference + --------- + Leplat, V., Gillis, N. and Ang, A.M., 2020. + Blind audio source separation with minimum-volume beta-divergence NMF. + IEEE Transactions on Signal Processing, 68, pp.3400-3410. """ def __init__( @@ -167,7 +164,9 @@ def _update_H(self): self.H = update_H(self.X, self.W, self.H) def _update_W_unconstrained(self): - return update_W_unconstrained(self.X, self.W, self.H, self.lam, self.delta) + return update_W_unconstrained( + self.X, self.W, self.H, self.lam, self.delta, self.n_given_signatures + ) def _line_search(self, W_unconstrained): self.W, self.H, self._gamma = line_search( @@ -232,7 +231,7 @@ def fit( self._update_H() - if given_signatures is None: + if self.n_given_signatures < self.n_signatures: self._update_W() prev_of_value = of_values[-1] diff --git a/src/salamander/nmf_framework/nmf.py b/src/salamander/nmf_framework/nmf.py index e884e22..d83a33b 100644 --- a/src/salamander/nmf_framework/nmf.py +++ b/src/salamander/nmf_framework/nmf.py @@ -184,6 +184,9 @@ def _initialize(self, given_signatures=None, init_kwargs=None): """ if given_signatures is not None: self._check_given_signatures(given_signatures) + self.n_given_signatures = len(given_signatures.columns) + else: + self.n_given_signatures = 0 init_kwargs = {} if init_kwargs is None else init_kwargs.copy() @@ -205,8 +208,20 @@ def _initialize(self, given_signatures=None, init_kwargs=None): self.W = init_separableNMF(self.X, self.n_signatures) if given_signatures is not None: - self.W = given_signatures.copy().values - self.signature_names = given_signatures.columns.to_numpy(dtype=str) + self.W[:, : self.n_given_signatures] = given_signatures.copy().values + given_signatures_names = given_signatures.columns.to_numpy(dtype=str) + n_new_signatures = self.n_signatures - self.n_given_signatures + new_signatures_names = np.array( + [f"Sig{k+1}" for k in range(n_new_signatures)] + ) + self.signature_names = np.concatenate( + [given_signatures_names, new_signatures_names] + ) + + else: + self.signature_names = np.array( + [f"Sig{k+1}" for k in range(self.n_signatures)] + ) if not hasattr(self, "H"): _, self.H = init_random(self.X, self.n_signatures) diff --git a/src/salamander/nmf_framework/signature_nmf.py b/src/salamander/nmf_framework/signature_nmf.py index 0b2c6e7..8aca3ac 100644 --- a/src/salamander/nmf_framework/signature_nmf.py +++ b/src/salamander/nmf_framework/signature_nmf.py @@ -145,7 +145,7 @@ def __init__( value_checker("init_method", init_method, init_methods) self.n_signatures = n_signatures - self.signature_names = np.array([f"Sig{k+1}" for k in range(n_signatures)]) + self.signature_names = None self.init_method = init_method self.min_iterations = min_iterations self.max_iterations = max_iterations @@ -154,6 +154,7 @@ def __init__( # initialize data/fitting dependent attributes self.X = None self.n_features = 0 + self.n_given_signatures = 0 self.n_samples = 0 self.mutation_types = np.empty(0, dtype=str) self.sample_names = np.empty(0, dtype=str) @@ -240,21 +241,23 @@ def _check_given_signatures(self, given_signatures: pd.DataFrame): """ Check if the given signatures are compatible with the number of signatures of the algorithm and the - mutation types of the input data and. + mutation types of the input data. given_signatures: pd.DataFrame Known signatures that should be fixed by the algorithm. + The number of known signatures can be less or equal to the + number of signatures specified in the algorithm instance. """ type_checker("given_signatures", given_signatures, pd.DataFrame) given_mutation_types = given_signatures.index.to_numpy(dtype=str) compatible = ( np.array_equal(given_mutation_types, self.mutation_types) - and given_signatures.shape[1] == self.n_signatures + and given_signatures.shape[1] <= self.n_signatures ) if not compatible: raise ValueError( - f"You have to provide {self.n_signatures} signatures with " + f"You have to provide at most {self.n_signatures} signatures with " f"mutation types matching to your data." ) @@ -271,7 +274,6 @@ def _initialize(self): decompose the mutation count matrix X into a signature matrix W and an exposure matrix H, both W and H have to be initialized. """ - pass def _setup_data_parameters(self, data: pd.DataFrame): """ @@ -301,13 +303,11 @@ def fit(self, data: pd.DataFrame, given_signatures=None): The named mutation count data of shape (n_features, n_samples). given_signatures: pd.DataFrame, by default None - In the case of refitting, 'given_signatures' - are the a priori known signatures. - The number of signatures has to match to the NMF algorithm - instance and the mutation type names have to match to the names - of the mutation count data. + A priori known signatures. The number of given signatures has + to be less or equal to the number of signatures of NMF + algorithm instance, and the mutation type names have to match + the mutation types of the count data. """ - pass @salamander_style def plot_signatures( @@ -321,11 +321,6 @@ def plot_signatures( ): """ Plot the signatures, see plot.py for the implementation of signatures_plot. - - Input: - ------ - **kwargs: - arguments to be passed to signatures_plot """ axes = signatures_plot( self.signatures, @@ -355,11 +350,6 @@ def plot_exposures( """ Visualize the exposures as a stacked bar chart, see plot.py for the implementation. - - Input: - ------ - **kwargs: - arguments to be passed to exposure_plot """ ax = exposures_plot( exposures=self.exposures, @@ -383,7 +373,6 @@ def corr_signatures(self) -> pd.DataFrame: Every child class of SignatureNMF has to implement a function that returns the signature correlation matrix as a pandas dataframe. """ - pass @property @abstractmethod @@ -392,7 +381,6 @@ def corr_samples(self) -> pd.DataFrame: Every child class of SignatureNMF has to implement a function that returns the sample correlation matrix as a pandas dataframe. """ - pass def plot_correlation(self, data="signatures", annot=False, outfile=None, **kwargs): """ diff --git a/src/salamander/utils.py b/src/salamander/utils.py index 7fd2efa..ecfc0a6 100644 --- a/src/salamander/utils.py +++ b/src/salamander/utils.py @@ -2,7 +2,6 @@ import pandas as pd from numba import njit from scipy.optimize import linear_sum_assignment -from scipy.special import gammaln from sklearn.metrics import pairwise_distances EPSILON = np.finfo(np.float32).eps @@ -69,79 +68,6 @@ def value_checker(arg_name: str, arg, allowed_values): ) -@njit(fastmath=True) -def kl_divergence(X: np.ndarray, W: np.ndarray, H: np.ndarray) -> float: - r""" - The generalized Kullback-Leibler divergence - D_KL(X || WH) = \sum_vd X_vd * ln(X_vd / (WH)_vd) - \sum_vd X_vd + \sum_vd (WH)_vd. - """ - V, D = X.shape - WH = W @ H - result = 0.0 - - for v in range(V): - for d in range(D): - if X[v, d] != 0: - result += X[v, d] * np.log(X[v, d] / WH[v, d]) - result -= X[v, d] - result += WH[v, d] - - return result - - -def samplewise_kl_divergence(X, W, H): - """ - Per sample generalizedKullback-Leibler divergence D_KL(x || Wh). - """ - X_data = np.copy(X).astype(float) - indices = X == 0 - X_data[indices] = EPSILON - WH_data = W @ H - WH_data[indices] = EPSILON - - s1 = np.einsum("vd,vd->d", X_data, np.log(X_data / WH_data)) - s2 = -np.sum(X, axis=0) - s3 = np.dot(H.T, np.sum(W, axis=0)) - - errors = s1 + s2 + s3 - - return errors - - -@njit(fastmath=True) -def _poisson_llh_wo_factorial(X: np.ndarray, W: np.ndarray, H: np.ndarray) -> float: - """ - The Poisson log-likelihood generalized to X, W and H having - non-negative real numbers without the summands involving the log-factorial - of elements of X. - Note: - scipy-special, which is required to computed the log-factorial, - is not supported by numba. - """ - V, D = X.shape - WH = W @ H - result = 0.0 - - for v in range(V): - for d in range(D): - if WH[v, d] != 0: - result += X[v, d] * np.log(WH[v, d]) - result -= WH[v, d] - - return result - - -def poisson_llh(X: np.ndarray, W: np.ndarray, H: np.ndarray) -> float: - """ - The Poisson log-likelihood generalized to X, W and H having - non-negative real numbers. - """ - result = _poisson_llh_wo_factorial(X, W, H) - result -= np.sum(gammaln(1 + X)) - - return result - - @njit def normalize_WH(W, H): normalization_factor = np.sum(W, axis=0) diff --git a/tests/test_corrnmf.py b/tests/test_corrnmf.py index 9a8899d..cbc3353 100644 --- a/tests/test_corrnmf.py +++ b/tests/test_corrnmf.py @@ -10,36 +10,37 @@ @pytest.fixture def counts(): - return pd.read_csv(f"{PATH}/nmf_framework/counts.csv", index_col=0) + return pd.read_csv(f"{PATH_TEST_DATA}/counts.csv", index_col=0) -@pytest.fixture(params=[(1, 1), (2, 2)]) -def model(request): - param = request.param - return corrnmf_det.CorrNMFDet(n_signatures=param[0], dim_embeddings=param[1]) +@pytest.fixture(params=[1, 2]) +def n_signatures(request): + return request.param @pytest.fixture -def path(model): - return ( - f"{PATH_TEST_DATA}/" - f"corrnmf_nsigs{model.n_signatures}_dim{model.dim_embeddings}" - ) +def dim_embeddings(n_signatures): + return n_signatures + + +@pytest.fixture +def path_suffix(n_signatures, dim_embeddings): + return f"nsigs{n_signatures}_dim{dim_embeddings}.npy" @pytest.fixture -def W_init(path): - return np.load(f"{path}_W_init.npy") +def W_init(path_suffix): + return np.load(f"{PATH_TEST_DATA}/W_init_{path_suffix}") @pytest.fixture -def alpha_init(path): - return np.load(f"{path}_alpha_init.npy") +def alpha_init(path_suffix): + return np.load(f"{PATH_TEST_DATA}/alpha_init_{path_suffix}") @pytest.fixture -def _p(path): - return np.load(f"{path}_p.npy") +def _p(path_suffix): + return np.load(f"{PATH_TEST_DATA}/p_{path_suffix}") @pytest.fixture @@ -48,22 +49,26 @@ def _aux(counts, _p): @pytest.fixture -def L_init(path): - return np.load(f"{path}_L_init.npy") +def L_init(path_suffix): + return np.load(f"{PATH_TEST_DATA}/L_init_{path_suffix}") @pytest.fixture -def U_init(path): - return np.load(f"{path}_U_init.npy") +def U_init(path_suffix): + return np.load(f"{PATH_TEST_DATA}/U_init_{path_suffix}") @pytest.fixture -def sigma_sq_init(path): - return np.load(f"{path}_sigma_sq_init.npy") +def sigma_sq_init(path_suffix): + return np.load(f"{PATH_TEST_DATA}/sigma_sq_init_{path_suffix}") @pytest.fixture -def model_init(model, counts, W_init, alpha_init, L_init, U_init, sigma_sq_init): +def model_init(counts, W_init, alpha_init, L_init, U_init, sigma_sq_init): + n_signatures, dim_embeddings = L_init.shape + model = corrnmf_det.CorrNMFDet( + n_signatures=n_signatures, dim_embeddings=dim_embeddings + ) model.X = counts.values model.W = W_init model.alpha = alpha_init @@ -71,7 +76,7 @@ def model_init(model, counts, W_init, alpha_init, L_init, U_init, sigma_sq_init) model.U = U_init model.sigma_sq = sigma_sq_init model.mutation_types = counts.index - model.signature_names = ["_" for _ in range(model.n_signatures)] + model.signature_names = ["_" for _ in range(n_signatures)] model.sample_names = counts.columns model.n_samples = len(counts.columns) model.given_signature_embeddings = None @@ -79,52 +84,52 @@ def model_init(model, counts, W_init, alpha_init, L_init, U_init, sigma_sq_init) @pytest.fixture -def objective_init(path): - return np.load(f"{path}_objective_init.npy") +def objective_init(path_suffix): + return np.load(f"{PATH_TEST_DATA}/objective_init_{path_suffix}") -@pytest.fixture -def surrogate_objective_init(path): - return np.load(f"{path}_surrogate_objective_init.npy") +def test_objective_function(model_init, objective_init): + assert np.allclose(model_init.objective_function(), objective_init) @pytest.fixture -def W_updated(path): - return np.load(f"{path}_W_updated.npy") +def surrogate_objective_init(path_suffix): + return np.load(f"{PATH_TEST_DATA}/surrogate_objective_init_{path_suffix}") + + +def test_surrogate_objective_function(model_init, _p, surrogate_objective_init): + assert np.allclose( + model_init._surrogate_objective_function(_p), surrogate_objective_init + ) @pytest.fixture -def alpha_updated(path): - return np.load(f"{path}_alpha_updated.npy") +def W_updated(path_suffix): + return np.load(f"{PATH_TEST_DATA}/W_updated_{path_suffix}") @pytest.fixture -def L_updated(path): - return np.load(f"{path}_L_updated.npy") +def alpha_updated(path_suffix): + return np.load(f"{PATH_TEST_DATA}/alpha_updated_{path_suffix}") @pytest.fixture -def U_updated(path): - return np.load(f"{path}_U_updated.npy") +def L_updated(path_suffix): + return np.load(f"{PATH_TEST_DATA}/L_updated_{path_suffix}") @pytest.fixture -def sigma_sq_updated(path): - return np.load(f"{path}_sigma_sq_updated.npy") +def U_updated(path_suffix): + return np.load(f"{PATH_TEST_DATA}/U_updated_{path_suffix}") -class TestCorrNMFDet: - def test_objective_function(self, model_init, objective_init): - assert np.allclose(model_init.objective_function(), objective_init) +@pytest.fixture +def sigma_sq_updated(path_suffix): + return np.load(f"{PATH_TEST_DATA}/sigma_sq_updated_{path_suffix}") - def test_surrogate_objective_function( - self, model_init, _p, surrogate_objective_init - ): - assert np.allclose( - model_init._surrogate_objective_function(_p), surrogate_objective_init - ) - def test_update_W_Lee(self, model_init, W_updated): +class TestUpdatesCorrNMFDet: + def test_update_W(self, model_init, W_updated): model_init._update_W() assert np.allclose(model_init.W, W_updated) @@ -141,54 +146,54 @@ def test_update_L(self, model_init, _aux, L_updated): assert np.allclose(model_init.L, L_updated) def test_update_U(self, model_init, _aux, U_updated): - print("\n\n", "BEFORE UPDATE", model_init.U, "\n\n") model_init._update_U(_aux) - print("UPDATED", model_init.U, "\n\n") - print("SOLUTION", U_updated, "\n\n") - print("DIFFERENCE", np.sum(np.abs(model_init.U - U_updated))) assert np.allclose(model_init.U, U_updated) def test_update_sigma_sq(self, model_init, sigma_sq_updated): model_init._update_sigma_sq() assert np.allclose(model_init.sigma_sq, sigma_sq_updated) - -@pytest.mark.parametrize("n_signatures", [1, 2]) -def test_given_signatures(counts, n_signatures): - given_signatures = counts.iloc[:, :n_signatures].astype(float).copy() - given_signatures /= given_signatures.sum(axis=0) - model = corrnmf_det.CorrNMFDet( - n_signatures=n_signatures, - dim_embeddings=n_signatures, - min_iterations=3, - max_iterations=3, - ) - model.fit(counts, given_signatures=given_signatures) - assert np.allclose(given_signatures, model.signatures) - - -@pytest.mark.parametrize("n_signatures,dim_embeddings", [(1, 1), (2, 1), (2, 2)]) -def test_given_signature_embeddings(counts, n_signatures, dim_embeddings): - given_signature_embeddings = np.random.uniform(size=(dim_embeddings, n_signatures)) - model = corrnmf_det.CorrNMFDet( - n_signatures=n_signatures, - dim_embeddings=dim_embeddings, - min_iterations=3, - max_iterations=3, - ) - model.fit(counts, given_signature_embeddings=given_signature_embeddings) - assert np.allclose(given_signature_embeddings, model.L) - - -@pytest.mark.parametrize("n_signatures,dim_embeddings", [(1, 1), (2, 1), (2, 2)]) -def test_given_sample_embeddings(counts, n_signatures, dim_embeddings): - n_samples = len(counts.columns) - given_sample_embeddings = np.random.uniform(size=(dim_embeddings, n_samples)) - model = corrnmf_det.CorrNMFDet( - n_signatures=n_signatures, - dim_embeddings=dim_embeddings, - min_iterations=3, - max_iterations=3, - ) - model.fit(counts, given_sample_embeddings=given_sample_embeddings) - assert np.allclose(given_sample_embeddings, model.U) + def test_given_signatures(self, n_signatures, counts): + for n_given_signatures in range(1, n_signatures + 1): + given_signatures = counts.iloc[:, :n_given_signatures].astype(float).copy() + given_signatures /= given_signatures.sum(axis=0) + model = corrnmf_det.CorrNMFDet( + n_signatures=n_signatures, + dim_embeddings=n_signatures, + min_iterations=3, + max_iterations=3, + ) + model.fit(counts, given_signatures=given_signatures) + assert np.allclose( + given_signatures, model.signatures.iloc[:, :n_given_signatures] + ) + + def test_given_signature_embeddings(self, n_signatures, counts): + for dim_embeddings in range(1, n_signatures + 1): + given_signature_embeddings = np.random.uniform( + size=(dim_embeddings, n_signatures) + ) + model = corrnmf_det.CorrNMFDet( + n_signatures=n_signatures, + dim_embeddings=dim_embeddings, + min_iterations=3, + max_iterations=3, + ) + model.fit(counts, given_signature_embeddings=given_signature_embeddings) + assert np.allclose(given_signature_embeddings, model.L) + + def test_given_sample_embeddings(self, n_signatures, counts): + n_samples = len(counts.columns) + + for dim_embeddings in range(1, n_signatures + 1): + given_sample_embeddings = np.random.uniform( + size=(dim_embeddings, n_samples) + ) + model = corrnmf_det.CorrNMFDet( + n_signatures=n_signatures, + dim_embeddings=dim_embeddings, + min_iterations=3, + max_iterations=3, + ) + model.fit(counts, given_sample_embeddings=given_sample_embeddings) + assert np.allclose(given_sample_embeddings, model.U) diff --git a/tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs1_dim1_L_init.npy b/tests/test_data/nmf_framework/corrnmf/L_init_nsigs1_dim1.npy similarity index 100% rename from tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs1_dim1_L_init.npy rename to tests/test_data/nmf_framework/corrnmf/L_init_nsigs1_dim1.npy diff --git a/tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs2_dim2_L_init.npy b/tests/test_data/nmf_framework/corrnmf/L_init_nsigs2_dim2.npy similarity index 100% rename from tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs2_dim2_L_init.npy rename to tests/test_data/nmf_framework/corrnmf/L_init_nsigs2_dim2.npy diff --git a/tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs1_dim1_L_updated.npy b/tests/test_data/nmf_framework/corrnmf/L_updated_nsigs1_dim1.npy similarity index 100% rename from tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs1_dim1_L_updated.npy rename to tests/test_data/nmf_framework/corrnmf/L_updated_nsigs1_dim1.npy diff --git a/tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs2_dim2_L_updated.npy b/tests/test_data/nmf_framework/corrnmf/L_updated_nsigs2_dim2.npy similarity index 100% rename from tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs2_dim2_L_updated.npy rename to tests/test_data/nmf_framework/corrnmf/L_updated_nsigs2_dim2.npy diff --git a/tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs1_dim1_U_init.npy b/tests/test_data/nmf_framework/corrnmf/U_init_nsigs1_dim1.npy similarity index 100% rename from tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs1_dim1_U_init.npy rename to tests/test_data/nmf_framework/corrnmf/U_init_nsigs1_dim1.npy diff --git a/tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs2_dim2_U_init.npy b/tests/test_data/nmf_framework/corrnmf/U_init_nsigs2_dim2.npy similarity index 100% rename from tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs2_dim2_U_init.npy rename to tests/test_data/nmf_framework/corrnmf/U_init_nsigs2_dim2.npy diff --git a/tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs1_dim1_U_updated.npy b/tests/test_data/nmf_framework/corrnmf/U_updated_nsigs1_dim1.npy similarity index 100% rename from tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs1_dim1_U_updated.npy rename to tests/test_data/nmf_framework/corrnmf/U_updated_nsigs1_dim1.npy diff --git a/tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs2_dim2_U_updated.npy b/tests/test_data/nmf_framework/corrnmf/U_updated_nsigs2_dim2.npy similarity index 100% rename from tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs2_dim2_U_updated.npy rename to tests/test_data/nmf_framework/corrnmf/U_updated_nsigs2_dim2.npy diff --git a/tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs1_dim1_W_init.npy b/tests/test_data/nmf_framework/corrnmf/W_init_nsigs1_dim1.npy similarity index 100% rename from tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs1_dim1_W_init.npy rename to tests/test_data/nmf_framework/corrnmf/W_init_nsigs1_dim1.npy diff --git a/tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs2_dim2_W_init.npy b/tests/test_data/nmf_framework/corrnmf/W_init_nsigs2_dim2.npy similarity index 100% rename from tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs2_dim2_W_init.npy rename to tests/test_data/nmf_framework/corrnmf/W_init_nsigs2_dim2.npy diff --git a/tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs1_dim1_W_updated.npy b/tests/test_data/nmf_framework/corrnmf/W_updated_nsigs1_dim1.npy similarity index 100% rename from tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs1_dim1_W_updated.npy rename to tests/test_data/nmf_framework/corrnmf/W_updated_nsigs1_dim1.npy diff --git a/tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs2_dim2_W_updated.npy b/tests/test_data/nmf_framework/corrnmf/W_updated_nsigs2_dim2.npy similarity index 100% rename from tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs2_dim2_W_updated.npy rename to tests/test_data/nmf_framework/corrnmf/W_updated_nsigs2_dim2.npy diff --git a/tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs1_dim1_alpha_init.npy b/tests/test_data/nmf_framework/corrnmf/alpha_init_nsigs1_dim1.npy similarity index 100% rename from tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs1_dim1_alpha_init.npy rename to tests/test_data/nmf_framework/corrnmf/alpha_init_nsigs1_dim1.npy diff --git a/tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs2_dim2_alpha_init.npy b/tests/test_data/nmf_framework/corrnmf/alpha_init_nsigs2_dim2.npy similarity index 100% rename from tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs2_dim2_alpha_init.npy rename to tests/test_data/nmf_framework/corrnmf/alpha_init_nsigs2_dim2.npy diff --git a/tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs1_dim1_alpha_updated.npy b/tests/test_data/nmf_framework/corrnmf/alpha_updated_nsigs1_dim1.npy similarity index 100% rename from tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs1_dim1_alpha_updated.npy rename to tests/test_data/nmf_framework/corrnmf/alpha_updated_nsigs1_dim1.npy diff --git a/tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs2_dim2_alpha_updated.npy b/tests/test_data/nmf_framework/corrnmf/alpha_updated_nsigs2_dim2.npy similarity index 100% rename from tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs2_dim2_alpha_updated.npy rename to tests/test_data/nmf_framework/corrnmf/alpha_updated_nsigs2_dim2.npy diff --git a/tests/test_data/nmf_framework/corrnmf/counts.csv b/tests/test_data/nmf_framework/corrnmf/counts.csv new file mode 100644 index 0000000..05310ed --- /dev/null +++ b/tests/test_data/nmf_framework/corrnmf/counts.csv @@ -0,0 +1,97 @@ +Type,SP9251,SP6730,SP10084,SP5381,SP10635,SP2714,SP11235,SP8085,SP4593,SP4820 +A[C>A]A,94,54,239,28,35,180,78,103,32,112 +A[C>A]C,60,63,199,25,23,129,60,58,24,131 +A[C>A]G,10,10,28,4,8,17,10,7,5,22 +A[C>A]T,72,42,222,17,24,148,43,60,20,106 +C[C>A]A,57,47,163,19,24,159,78,65,17,119 +C[C>A]C,73,35,161,18,13,143,41,33,10,121 +C[C>A]G,13,9,16,6,9,15,10,10,3,20 +C[C>A]T,68,57,189,7,20,173,62,65,14,116 +G[C>A]A,75,69,122,23,22,89,100,95,20,65 +G[C>A]C,45,45,84,14,16,83,27,50,15,92 +G[C>A]G,8,6,13,5,8,16,10,11,5,16 +G[C>A]T,49,66,101,10,13,109,72,63,13,68 +T[C>A]A,65,87,171,19,33,204,99,135,30,84 +T[C>A]C,86,61,132,18,24,202,60,82,33,84 +T[C>A]G,6,12,24,6,11,10,11,16,4,13 +T[C>A]T,91,134,250,24,28,265,151,182,44,100 +A[C>G]A,95,24,95,8,17,150,46,26,10,105 +A[C>G]C,40,19,61,8,7,70,36,19,14,60 +A[C>G]G,11,4,25,0,0,35,3,5,5,23 +A[C>G]T,96,18,118,7,9,145,45,35,16,112 +C[C>G]A,66,15,47,2,10,151,12,17,13,60 +C[C>G]C,52,10,36,3,9,92,23,15,11,77 +C[C>G]G,22,11,14,3,3,39,6,11,2,42 +C[C>G]T,79,19,50,7,7,203,28,31,15,106 +G[C>G]A,36,6,28,3,4,78,18,6,8,68 +G[C>G]C,30,9,43,3,8,64,22,8,7,61 +G[C>G]G,8,1,6,0,1,14,8,3,1,19 +G[C>G]T,51,8,34,6,10,107,23,18,9,93 +T[C>G]A,119,61,131,8,20,687,68,116,24,93 +T[C>G]C,85,32,80,9,11,448,52,56,13,110 +T[C>G]G,15,11,13,1,1,45,11,14,1,16 +T[C>G]T,239,90,236,14,28,1125,136,214,35,207 +A[C>T]A,126,91,238,24,46,187,80,70,54,140 +A[C>T]C,61,56,103,17,26,83,43,47,35,79 +A[C>T]G,149,272,257,92,119,185,147,168,134,145 +A[C>T]T,92,51,181,18,32,157,33,57,43,119 +C[C>T]A,75,76,112,21,46,140,53,66,40,100 +C[C>T]C,69,67,89,28,39,97,59,59,54,77 +C[C>T]G,93,163,139,45,72,108,110,108,85,88 +C[C>T]T,107,94,185,27,49,220,68,75,56,162 +G[C>T]A,68,75,86,13,37,99,74,58,38,123 +G[C>T]C,46,61,95,22,44,103,80,45,44,79 +G[C>T]G,90,230,176,71,81,155,124,118,102,116 +G[C>T]T,74,55,129,10,22,110,35,49,38,93 +T[C>T]A,139,178,198,28,63,520,136,224,97,106 +T[C>T]C,126,97,155,24,40,341,98,95,79,98 +T[C>T]G,80,128,117,35,68,101,79,103,53,87 +T[C>T]T,152,128,244,26,72,382,109,147,116,137 +A[T>A]A,43,66,115,17,16,86,44,26,13,57 +A[T>A]C,25,19,75,20,24,46,31,21,27,48 +A[T>A]G,37,30,99,10,17,76,22,29,16,64 +A[T>A]T,63,61,168,21,32,120,49,38,18,117 +C[T>A]A,31,32,85,4,9,61,15,16,3,63 +C[T>A]C,32,17,65,2,16,71,19,22,7,101 +C[T>A]G,55,24,105,9,14,108,34,22,9,68 +C[T>A]T,63,39,182,6,9,117,24,23,17,90 +G[T>A]A,22,17,42,5,4,47,10,9,7,40 +G[T>A]C,20,14,39,3,8,38,14,12,10,38 +G[T>A]G,23,16,33,5,14,48,18,15,5,39 +G[T>A]T,41,16,99,2,9,102,18,17,10,71 +T[T>A]A,31,63,124,16,29,122,51,60,18,76 +T[T>A]C,30,21,95,3,10,59,25,21,12,44 +T[T>A]G,19,15,57,2,5,43,9,9,7,39 +T[T>A]T,76,38,240,9,13,146,41,41,13,116 +A[T>C]A,90,101,150,29,49,189,79,57,52,157 +A[T>C]C,42,29,68,10,23,87,36,19,18,85 +A[T>C]G,58,47,85,10,24,108,55,31,26,131 +A[T>C]T,99,95,118,19,46,200,95,58,40,167 +C[T>C]A,39,50,57,7,15,69,44,14,20,73 +C[T>C]C,55,27,92,11,15,139,28,20,13,109 +C[T>C]G,43,37,42,10,14,59,28,25,15,88 +C[T>C]T,59,42,68,5,13,105,56,37,22,113 +G[T>C]A,40,63,79,11,32,101,50,22,12,81 +G[T>C]C,29,27,35,12,11,62,37,18,14,47 +G[T>C]G,26,41,32,11,9,58,39,17,15,64 +G[T>C]T,57,49,70,14,36,103,55,28,20,73 +T[T>C]A,56,83,106,7,23,93,41,28,23,76 +T[T>C]C,47,52,73,14,16,91,50,24,9,73 +T[T>C]G,25,29,57,4,13,61,18,21,13,54 +T[T>C]T,54,75,92,17,39,133,55,46,27,99 +A[T>G]A,29,25,49,2,13,61,29,18,13,49 +A[T>G]C,12,12,15,4,5,26,11,4,7,35 +A[T>G]G,43,16,40,6,6,74,17,19,6,44 +A[T>G]T,37,24,57,13,13,56,24,15,9,46 +C[T>G]A,18,7,12,5,3,30,10,9,3,29 +C[T>G]C,23,13,21,8,4,42,6,10,3,27 +C[T>G]G,43,11,46,1,6,100,21,25,7,62 +C[T>G]T,20,14,49,8,9,70,14,25,7,64 +G[T>G]A,7,10,17,0,1,27,7,7,5,26 +G[T>G]C,9,6,12,3,6,19,6,3,4,18 +G[T>G]G,25,15,37,11,4,76,16,12,8,57 +G[T>G]T,27,6,24,6,5,63,20,11,5,41 +T[T>G]A,39,25,39,2,9,66,21,8,9,52 +T[T>G]C,19,9,30,3,4,37,15,6,8,31 +T[T>G]G,39,18,73,4,10,86,21,18,4,45 +T[T>G]T,58,38,81,10,20,110,48,38,16,109 diff --git a/tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs1_dim1_objective_init.npy b/tests/test_data/nmf_framework/corrnmf/objective_init_nsigs1_dim1.npy similarity index 100% rename from tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs1_dim1_objective_init.npy rename to tests/test_data/nmf_framework/corrnmf/objective_init_nsigs1_dim1.npy diff --git a/tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs2_dim2_objective_init.npy b/tests/test_data/nmf_framework/corrnmf/objective_init_nsigs2_dim2.npy similarity index 100% rename from tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs2_dim2_objective_init.npy rename to tests/test_data/nmf_framework/corrnmf/objective_init_nsigs2_dim2.npy diff --git a/tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs1_dim1_p.npy b/tests/test_data/nmf_framework/corrnmf/p_nsigs1_dim1.npy similarity index 100% rename from tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs1_dim1_p.npy rename to tests/test_data/nmf_framework/corrnmf/p_nsigs1_dim1.npy diff --git a/tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs2_dim2_p.npy b/tests/test_data/nmf_framework/corrnmf/p_nsigs2_dim2.npy similarity index 100% rename from tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs2_dim2_p.npy rename to tests/test_data/nmf_framework/corrnmf/p_nsigs2_dim2.npy diff --git a/tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs1_dim1_sigma_sq_init.npy b/tests/test_data/nmf_framework/corrnmf/sigma_sq_init_nsigs1_dim1.npy similarity index 100% rename from tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs1_dim1_sigma_sq_init.npy rename to tests/test_data/nmf_framework/corrnmf/sigma_sq_init_nsigs1_dim1.npy diff --git a/tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs2_dim2_sigma_sq_init.npy b/tests/test_data/nmf_framework/corrnmf/sigma_sq_init_nsigs2_dim2.npy similarity index 100% rename from tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs2_dim2_sigma_sq_init.npy rename to tests/test_data/nmf_framework/corrnmf/sigma_sq_init_nsigs2_dim2.npy diff --git a/tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs1_dim1_sigma_sq_updated.npy b/tests/test_data/nmf_framework/corrnmf/sigma_sq_updated_nsigs1_dim1.npy similarity index 100% rename from tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs1_dim1_sigma_sq_updated.npy rename to tests/test_data/nmf_framework/corrnmf/sigma_sq_updated_nsigs1_dim1.npy diff --git a/tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs2_dim2_sigma_sq_updated.npy b/tests/test_data/nmf_framework/corrnmf/sigma_sq_updated_nsigs2_dim2.npy similarity index 100% rename from tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs2_dim2_sigma_sq_updated.npy rename to tests/test_data/nmf_framework/corrnmf/sigma_sq_updated_nsigs2_dim2.npy diff --git a/tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs1_dim1_surrogate_objective_init.npy b/tests/test_data/nmf_framework/corrnmf/surrogate_objective_init_nsigs1_dim1.npy similarity index 100% rename from tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs1_dim1_surrogate_objective_init.npy rename to tests/test_data/nmf_framework/corrnmf/surrogate_objective_init_nsigs1_dim1.npy diff --git a/tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs2_dim2_surrogate_objective_init.npy b/tests/test_data/nmf_framework/corrnmf/surrogate_objective_init_nsigs2_dim2.npy similarity index 100% rename from tests/test_data/nmf_framework/corrnmf/corrnmf_nsigs2_dim2_surrogate_objective_init.npy rename to tests/test_data/nmf_framework/corrnmf/surrogate_objective_init_nsigs2_dim2.npy diff --git a/tests/test_data/nmf_framework/klnmf/klnmf_nsigs1_H_init.npy b/tests/test_data/nmf_framework/klnmf/H_init_nsigs1.npy similarity index 100% rename from tests/test_data/nmf_framework/klnmf/klnmf_nsigs1_H_init.npy rename to tests/test_data/nmf_framework/klnmf/H_init_nsigs1.npy diff --git a/tests/test_data/nmf_framework/klnmf/klnmf_nsigs2_H_init.npy b/tests/test_data/nmf_framework/klnmf/H_init_nsigs2.npy similarity index 100% rename from tests/test_data/nmf_framework/klnmf/klnmf_nsigs2_H_init.npy rename to tests/test_data/nmf_framework/klnmf/H_init_nsigs2.npy diff --git a/tests/test_data/nmf_framework/klnmf/WH_updated_mu-joint_nsigs1.pkl b/tests/test_data/nmf_framework/klnmf/WH_updated_mu-joint_nsigs1.pkl new file mode 100644 index 0000000000000000000000000000000000000000..79869f4588d83fbaf3acb3a914d8e83cd9efeafd GIT binary patch literal 1007 zcmX|=drZ?;6vugWAPAd)7zk#D4MG)|5i^>Be4F8hqfoTYAsSt2X(2$0^5kZ}6oVYPEup5s!kv&8p3=!5>Jb&EJy}9SybI(2ZDaKZ36F zfV1z(xO0aN&Yce%$tenmb|&{nElFRTpA@THyO&&!t#=H<$I3!PUbqtSN4^x714kmF zM$=HAxRQS_M2=x$ZdlLVk6+N?vR$8XnE~$)8u(Rnd3cq+la)Jq4v+NO1xH>rBz?G1 zH_Jxc?fX7+ns)g2q~{Hk5}4x8EUQn;k2w11*=m7&6P%5CHL ztqNp)9{t-ttu@%z%B%K{A>i;_ruFiy#Kkg6?Bv!8j5BU6?^WvXF6nPyueowuB4>`z zrZwT|n(lf4w+@lrnemj0Pr+e687@j~#GuYA^7Y{oJZPAiCi&-LT^Ua7*wJ#GW z-Ie&=l2?L2T|{Eyi9A~!#h8DroyZGRfZ{UdU#=1$+4AmuymK|{P2> z*Fd3ODiH=1V*dV)nZc1#6#X33aNIo~alP;FBJVFlwKll!l2<7vzF1!!=n`YGm2*`a zErINqx_{-fQY>^8{*o|OhU=%+B33>sw%tbz=bOOcd9oM5gI^x?`glp*7q_HDw0JnCZ-Cd*&LHN zR1RDUp{yt;7(_75QDq-R(CPHJO_Ny$nw^kf(U>%Q1);K!A%rG3o?~#`VXFe|?9H6& zqgs~wZmKzMQ-96Mqcv8JU9j%|H)>^LmzB&e>W10htC_lPrdm%Wc&;>inyZ4~d6-ji G!u|*4kj*v# literal 0 HcmV?d00001 diff --git a/tests/test_data/nmf_framework/klnmf/WH_updated_mu-joint_nsigs2.pkl b/tests/test_data/nmf_framework/klnmf/WH_updated_mu-joint_nsigs2.pkl new file mode 100644 index 0000000000000000000000000000000000000000..6943d2022e10c22a284db84daa5deaf66aac08da GIT binary patch literal 1857 zcmYL~dpOit7{}cbwaq+_q9_$BX}8ryq;OtEb~36>Tf1~2%$Vt-F=odIWyTmY41@WN zMPhPk$5SndP`ZdrOKh#$WJ|S*9;7J^OS1IOlwx@B2OHea_{1<7OzFX6k%u z`!PGIN!U+ESyA@E$JjJ`?Ucr524m9M0d)4Uqd`B>=`=Pb$Lun73_m+eTFdr##>wF` z#i6k)Bm3OMGJ;X7@=RMpFu}Rc4^|o8O@%fR{~PUvPg#gZ_`t%Vg(#*25 zQAH*JcZb7?zq90Mv_3voZXHDU$p(1?jj=K)_2;K|)n_6WW)tMmm`v!kkq7Q{yT%jT za{RF6W4%CvD?OH<-EEvmaHL*E-0{H-aEZ0W^B2jH4u$pm?I9W5&Ncpp_H6~a$hN83 z$erH5Z+E*g?oBpyMopscEG5WrW7l%EC=t4C{q_4^Rw7yYucy_NXm~MiU<>cD1Z@<$ z*tGfwK-Q>bI2sEZ_kZ`Cz1k04T9*du zHOW|iMiZo+*O*0?ir}u*JGJWlO%PGxrr-TQupT_+GB_3w+x;C#-|o$WA!hMp#$OMB zsUey6rYVT?oNYFkJi-+b?9;+|--UKm9}#Je)YS=P#y5+kcYL=XbKB&pHIhaoPN7KF6 z*>&du%jlk;j8leOE;ob+UYUMxj{|ad=dz=Tcs*x9JbzCd!S^?}|H&as3D)dwJidT@ zoM6YhMHI)n7_eBZo|Aki9vyLOObYQ4fT`c=%&#ho(ZVHyl_r!#V23j5dxsN{!qn}( zvw{UiGPm=d_Ib!Ru-T3sO9!V2;Y7Wg>L9jl&^gebpd*MtwHz?5pDqpgdN z$kavwlPd?wVCvSxVSFUirPOXV;=}MZl>KZ#j663imiWg^|F`@g#r0d(a(KS$I;o;B z1;tv3n#koRiF|z-R)wpYgrMn&nH(>xLpEPC=|^H?3H@R8dEL!!2~fxt@Iuq#Q7dKO zq#;QN18yf&gvs7BY<9dO8vxKA=-3DW31E9X2bl8OU zE<8<{QGIUFujr{?|4jSngJ51FCM`QqhSuyaU)Ev15Bf)|Xjhx%=;jp3dy82FSik#d z^u8$pX71>yQ}B+$7S|WEt93S!mH_{-=!g4@#0Znl_*5^nV+bdTMbj1ssUF=KlZX>B zU(AlGFV%I$wr}6=jcePdj;h1Kf~bL1JrO3=^}smVFowQert)(I(Rp5S4#%HVxZ7Uf z=8H`qBKMe&I9_|v{r>_JW2ar{S{=c>nh6}=B`0{oUdwIy1cx~ehv{f;x$peS#q;Cb zBKLyHNgU5DbJrd0rS({d(QBRv#{T-bI%Zd3v3FLKrFxTH^WRe(udZ=bn4+xnBk4 z?2N0E-Agb!bonk*YyQG-CI zN3i$l)Xrrc>S$uM{}>C`Zp%Dc7&>%p*?G|RvJP2IZ9gUN(n1mIHMozSfjGXq=OvpC zU!Pk%U_4U+WuUOoA`WjId!Rzonazc-yXeu9crZSMRf!`9LIQuR z$^duI;|cc;BiuXhH&Qdy5Cy08$JnGV&OaKjUAmK|!KPbd!iU-->^pxc;t#nL7ehy) zV@9)3pRypo6QRPeP!`#9`@x#-kqo@c`Ey#{{OYyntHq2&bjicA?d?4 z!weee~_2lU*=SMq-+8nb)0Cp@7x?!xvFakk7U2{ctm+zcyTosRqkEN{|f8+9-18)?#!leV#Z6D}wdzHp}2h8H#@nYdGd5M?&wr+o^k%sICn+T<|T!;{Q zw+L@iHDRh>#c1!}aoX)H8z)}<=dW}w4=dXP+f)=O*k79NyA6pjMYnkPH>abd-m7lM zz;P@;3HZkAun;Xnaf)xLc}U&8!EM`9A$a!xn~jJN%3J^TOV<0Y1W8crdRb+rP@2bC z6$jzMrIFfl4ap?K8LotL44IsqoUpQL_JQFfBw37fhO?R^oa0F0svFNW-0-$TflkhA zobIFBm!@v|THK21ii1aL92`II*#B?T!Nx8J)4S-)YyQ5?^k-|-ndl_Xli|$pRFgdK JwN#wQ{{iQo%{c%7 literal 0 HcmV?d00001 diff --git a/tests/test_data/nmf_framework/klnmf/WH_updated_mu-standard_nsigs2.pkl b/tests/test_data/nmf_framework/klnmf/WH_updated_mu-standard_nsigs2.pkl new file mode 100644 index 0000000000000000000000000000000000000000..b2062e8530a348541f7d01b16d974bff81e56609 GIT binary patch literal 1857 zcmYL~dt8iJ7{{l^h!K)Y#AZVpZ6cIhI?qJg#UQ$fbf0DvqRTYRsF|8BbIHt<=510! zp%fxEs}O~#L@^>N)VgLBE4QYao#nG{=a2KA@8|dY&Uwx`@A*85q%2!K!jI;MQKW{u zV!|W1R`lIb3@h!B5fw(qq+m&w5=^6Jz(V`QjPW18o{NrMJ| zGy2yO1&ZW%937YyjQjo9&k^zX>0oT&?sw^t5_$0hBVOFxi|e{WU0$2)lX0HVH2)tP zVj#{VT~d~}jTYfNzhtiSo1RQ4s(zjAGGBps^*w6-$8_lEzNR$XQjWUQ57rj4hu1%o zYVu~vyBw%lX+c}ESb{iLdc|%c5d?@jJPF^FaRdZ8YEHnp6y-L^Icf4~6Kq^xZ=D1P)M2ySf$&k3gngMt0j`4tT$qe#2pv z1oGD$6_;s3ft9gfYQo@I)O?|%rPNITv&WiF)Sq4q@5!>2&4ZEf?eEL8GEeM<&fMg` zOj=k_P#bgEbnO{5sdLLcQm+RjyD#vJRv$pDrI$ji4{^aI)O=6DgiPpLri|KYodIqR zTSC6b8(>~b(e%8NLYUe%NzCu7h7HxRc0NskP_bW0E>|amlg=)h{iZ{ZedF9Q@5c|| z5p~gQ=DsZaJel>yO{`cU&K;YTB}JnOaAs2ncdFIt(AVEoD=$kz4$O8(+ASHbpKa}z zmyJ)uc|u5qYDN);^DFoEQQ_;dU?^>C-urcXQ85?{+zcJQe<4}7jNmH4<0DV(u&B75 zgL7@feYWMyVV&3Kn{IpjAdHb(<{!r{z>E#+Oi?igbW+gfHL@cZkmB5*@eMmO zeh<#8Nfsm8YvQ4^)Jh&3myGf~Yx3`|O@i&BI1t z(Dg97h%8LoK37c>gJj`Q5^IeJ6>l%8%$YNsPmHRuG2&_j^c*`N*ft{#bwB9!ocU1( z0_$nbJ3a}J@PwuR6Se~MbqwC-@dU_wu#I0hN&r*eF<0jjC8*_oTt!#>@Ov9=9$#)f zTLF5_-O@Ue9Bp>+!fu^Q!1oVJHkuSSq=2l1xs6#}gNO+frV*cq>y#C*=2h9J!t)X1 z2UrEkC}o+le$TiRc&a(?I?hLk!UPvT4nB#64a9l@)JqZX;GcS~JTC5^<`Ep-nV$l* zDXJj1@u`SS_AnXuG8NYkb7#$P_D+JaT|r;s%u>*qsPgNVr*Ys`9`T*G_I&F6?Vc_t zrGeb7XjQg_1o6Al4BO9B^<62dvx<6W)e*id<;r!(htF}g7Q#bGA]A,94,54,239,28,35,180,78,103,32,112 +A[C>A]C,60,63,199,25,23,129,60,58,24,131 +A[C>A]G,10,10,28,4,8,17,10,7,5,22 +A[C>A]T,72,42,222,17,24,148,43,60,20,106 +C[C>A]A,57,47,163,19,24,159,78,65,17,119 +C[C>A]C,73,35,161,18,13,143,41,33,10,121 +C[C>A]G,13,9,16,6,9,15,10,10,3,20 +C[C>A]T,68,57,189,7,20,173,62,65,14,116 +G[C>A]A,75,69,122,23,22,89,100,95,20,65 +G[C>A]C,45,45,84,14,16,83,27,50,15,92 +G[C>A]G,8,6,13,5,8,16,10,11,5,16 +G[C>A]T,49,66,101,10,13,109,72,63,13,68 +T[C>A]A,65,87,171,19,33,204,99,135,30,84 +T[C>A]C,86,61,132,18,24,202,60,82,33,84 +T[C>A]G,6,12,24,6,11,10,11,16,4,13 +T[C>A]T,91,134,250,24,28,265,151,182,44,100 +A[C>G]A,95,24,95,8,17,150,46,26,10,105 +A[C>G]C,40,19,61,8,7,70,36,19,14,60 +A[C>G]G,11,4,25,0,0,35,3,5,5,23 +A[C>G]T,96,18,118,7,9,145,45,35,16,112 +C[C>G]A,66,15,47,2,10,151,12,17,13,60 +C[C>G]C,52,10,36,3,9,92,23,15,11,77 +C[C>G]G,22,11,14,3,3,39,6,11,2,42 +C[C>G]T,79,19,50,7,7,203,28,31,15,106 +G[C>G]A,36,6,28,3,4,78,18,6,8,68 +G[C>G]C,30,9,43,3,8,64,22,8,7,61 +G[C>G]G,8,1,6,0,1,14,8,3,1,19 +G[C>G]T,51,8,34,6,10,107,23,18,9,93 +T[C>G]A,119,61,131,8,20,687,68,116,24,93 +T[C>G]C,85,32,80,9,11,448,52,56,13,110 +T[C>G]G,15,11,13,1,1,45,11,14,1,16 +T[C>G]T,239,90,236,14,28,1125,136,214,35,207 +A[C>T]A,126,91,238,24,46,187,80,70,54,140 +A[C>T]C,61,56,103,17,26,83,43,47,35,79 +A[C>T]G,149,272,257,92,119,185,147,168,134,145 +A[C>T]T,92,51,181,18,32,157,33,57,43,119 +C[C>T]A,75,76,112,21,46,140,53,66,40,100 +C[C>T]C,69,67,89,28,39,97,59,59,54,77 +C[C>T]G,93,163,139,45,72,108,110,108,85,88 +C[C>T]T,107,94,185,27,49,220,68,75,56,162 +G[C>T]A,68,75,86,13,37,99,74,58,38,123 +G[C>T]C,46,61,95,22,44,103,80,45,44,79 +G[C>T]G,90,230,176,71,81,155,124,118,102,116 +G[C>T]T,74,55,129,10,22,110,35,49,38,93 +T[C>T]A,139,178,198,28,63,520,136,224,97,106 +T[C>T]C,126,97,155,24,40,341,98,95,79,98 +T[C>T]G,80,128,117,35,68,101,79,103,53,87 +T[C>T]T,152,128,244,26,72,382,109,147,116,137 +A[T>A]A,43,66,115,17,16,86,44,26,13,57 +A[T>A]C,25,19,75,20,24,46,31,21,27,48 +A[T>A]G,37,30,99,10,17,76,22,29,16,64 +A[T>A]T,63,61,168,21,32,120,49,38,18,117 +C[T>A]A,31,32,85,4,9,61,15,16,3,63 +C[T>A]C,32,17,65,2,16,71,19,22,7,101 +C[T>A]G,55,24,105,9,14,108,34,22,9,68 +C[T>A]T,63,39,182,6,9,117,24,23,17,90 +G[T>A]A,22,17,42,5,4,47,10,9,7,40 +G[T>A]C,20,14,39,3,8,38,14,12,10,38 +G[T>A]G,23,16,33,5,14,48,18,15,5,39 +G[T>A]T,41,16,99,2,9,102,18,17,10,71 +T[T>A]A,31,63,124,16,29,122,51,60,18,76 +T[T>A]C,30,21,95,3,10,59,25,21,12,44 +T[T>A]G,19,15,57,2,5,43,9,9,7,39 +T[T>A]T,76,38,240,9,13,146,41,41,13,116 +A[T>C]A,90,101,150,29,49,189,79,57,52,157 +A[T>C]C,42,29,68,10,23,87,36,19,18,85 +A[T>C]G,58,47,85,10,24,108,55,31,26,131 +A[T>C]T,99,95,118,19,46,200,95,58,40,167 +C[T>C]A,39,50,57,7,15,69,44,14,20,73 +C[T>C]C,55,27,92,11,15,139,28,20,13,109 +C[T>C]G,43,37,42,10,14,59,28,25,15,88 +C[T>C]T,59,42,68,5,13,105,56,37,22,113 +G[T>C]A,40,63,79,11,32,101,50,22,12,81 +G[T>C]C,29,27,35,12,11,62,37,18,14,47 +G[T>C]G,26,41,32,11,9,58,39,17,15,64 +G[T>C]T,57,49,70,14,36,103,55,28,20,73 +T[T>C]A,56,83,106,7,23,93,41,28,23,76 +T[T>C]C,47,52,73,14,16,91,50,24,9,73 +T[T>C]G,25,29,57,4,13,61,18,21,13,54 +T[T>C]T,54,75,92,17,39,133,55,46,27,99 +A[T>G]A,29,25,49,2,13,61,29,18,13,49 +A[T>G]C,12,12,15,4,5,26,11,4,7,35 +A[T>G]G,43,16,40,6,6,74,17,19,6,44 +A[T>G]T,37,24,57,13,13,56,24,15,9,46 +C[T>G]A,18,7,12,5,3,30,10,9,3,29 +C[T>G]C,23,13,21,8,4,42,6,10,3,27 +C[T>G]G,43,11,46,1,6,100,21,25,7,62 +C[T>G]T,20,14,49,8,9,70,14,25,7,64 +G[T>G]A,7,10,17,0,1,27,7,7,5,26 +G[T>G]C,9,6,12,3,6,19,6,3,4,18 +G[T>G]G,25,15,37,11,4,76,16,12,8,57 +G[T>G]T,27,6,24,6,5,63,20,11,5,41 +T[T>G]A,39,25,39,2,9,66,21,8,9,52 +T[T>G]C,19,9,30,3,4,37,15,6,8,31 +T[T>G]G,39,18,73,4,10,86,21,18,4,45 +T[T>G]T,58,38,81,10,20,110,48,38,16,109 diff --git a/tests/test_data/nmf_framework/klnmf/klnmf_nsigs1_W_updated.npy b/tests/test_data/nmf_framework/klnmf/klnmf_nsigs1_W_updated.npy deleted file mode 100644 index d2c5ea8d694b3d98b7fdc99cf98861f4a3d504bf..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 896 zcmbV~>rava6vk1@EZe$ijaJ*cT9Az_q%6~Xv~^*XU8cFYbSN*pp-B*m2I?Zw7cT)S zQL!#Fo$Er`g)Yl#W7SFoiBTWzn^vXNIluE9)3#kHJ2HY8 zMT`=*NL)ZF*lP%O0=1gWA=s2mQK-n-Cy*&b;=lUFocsdum|h^sk&DOb#c{D5f*Zvl z%837sDP49C$Ftvr!r|n8rrZdh(VlI)K_dEXl)B)$0ecTQF1nQlL@pqkDY6EM&;6OU z2qCP~vLY{ujZobg1>-xg_!8Ss8n@Qw|MI_U067rzKEq|i z@gEBxG9^L?X6)=f-($hg_TZ4tLLG*KHng~TCI9+6h^&cHRTbLf$=c8$4dncqDV;af zpufjBnV%&PL~P+Ta%r4!j4rc3Ec}<*3)7BJ?^1F4s`or=Z5f3A#PE>`C$N{StD_Pd zz&+a0Vd)j&OHxV5;JLHt@1Jv=(wLEc_yNx@rr~UL+2eeT;C9Tp@~Gd44EkBI^Rgc7 z7)2AySBCjq+J;bUz%Erh*e7Itd-Rm6*lc7Z&~oMxi2m|yc#2d zn&DyY8SIW%=7y+`p?;IX(34t^#d9gO_UUOvTVI(rTD2g|E|(>%1nFfhix`_Vm^707 z((A1SYkaN>+GwScXDtnIlw(DH%gUr!9hfhJhYnZC;UB-d47X>myKgzRAAL&>@RN2Q z#LTU_Rrb&--1WsLFs7@qL3)&xbWIMLNm}V%JwjiHXKVME0^AyK2Ij|R=jqUTW8l}k z2osug-I1M^gLwX7E}7a^0mpj%n_j96HFrJB8A~fs^&;4J@wXN`zDDKo+KVyu@}QKL zt3q>{Zdqep1=dW)sKZ%_SiUmL$uAG=a{!qs`tIAF>VW^nxVgPTWJ?$0xA>$ZV`dRb z%kQ0Xe=CMd)+FT7BCK9_=#0A|7oQBzWVX05j?(KZ7*C6!*R~qm(j0X4vCdKxGC|O7 mPkNJ;2zxR#P`2XgiUUDuh#jOhXUWq@x6^JVU diff --git a/tests/test_data/nmf_framework/klnmf/klnmf_nsigs2_H_updated.npy b/tests/test_data/nmf_framework/klnmf/klnmf_nsigs2_H_updated.npy deleted file mode 100644 index 7890ef44306fdd688403adba311bf5b69dec1d0f..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 288 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1ZlV+i=qoAIaUsO_*m=~X4l#&V(4=E~51qv5u zBo?Fsxf(_~3Wf%nItsN4WCN~;D_8&je`kxsV^8fl9imt59Bcml|G#IW0|OW|*wt$M z|NnpQK{&s1=kNdjgD1oJfAxU;%k$yUM1d?fz= z#+F0N5xH~GJdpS=Vo;pq&ll<4g0xi zkW`|odS4+GC54Ra^PnYzbC!$~QN%+TeRsmEG~6)1&2oabSwuRNcB_u*i?h+}B>(bU zD;Lw%b{Uzi8pmTiuQ|PYmy|2UIs4rUjpY(CHenPrWX@#5W;Q8?eVB(zsOep3E**H8 z2WH(aK13r5Uq2#;ZuA$Cb1s>_zX|796^{q>IjEiFf1re&2)Z0R%jZ)S8n&>$$*^ET zyHGS=y1+qG6f-LuM;GW-&l74X(&43flZ7ld32IkgzI(q=jH23e1wOHKSR;55ys`_S z`s&ALuKILHO)cv@W6J@vA=jFj&)zWVtJx!5uR!(VX4iMx#e$@dKp|RJ2;{nz?YGy1 zfFm$*;M8-Wr7A(qRGL)P@a}I`8A_Gi5-$HVuh10qP3}{;0ls&4R z2~&geX7r;Ophdqi{an-q9awB*F^m`u;Y_G!BaNfV(77+>6U*DXX=fN}5Ec53aXbPyTH6+EqoL*`+d zm5;QgV!BX#EMP<-8RMfx$U#`d$5=O{tN9PBEWk0I_m$gn5yM`CJM6V_|F5@2Wt_0! zY}CI!H89nA6XPEx7EWK;v!RR3e8zdgJQ&k+IDM+98d6uC79{TRA*QsDxZS7}BCm~! zZC$Tp{(4Ktx%_|aOtp0n_J?bcQmVCx$iHz+UL^@a14gbaStTM;?rV*f{TuU~R+_9ZX$%D26#P!*y~*e+ z>7Vp*bE&W-mZb9kj72N$d|-z1pw`uTp&=?3<@qNX2Fu2R+7R!qfiwqIzA%?IW^Vkq zzFVQpVGSNEns%2xmQ6!v^)Iz?%cHULGpF70dYvq2ET*h*o12lt)>{-@L?os!6(3*! zmdJ)4S!a)$!g$n26s;*Lut29w3C^lVwrD3Zy*{K{tD z_I?E{Ae&Tsixk-?`yTFK{987rYsV~qzh)l?&JDG{b!f6s_>_|L78N?k2t5afzwBpq zah2(bONNYINlgQ74r+A=;(c})828;dE^nNI*e`bqwt9KOvO$BX4lfp24+wu*{EG_n zgh*|R8#HKZ^*OvBN5}dZ-VXcb1K$JOl&kK(3I74%S)-zT#aH3N`n!{B*$pU>s^)ee zjRq4XxTc=-nXtI!`lRzX4Spx)92q}RhC(~-WXRzZ5YM+7$wq?a0Q?6Krs6w^-2*^8@z45bK;$0_34a87l=EMDwYi^vZ>(U*~UK uC)za_BW>$?(VdHAd+N-GR~Rs>Bo(eI`vi5)Jk@!f&IWbC?n4swN$4M0QA]A,94,54,239,28,35,180,78,103,32,112 +A[C>A]C,60,63,199,25,23,129,60,58,24,131 +A[C>A]G,10,10,28,4,8,17,10,7,5,22 +A[C>A]T,72,42,222,17,24,148,43,60,20,106 +C[C>A]A,57,47,163,19,24,159,78,65,17,119 +C[C>A]C,73,35,161,18,13,143,41,33,10,121 +C[C>A]G,13,9,16,6,9,15,10,10,3,20 +C[C>A]T,68,57,189,7,20,173,62,65,14,116 +G[C>A]A,75,69,122,23,22,89,100,95,20,65 +G[C>A]C,45,45,84,14,16,83,27,50,15,92 +G[C>A]G,8,6,13,5,8,16,10,11,5,16 +G[C>A]T,49,66,101,10,13,109,72,63,13,68 +T[C>A]A,65,87,171,19,33,204,99,135,30,84 +T[C>A]C,86,61,132,18,24,202,60,82,33,84 +T[C>A]G,6,12,24,6,11,10,11,16,4,13 +T[C>A]T,91,134,250,24,28,265,151,182,44,100 +A[C>G]A,95,24,95,8,17,150,46,26,10,105 +A[C>G]C,40,19,61,8,7,70,36,19,14,60 +A[C>G]G,11,4,25,0,0,35,3,5,5,23 +A[C>G]T,96,18,118,7,9,145,45,35,16,112 +C[C>G]A,66,15,47,2,10,151,12,17,13,60 +C[C>G]C,52,10,36,3,9,92,23,15,11,77 +C[C>G]G,22,11,14,3,3,39,6,11,2,42 +C[C>G]T,79,19,50,7,7,203,28,31,15,106 +G[C>G]A,36,6,28,3,4,78,18,6,8,68 +G[C>G]C,30,9,43,3,8,64,22,8,7,61 +G[C>G]G,8,1,6,0,1,14,8,3,1,19 +G[C>G]T,51,8,34,6,10,107,23,18,9,93 +T[C>G]A,119,61,131,8,20,687,68,116,24,93 +T[C>G]C,85,32,80,9,11,448,52,56,13,110 +T[C>G]G,15,11,13,1,1,45,11,14,1,16 +T[C>G]T,239,90,236,14,28,1125,136,214,35,207 +A[C>T]A,126,91,238,24,46,187,80,70,54,140 +A[C>T]C,61,56,103,17,26,83,43,47,35,79 +A[C>T]G,149,272,257,92,119,185,147,168,134,145 +A[C>T]T,92,51,181,18,32,157,33,57,43,119 +C[C>T]A,75,76,112,21,46,140,53,66,40,100 +C[C>T]C,69,67,89,28,39,97,59,59,54,77 +C[C>T]G,93,163,139,45,72,108,110,108,85,88 +C[C>T]T,107,94,185,27,49,220,68,75,56,162 +G[C>T]A,68,75,86,13,37,99,74,58,38,123 +G[C>T]C,46,61,95,22,44,103,80,45,44,79 +G[C>T]G,90,230,176,71,81,155,124,118,102,116 +G[C>T]T,74,55,129,10,22,110,35,49,38,93 +T[C>T]A,139,178,198,28,63,520,136,224,97,106 +T[C>T]C,126,97,155,24,40,341,98,95,79,98 +T[C>T]G,80,128,117,35,68,101,79,103,53,87 +T[C>T]T,152,128,244,26,72,382,109,147,116,137 +A[T>A]A,43,66,115,17,16,86,44,26,13,57 +A[T>A]C,25,19,75,20,24,46,31,21,27,48 +A[T>A]G,37,30,99,10,17,76,22,29,16,64 +A[T>A]T,63,61,168,21,32,120,49,38,18,117 +C[T>A]A,31,32,85,4,9,61,15,16,3,63 +C[T>A]C,32,17,65,2,16,71,19,22,7,101 +C[T>A]G,55,24,105,9,14,108,34,22,9,68 +C[T>A]T,63,39,182,6,9,117,24,23,17,90 +G[T>A]A,22,17,42,5,4,47,10,9,7,40 +G[T>A]C,20,14,39,3,8,38,14,12,10,38 +G[T>A]G,23,16,33,5,14,48,18,15,5,39 +G[T>A]T,41,16,99,2,9,102,18,17,10,71 +T[T>A]A,31,63,124,16,29,122,51,60,18,76 +T[T>A]C,30,21,95,3,10,59,25,21,12,44 +T[T>A]G,19,15,57,2,5,43,9,9,7,39 +T[T>A]T,76,38,240,9,13,146,41,41,13,116 +A[T>C]A,90,101,150,29,49,189,79,57,52,157 +A[T>C]C,42,29,68,10,23,87,36,19,18,85 +A[T>C]G,58,47,85,10,24,108,55,31,26,131 +A[T>C]T,99,95,118,19,46,200,95,58,40,167 +C[T>C]A,39,50,57,7,15,69,44,14,20,73 +C[T>C]C,55,27,92,11,15,139,28,20,13,109 +C[T>C]G,43,37,42,10,14,59,28,25,15,88 +C[T>C]T,59,42,68,5,13,105,56,37,22,113 +G[T>C]A,40,63,79,11,32,101,50,22,12,81 +G[T>C]C,29,27,35,12,11,62,37,18,14,47 +G[T>C]G,26,41,32,11,9,58,39,17,15,64 +G[T>C]T,57,49,70,14,36,103,55,28,20,73 +T[T>C]A,56,83,106,7,23,93,41,28,23,76 +T[T>C]C,47,52,73,14,16,91,50,24,9,73 +T[T>C]G,25,29,57,4,13,61,18,21,13,54 +T[T>C]T,54,75,92,17,39,133,55,46,27,99 +A[T>G]A,29,25,49,2,13,61,29,18,13,49 +A[T>G]C,12,12,15,4,5,26,11,4,7,35 +A[T>G]G,43,16,40,6,6,74,17,19,6,44 +A[T>G]T,37,24,57,13,13,56,24,15,9,46 +C[T>G]A,18,7,12,5,3,30,10,9,3,29 +C[T>G]C,23,13,21,8,4,42,6,10,3,27 +C[T>G]G,43,11,46,1,6,100,21,25,7,62 +C[T>G]T,20,14,49,8,9,70,14,25,7,64 +G[T>G]A,7,10,17,0,1,27,7,7,5,26 +G[T>G]C,9,6,12,3,6,19,6,3,4,18 +G[T>G]G,25,15,37,11,4,76,16,12,8,57 +G[T>G]T,27,6,24,6,5,63,20,11,5,41 +T[T>G]A,39,25,39,2,9,66,21,8,9,52 +T[T>G]C,19,9,30,3,4,37,15,6,8,31 +T[T>G]G,39,18,73,4,10,86,21,18,4,45 +T[T>G]T,58,38,81,10,20,110,48,38,16,109 diff --git a/tests/test_data/nmf_framework/mvnmf/mvnmf_nsigs1_objective_init.npy b/tests/test_data/nmf_framework/mvnmf/objective_init_nsigs1.npy similarity index 100% rename from tests/test_data/nmf_framework/mvnmf/mvnmf_nsigs1_objective_init.npy rename to tests/test_data/nmf_framework/mvnmf/objective_init_nsigs1.npy diff --git a/tests/test_data/nmf_framework/mvnmf/mvnmf_nsigs2_objective_init.npy b/tests/test_data/nmf_framework/mvnmf/objective_init_nsigs2.npy similarity index 100% rename from tests/test_data/nmf_framework/mvnmf/mvnmf_nsigs2_objective_init.npy rename to tests/test_data/nmf_framework/mvnmf/objective_init_nsigs2.npy diff --git a/tests/test_data/utils/objective_input_nsigs1_H.npy b/tests/test_data/nmf_framework/utils_klnmf/H_nsigs1.npy similarity index 100% rename from tests/test_data/utils/objective_input_nsigs1_H.npy rename to tests/test_data/nmf_framework/utils_klnmf/H_nsigs1.npy diff --git a/tests/test_data/utils/objective_input_nsigs2_H.npy b/tests/test_data/nmf_framework/utils_klnmf/H_nsigs2.npy similarity index 100% rename from tests/test_data/utils/objective_input_nsigs2_H.npy rename to tests/test_data/nmf_framework/utils_klnmf/H_nsigs2.npy diff --git a/tests/test_data/nmf_framework/klnmf/klnmf_nsigs1_H_updated.npy b/tests/test_data/nmf_framework/utils_klnmf/H_updated_mu-joint_nsigs1.npy similarity index 61% rename from tests/test_data/nmf_framework/klnmf/klnmf_nsigs1_H_updated.npy rename to tests/test_data/nmf_framework/utils_klnmf/H_updated_mu-joint_nsigs1.npy index 74270423221a937d8fa4df97d7d79a1a172405f1..8a278d3619246b10ab08ab3c62a455d1ee920db6 100644 GIT binary patch delta 87 zcmcb>c!6<3LjV&45d2%`zzCx24>*9Rh7<_RAiCY*KL}JDfXKJ(g76=0hw%9pIQ#&~ NU!D)v$Dnh>0RS1_AnX7D delta 87 zcmcb>c!6<3LjWTK5ZvA30HPT7Zglt$0(%ca=-|l^`tp2;{LkYMdhP}Yov^{-A4t9M MLWe&fx_7Sw0Qbl%R{#J2 diff --git a/tests/test_data/nmf_framework/utils_klnmf/H_updated_mu-joint_nsigs2.npy b/tests/test_data/nmf_framework/utils_klnmf/H_updated_mu-joint_nsigs2.npy new file mode 100644 index 0000000000000000000000000000000000000000..de9d783cce836cd46badad156943c09478af3501 GIT binary patch literal 288 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1ZlV+i=qoAIaUsO_*m=~X4l#&V(4=E~51qv5u zBo?Fsxf(_~3Wf%nItsN4WCJdzrmj>om(>oR-mbpvbFS7Q(n6jgJ$9YLwd4==KIb<& zux;`7keRi@;lJvw+aIg9I28FkP+o7b(ZS+aU4e7qbO*iPVJy|#*Eys&U8!}Nw$b6h zQGeTa?NoWi4Oku%f=Bq+^q^!QtJ^zn2ex-tGVZ2XkZ) literal 0 HcmV?d00001 diff --git a/tests/test_data/nmf_framework/utils_klnmf/H_updated_mu-standard_nsigs1.npy b/tests/test_data/nmf_framework/utils_klnmf/H_updated_mu-standard_nsigs1.npy new file mode 100644 index 0000000000000000000000000000000000000000..8a278d3619246b10ab08ab3c62a455d1ee920db6 GIT binary patch literal 208 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1ZlV+i=qoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I$-ItqpcnmP)#3giMVCI%q*x6XkPMAsj108tGo5Sl@ByTgAFs5k(TZ`lRm XKiUrA^DS`r0g}HwAFPi-=ZFIU4t6om(>oR-mbpvbFS7Q(n6jgJ$9YLwd4==KIb<& zux;`7keRi@;lJvw+aIg9I28FkP+o7b(ZS+aU4e7qbO*iPVJy|#*Eys&U8!}Nw$b6h zQGeTa?NoWi4Oku%f=Bq+^q^!QtJ^zn2ex-tGVZ2XkZ) literal 0 HcmV?d00001 diff --git a/tests/test_data/utils/objective_input_nsigs1_W.npy b/tests/test_data/nmf_framework/utils_klnmf/W_nsigs1.npy similarity index 100% rename from tests/test_data/utils/objective_input_nsigs1_W.npy rename to tests/test_data/nmf_framework/utils_klnmf/W_nsigs1.npy diff --git a/tests/test_data/utils/objective_input_nsigs2_W.npy b/tests/test_data/nmf_framework/utils_klnmf/W_nsigs2.npy similarity index 100% rename from tests/test_data/utils/objective_input_nsigs2_W.npy rename to tests/test_data/nmf_framework/utils_klnmf/W_nsigs2.npy diff --git a/tests/test_data/nmf_framework/utils_klnmf/W_updated_mu-joint_nsigs1.npy b/tests/test_data/nmf_framework/utils_klnmf/W_updated_mu-joint_nsigs1.npy new file mode 100644 index 0000000000000000000000000000000000000000..331809a04cceecddaba82abb515f3fc7f1b4a822 GIT binary patch literal 896 zcmbVK`%?=59G^>;O3dMu>7y`rsr90cWcEt8+u7`qWv!aESxYsGrf9Bp zth$|7ZaP%!JTAfw%_DRXU5<3(@~D%!8-IZBZ{N@7htKEx`FyMB+mljuOd;eEW3V8DCZyG$*K9G$+Dvq`%;D`)6oZ;of7|I*4-)ms zk88Dh*t?_aHl=g0{=smw+0=wcTldv=ngJC`eOF+z3THBw%_k^MVnLE4iXT*t!DI19 zDFGGW-60IUxOW&Eob#G`NA=*bXbU|SA9edKgNk*db%L#dAR^uP{dEzEG^oRVWX)KVoQ8!R|{e52s@Kz5xQ|Pnj8(| zQ*-5Dw~dUvCzZeytMdvRlw$5SIsNQ;5sYqsuF{A$dL`KIe7ibAS@iG{jIN?IxH7@R z&^Z|mvw-l!Zf3DrX`>lZ~Hz+6xmiY>!zzD!E# z5!xtqR>_r4qGnmT6cW91uj;Ama_Dk}o}Pc<`TX*Jz2Cool)D0V?q(Pf(g|6V=$OtqX@!@z;JjoPJGB1r6k;G2sMaTTF`$r_E#(e6jaS_~@PyY3fYdp#Sa`z+)$^Y-n zW}GzJDr(A+xKI9;`;`JZqOM@fvJ^EwzeTK6AAy6DjI)LufFwiHjLIPaVHJ*BN!bPH z8N~4(vK*DWMYn8{Ww_nBnqp#E4C-(@hY%};yVEFD;w6C(i6^f-8-r#$i^yBAS#Tc; zpVb-^a5mE&^Ysux<7wd?@G}n?T`eaJb0rAgdo=lbi)hTh(UBM3D25-wA>`hVad7yf ztzv7{PV}8wtz}* zl(6;Ypq>h(p7yOl-g@=Rnsl^pb?{CI4#>@qk@oDTJ zt5f|$o6(}7uB4t7!rJ4*1G+;s#)CRDudNlp#QAjll@JNyLo-|?2W#N`AHVP7msTOf zbk1#bx)Ob@^I6s>vQQ#&W!m!hqr=~3`GxH&?9d%Lr+i#QHAyV{V&qN8V^+|Gei(xcBV%46~fQ-ZG^%>yW%W5MZ4W^ekc6?5&L6cN_ zbtgxG7WenlW9@kue?GV6x~3dHCmdZb1RO=PtwGmf+<-Cve(4;Qf$oW+75x!Xyb2Gm zHJeI7an++vQ*#kibf#_abOEGu>?=Pa^cI29~C!=2g;efad` zSV$fw?2cc#R(u%LslHWD*csqk-7)k|FhL#ilI~g~flF)hb|Vux*4wb_AM~Uk?T@eN ze<}m9>9*bOcgjpyU0-X~@tF{RkF|bS7aanU>)n)MLpHKULab_fqVOlrI(FGoF+%2Q z2a+leV4>IRuH_6l1YJ`r+wPOfXSr>m$eg4ui!&atf4N7MNY4W#X7FU1|3!7f%WI31!d15!K=1 zGMt``>sJWfCEA}LuDE3HQ7y*JHqFO%gGUf&`DW1VC>QP1!{nip4D@%6c4X>{FggBj z)p*c;+`QPx_I`H=-Y?AkXoup_6SV2LTYUy55=^;kZ41%0vm~};Fb7r}77578gE%5n zWUc&uFV;UV4*JzH0p%Mt4(~a>(Ay6@PKgqt-%LAv&wD?*A6Dh7kBC8Ks71@_=~yh7 zv`hAlfWVwiZc-*d^qIBhup|sLq9B;UOa;+~uCJCxW61*6<^c0ty!a*YVpnf5UUFZY zuP6zE+Vg;kMN~A5C&qfWG;q*y>i2RdH5-4qnUCuDEX?#?3~>06i;rqz6Z3pLNatja zmJTGLw(;D8(O^CvxQ~Q+Mejjrb}!4aO^k&`1=g>LX|UK(Elvsy!%P2WXM!6SZRt;o X$8AKYSmz!jCMl3OYpcJLz{7t5;2qs? literal 0 HcmV?d00001 diff --git a/tests/test_data/nmf_framework/utils_klnmf/W_updated_mu-standard_nsigs1.npy b/tests/test_data/nmf_framework/utils_klnmf/W_updated_mu-standard_nsigs1.npy new file mode 100644 index 0000000000000000000000000000000000000000..331809a04cceecddaba82abb515f3fc7f1b4a822 GIT binary patch literal 896 zcmbVK`%?=59G^>;O3dMu>7y`rsr90cWcEt8+u7`qWv!aESxYsGrf9Bp zth$|7ZaP%!JTAfw%_DRXU5<3(@~D%!8-IZBZ{N@7htKEx`FyMB+mljuOd;eEW3V8DCZyG$*K9G$+Dvq`%;D`)6oZ;of7|I*4-)ms zk88Dh*t?_aHl=g0{=smw+0=wcTldv=ngJC`eOF+z3THBw%_k^MVnLE4iXT*t!DI19 zDFGGW-60IUxOW&Eob#G`NA=*bXbU|SA9edKgNk*db%L#dAR^uP{dEzEG^oRVWX)KVoQ8!R|{e52s@Kz5xQ|Pnj8(| zQ*-5Dw~dUvCzZeytMdvRlw$5SIsNQ;5sYqsuF{A$dL`KIe7ibAS@iG{jIN?IxH7@R z&^Z|mvw-l!Zf3DrX`>lZ~Hz+o>9f4il#<%RBgauM$J2}NC}mmFhezT1jo#x!4MGh#Ad+L+1V z^eA6esIv-JdgM{fmsLke!b|tmJ#{V@oi5g0_b)u}-#+hOKX2KgVE+J`4k4Y8MTv_~ zjpb0>Nfb|}8^wu4VJ36ZIMGRw$(*?OzkLvg8~;(K9*Jhhf1KNP?{Xsj)4_?vC;i{( zFWD-;lGjy2O;i_78=r=it%SLY6bOdg=74EVA>;vqJLbs}?40Of5?Dp>Fy)k2wZ&uG z{W6-<*AuE5vrj;w$1%E^av;3!i23@#smq(l#Uw-}S7I3m2B7fcrrCrPWcK z)1BM+seLi_jzn*bc*{k0{ryiQX$5$?J*eR520n&VjJ~U{*pS+^P3O~(V2yfnm6eo_ z+H*Vh+MSj{t_mE_+m-e!J}bX-JpKOnFTrw14oPjND#}+WQ>eJ!4g%KVv9lJOTEZ zc(DIO9YoFc2|-QeSi5kixwEDP3;(d`-RM<^4u|)P(>;0csJF4d9C!v(zW0l?E;PK1 zifYhbNI?~qZV|dzh!xgT+b&`r8bXU&*b9DmZCTrWqx2-wOolas^K{sCCLh$%El27v z+o-?Ff?<5qRlj!?AIA3gQc5*BnB$nSHW&z@>o(WFiB=B&B}d1%$2m~W`qKBMW?;|E zx04CSDiJYbcA0N0hN#lu-cX+iC4oAcp5Qc8=tNL$P8Z{XM3bWz5DgzY8@J|U0S@G; zOl-CkLzX$;XyI87>HgLFSG_VYOBf|oJdZ+eH({Vm^$q&Flz)1RpF;QI1nKd2H2n7L zUy5fT$Fcio!@XYl33#2T$9AK zMRVbu!a5Om$5DHmRF2`k9TZ*TQaIU8xrp7w=&&@3xic3IvcCFTZ$|;L`#b72+A^$~ zSCiJ=Dh276F8P6Kf9%WCQ_F51#zyWE^FK-n{4*Qd23srO9C&M;mlX>p|~LuP49RwsHQZK zmGbz-kNsQ5e8s+XEeS&RaKi7KxNxRCUGpQ&0c%ykFOSURBKxGI?}y+f&>L5T!CRz| zoo_x&9rZ&1sWsJ?aS@h|%MYp6HPEv=*V7#)hW-Ecex6`ZjR&0{vdrqU(C2Hu>GDAZ zgsQw{=DjS4#2;c5gQ2*vy*_$tZmI z+;~}|=BcIJg^5FW|t?1A$F<>5f{^9Ku?7wRnuq4aCAJd)79&uq1 zPKKG(-HruA+c1_?b>ibc3(j4`smRhBaesJ?hl`n%H$+AP{E8L3?s_F+S$aFivOuw&Zgo%oRmlS5BF4`qepVXww$EI9>r z>U5dqD-l?|TenW6=RhRCYUNlXM1t|_aq<~9M*F9FGloiV=Sp*=`%*sCA^U2{=Xn_L zFJqOB=RzWpW;yQ+LY0rw`aRPdqx!0e2kyt=N^2EtI8O!NkV?8BON0+GGn7J41#xw_ zU!b7~bL^il$;(2pc7A$re-jgL$%a#EZa7vei5KXX642P(zIrN@3t7%!xN(;doA=ZR XlY%2~J$<(HnYjQ-D;7h;i5&a|YKGhh literal 0 HcmV?d00001 diff --git a/tests/test_data/utils/counts.csv b/tests/test_data/nmf_framework/utils_klnmf/counts.csv similarity index 100% rename from tests/test_data/utils/counts.csv rename to tests/test_data/nmf_framework/utils_klnmf/counts.csv diff --git a/tests/test_data/utils/kl_divergence_nsigs1_result.npy b/tests/test_data/nmf_framework/utils_klnmf/kl_divergence_nsigs1.npy similarity index 100% rename from tests/test_data/utils/kl_divergence_nsigs1_result.npy rename to tests/test_data/nmf_framework/utils_klnmf/kl_divergence_nsigs1.npy diff --git a/tests/test_data/utils/kl_divergence_nsigs2_result.npy b/tests/test_data/nmf_framework/utils_klnmf/kl_divergence_nsigs2.npy similarity index 100% rename from tests/test_data/utils/kl_divergence_nsigs2_result.npy rename to tests/test_data/nmf_framework/utils_klnmf/kl_divergence_nsigs2.npy diff --git a/tests/test_data/utils/poisson_llh_nsigs1_result.npy b/tests/test_data/nmf_framework/utils_klnmf/poisson_llh_nsigs1.npy similarity index 100% rename from tests/test_data/utils/poisson_llh_nsigs1_result.npy rename to tests/test_data/nmf_framework/utils_klnmf/poisson_llh_nsigs1.npy diff --git a/tests/test_data/utils/poisson_llh_nsigs2_result.npy b/tests/test_data/nmf_framework/utils_klnmf/poisson_llh_nsigs2.npy similarity index 100% rename from tests/test_data/utils/poisson_llh_nsigs2_result.npy rename to tests/test_data/nmf_framework/utils_klnmf/poisson_llh_nsigs2.npy diff --git a/tests/test_data/utils/samplewise_kl_divergence_nsigs1_result.npy b/tests/test_data/nmf_framework/utils_klnmf/samplewise_kl_divergence_nsigs1.npy similarity index 100% rename from tests/test_data/utils/samplewise_kl_divergence_nsigs1_result.npy rename to tests/test_data/nmf_framework/utils_klnmf/samplewise_kl_divergence_nsigs1.npy diff --git a/tests/test_data/utils/samplewise_kl_divergence_nsigs2_result.npy b/tests/test_data/nmf_framework/utils_klnmf/samplewise_kl_divergence_nsigs2.npy similarity index 100% rename from tests/test_data/utils/samplewise_kl_divergence_nsigs2_result.npy rename to tests/test_data/nmf_framework/utils_klnmf/samplewise_kl_divergence_nsigs2.npy diff --git a/tests/test_klnmf.py b/tests/test_klnmf.py index 4d4c13f..a1f6919 100644 --- a/tests/test_klnmf.py +++ b/tests/test_klnmf.py @@ -1,3 +1,5 @@ +import pickle + import numpy as np import pandas as pd import pytest @@ -10,31 +12,28 @@ @pytest.fixture def counts(): - return pd.read_csv(f"{PATH}/nmf_framework/counts.csv", index_col=0) + return pd.read_csv(f"{PATH_TEST_DATA}/counts.csv", index_col=0) @pytest.fixture(params=[1, 2]) -def model(request): - return klnmf.KLNMF(n_signatures=request.param) +def n_signatures(request): + return request.param @pytest.fixture -def path(model): - return f"{PATH_TEST_DATA}/klnmf_nsigs{model.n_signatures}" +def W_init(n_signatures): + return np.load(f"{PATH_TEST_DATA}/W_init_nsigs{n_signatures}.npy") @pytest.fixture -def W_init(path): - return np.load(f"{path}_W_init.npy") +def H_init(n_signatures): + return np.load(f"{PATH_TEST_DATA}/H_init_nsigs{n_signatures}.npy") @pytest.fixture -def H_init(path): - return np.load(f"{path}_H_init.npy") - - -@pytest.fixture -def model_init(model, counts, W_init, H_init): +def model_init(counts, W_init, H_init): + n_signatures = W_init.shape[1] + model = klnmf.KLNMF(n_signatures=n_signatures) model.X = counts.values model.W = W_init model.H = H_init @@ -42,37 +41,42 @@ def model_init(model, counts, W_init, H_init): @pytest.fixture -def objective_init(path): - return np.load(f"{path}_objective_init.npy") - - -@pytest.fixture -def W_updated(path): - return np.load(f"{path}_W_updated.npy") +def objective_init(n_signatures): + return np.load(f"{PATH_TEST_DATA}/objective_init_nsigs{n_signatures}.npy") -@pytest.fixture -def H_updated(path): - return np.load(f"{path}_H_updated.npy") +def test_objective_function(model_init, objective_init): + assert np.allclose(model_init.objective_function(), objective_init) -class TestKLNMF: - def test_objective_function(self, model_init, objective_init): - assert np.allclose(model_init.objective_function(), objective_init) +@pytest.mark.parametrize("update_method", ["mu-standard", "mu-joint"]) +class TestUpdatesKLNMF: + @pytest.fixture + def WH_updated(self, n_signatures, update_method): + with open( + f"{PATH_TEST_DATA}/WH_updated_{update_method}_nsigs{n_signatures}.pkl", "rb" + ) as f: + WH_updated = pickle.load(f) + return WH_updated - def test_update_W(self, model_init, W_updated): - model_init._update_W() + def test_update_WH(self, model_init, update_method, WH_updated): + model_init.update_method = update_method + model_init._update_WH() + W_updated, H_updated = WH_updated assert np.allclose(model_init.W, W_updated) - - def test_update_H(self, model_init, H_updated): - model_init._update_H() assert np.allclose(model_init.H, H_updated) - -@pytest.mark.parametrize("n_signatures", [1, 2]) -def test_given_signatures(counts, n_signatures): - given_signatures = counts.iloc[:, :n_signatures].astype(float).copy() - given_signatures /= given_signatures.sum(axis=0) - model = klnmf.KLNMF(n_signatures=n_signatures, min_iterations=3, max_iterations=3) - model.fit(counts, given_signatures=given_signatures) - assert np.allclose(given_signatures, model.signatures) + def test_given_signatures(self, n_signatures, update_method, counts): + for n_given_signatures in range(1, n_signatures + 1): + given_signatures = counts.iloc[:, :n_given_signatures].astype(float).copy() + given_signatures /= given_signatures.sum(axis=0) + model = klnmf.KLNMF( + n_signatures=n_signatures, + update_method=update_method, + min_iterations=3, + max_iterations=3, + ) + model.fit(counts, given_signatures=given_signatures) + assert np.allclose( + given_signatures, model.signatures.iloc[:, :n_given_signatures] + ) diff --git a/tests/test_multimodal_corrnmf.py b/tests/test_multimodal_corrnmf.py index a671c40..e276c17 100644 --- a/tests/test_multimodal_corrnmf.py +++ b/tests/test_multimodal_corrnmf.py @@ -11,22 +11,6 @@ DIM_EMBEDDINGS = 2 -@pytest.fixture -def U_init(): - """ - Initial joint sample embeddings. - """ - return np.load(f"{PATH_TEST_DATA}/U_init.npy") - - -@pytest.fixture -def sigma_sq_init(): - """ - Initial joint variance. - """ - return np.load(f"{PATH_TEST_DATA}/sigma_sq_init.npy") - - @pytest.fixture def counts(): """ @@ -53,6 +37,16 @@ def alphas_init(): ] +@pytest.fixture +def _ps(): + return [np.load(f"{PATH_TEST_DATA}/model{n}_p.npy") for n in range(N_MODALITIES)] + + +@pytest.fixture +def _auxs(counts, _ps): + return [np.einsum("vd,vkd->kd", data.values, p) for data, p in zip(counts, _ps)] + + @pytest.fixture def Ls_init(): return [ @@ -60,6 +54,22 @@ def Ls_init(): ] +@pytest.fixture +def U_init(): + """ + Initial joint sample embeddings. + """ + return np.load(f"{PATH_TEST_DATA}/U_init.npy") + + +@pytest.fixture +def sigma_sq_init(): + """ + Initial joint variance. + """ + return np.load(f"{PATH_TEST_DATA}/sigma_sq_init.npy") + + @pytest.fixture def multi_model_init(counts, Ws_init, alphas_init, Ls_init, U_init, sigma_sq_init): models = [] @@ -89,26 +99,27 @@ def multi_model_init(counts, Ws_init, alphas_init, Ls_init, U_init, sigma_sq_ini return multi_model -@pytest.fixture -def _ps(): - return [np.load(f"{PATH_TEST_DATA}/model{n}_p.npy") for n in range(N_MODALITIES)] - - -@pytest.fixture -def _auxs(counts, _ps): - return [np.einsum("vd,vkd->kd", data.values, p) for data, p in zip(counts, _ps)] - - @pytest.fixture def objective_init(): return np.load(f"{PATH_TEST_DATA}/objective_init.npy") +def test_objective_function(multi_model_init, objective_init): + assert np.allclose(multi_model_init.objective_function(), objective_init) + + @pytest.fixture def surrogate_objective_init(): return np.load(f"{PATH_TEST_DATA}/surrogate_objective_init.npy") +def test_surrogate_objective_function(multi_model_init, _ps, surrogate_objective_init): + assert np.allclose( + multi_model_init._surrogate_objective_function(_ps), + surrogate_objective_init, + ) + + @pytest.fixture def Ws_updated(): return [ @@ -141,21 +152,9 @@ def sigma_sq_updated(): return np.load(f"{PATH_TEST_DATA}/sigma_sq_updated.npy") -class TestMultimodalCorrNMFDet: - def test_objective_function(self, multi_model_init, objective_init): - assert np.allclose(multi_model_init.objective_function(), objective_init) - - def test_surrogate_objective_function( - self, multi_model_init, _ps, surrogate_objective_init - ): - assert np.allclose( - multi_model_init._surrogate_objective_function(_ps), - surrogate_objective_init, - ) - +class TestUpdatesMultimodalCorrNMFDet: def test_update_W(self, multi_model_init, Ws_updated): - given_signatures = [None for _ in range(N_MODALITIES)] - multi_model_init._update_Ws(given_signatures) + multi_model_init._update_Ws() for model, W_updated in zip(multi_model_init.models, Ws_updated): assert np.allclose(model.W, W_updated) @@ -192,58 +191,64 @@ def test_update_sigma_sq(self, multi_model_init, sigma_sq_updated): for model in multi_model_init.models: assert np.allclose(model.sigma_sq, sigma_sq_updated) - -@pytest.mark.parametrize("ns_signatures", [[1, 2], [2, 2]]) -def test_given_signatures(counts, ns_signatures): - given_signatures0 = counts[0].iloc[:, : ns_signatures[0]].astype(float).copy() - given_signatures0 /= given_signatures0.sum(axis=0) - given_signatures = [given_signatures0, None] - multi_model = multimodal_corrnmf.MultimodalCorrNMF( - n_modalities=2, - ns_signatures=ns_signatures, - dim_embeddings=2, - min_iterations=3, - max_iterations=3, - ) - multi_model.fit(counts, given_signatures=given_signatures) - assert np.allclose(given_signatures0, multi_model.models[0].W) - assert not np.allclose(given_signatures0, multi_model.models[1].W) - - -@pytest.mark.parametrize( - "ns_signatures,dim_embeddings", [([1, 2], 1), ([2, 2], 1), ([2, 2], 2)] -) -def test_given_signature_embeddings(counts, ns_signatures, dim_embeddings): - given_signature_embeddings0 = np.random.uniform( - size=(dim_embeddings, ns_signatures[0]) + @pytest.mark.parametrize("ns_signatures", [[1, 2], [2, 2]]) + def test_given_signatures(self, ns_signatures, counts): + for n_given_signatures in range(1, ns_signatures[0] + 1): + given_signatures0 = ( + counts[0].iloc[:, :n_given_signatures].astype(float).copy() + ) + given_signatures0 /= given_signatures0.sum(axis=0) + given_signatures = [given_signatures0, None] + multi_model = multimodal_corrnmf.MultimodalCorrNMF( + n_modalities=2, + ns_signatures=ns_signatures, + dim_embeddings=2, + min_iterations=3, + max_iterations=3, + ) + multi_model.fit(counts, given_signatures=given_signatures) + assert np.allclose( + given_signatures0, + multi_model.models[0].signatures.iloc[:, :n_given_signatures], + ) + assert not np.allclose( + given_signatures0, + multi_model.models[1].signatures.iloc[:, :n_given_signatures], + ) + + @pytest.mark.parametrize( + "ns_signatures,dim_embeddings", [([1, 2], 1), ([2, 2], 1), ([2, 2], 2)] ) - given_signature_embeddings = [given_signature_embeddings0, None] - multi_model = multimodal_corrnmf.MultimodalCorrNMF( - n_modalities=2, - ns_signatures=ns_signatures, - dim_embeddings=dim_embeddings, - min_iterations=3, - max_iterations=3, - ) - multi_model.fit(counts, given_signature_embeddings=given_signature_embeddings) - assert np.allclose(given_signature_embeddings0, multi_model.models[0].L) - assert not np.allclose(given_signature_embeddings0, multi_model.models[1].L) - + def test_given_signature_embeddings(self, ns_signatures, dim_embeddings, counts): + given_signature_embeddings0 = np.random.uniform( + size=(dim_embeddings, ns_signatures[0]) + ) + given_signature_embeddings = [given_signature_embeddings0, None] + multi_model = multimodal_corrnmf.MultimodalCorrNMF( + n_modalities=2, + ns_signatures=ns_signatures, + dim_embeddings=dim_embeddings, + min_iterations=3, + max_iterations=3, + ) + multi_model.fit(counts, given_signature_embeddings=given_signature_embeddings) + assert np.allclose(given_signature_embeddings0, multi_model.models[0].L) + assert not np.allclose(given_signature_embeddings0, multi_model.models[1].L) -@pytest.mark.parametrize( - "ns_signatures,dim_embeddings", [([1, 2], 1), ([2, 2], 1), ([2, 2], 2)] -) -def test_given_sample_embeddings(counts, ns_signatures, dim_embeddings): - n_samples = len(counts[0].columns) - given_sample_embeddings = np.random.uniform(size=(dim_embeddings, n_samples)) - multi_model = multimodal_corrnmf.MultimodalCorrNMF( - n_modalities=2, - ns_signatures=ns_signatures, - dim_embeddings=dim_embeddings, - min_iterations=3, - max_iterations=3, + @pytest.mark.parametrize( + "ns_signatures,dim_embeddings", [([1, 2], 1), ([2, 2], 1), ([2, 2], 2)] ) - multi_model.fit(counts, given_sample_embeddings=given_sample_embeddings) + def test_given_sample_embeddings(self, ns_signatures, dim_embeddings, counts): + n_samples = len(counts[0].columns) + given_sample_embeddings = np.random.uniform(size=(dim_embeddings, n_samples)) + multi_model = multimodal_corrnmf.MultimodalCorrNMF( + n_modalities=2, + ns_signatures=ns_signatures, + dim_embeddings=dim_embeddings, + min_iterations=3, + max_iterations=3, + ) + multi_model.fit(counts, given_sample_embeddings=given_sample_embeddings) - for model in multi_model.models: - assert np.allclose(given_sample_embeddings, model.U) + for model in multi_model.models: + assert np.allclose(given_sample_embeddings, model.U) diff --git a/tests/test_mvnmf.py b/tests/test_mvnmf.py index 412957f..cd11a2a 100644 --- a/tests/test_mvnmf.py +++ b/tests/test_mvnmf.py @@ -10,58 +10,52 @@ @pytest.fixture def counts(): - return pd.read_csv(f"{PATH}/nmf_framework/counts.csv", index_col=0) + return pd.read_csv(f"{PATH_TEST_DATA}/counts.csv", index_col=0) @pytest.fixture(params=[1, 2]) -def model(request): - return mvnmf.MvNMF(n_signatures=request.param) +def n_signatures(request): + return request.param @pytest.fixture -def path(model): - return f"{PATH_TEST_DATA}/mvnmf_nsigs{model.n_signatures}" +def W_init(n_signatures): + return np.load(f"{PATH_TEST_DATA}/W_init_nsigs{n_signatures}.npy") @pytest.fixture -def W_init(path): - return np.load(f"{path}_W_init.npy") +def H_init(n_signatures): + return np.load(f"{PATH_TEST_DATA}/H_init_nsigs{n_signatures}.npy") @pytest.fixture -def H_init(path): - return np.load(f"{path}_H_init.npy") - - -@pytest.fixture -def model_init(model, counts, W_init, H_init): +def model_init(counts, W_init, H_init): + n_signatures = W_init.shape[1] + model = mvnmf.MvNMF(n_signatures=n_signatures, lam=1.0, delta=1.0) model.X = counts.values model.W = W_init model.H = H_init - model.lam = 1.0 - model.delta = 1.0 model._gamma = 1.0 return model @pytest.fixture -def objective_init(path): - return np.load(f"{path}_objective_init.npy") +def objective_init(n_signatures): + return np.load(f"{PATH_TEST_DATA}/objective_init_nsigs{n_signatures}.npy") -@pytest.fixture -def W_updated(path): - return np.load(f"{path}_W_updated.npy") +def test_objective_function(model_init, objective_init): + assert np.allclose(model_init.objective_function(), objective_init) -@pytest.fixture -def H_updated(path): - return np.load(f"{path}_H_updated.npy") - +class TestUpdatesMVNMF: + @pytest.fixture + def W_updated(self, n_signatures): + return np.load(f"{PATH_TEST_DATA}/W_updated_nsigs{n_signatures}.npy") -class TestMVNMF: - def test_objective_function(self, model_init, objective_init): - assert np.allclose(model_init.objective_function(), objective_init) + @pytest.fixture + def H_updated(self, n_signatures): + return np.load(f"{PATH_TEST_DATA}/H_updated_nsigs{n_signatures}.npy") def test_update_W(self, model_init, W_updated): model_init._update_W() @@ -71,11 +65,14 @@ def test_update_H(self, model_init, H_updated): model_init._update_H() assert np.allclose(model_init.H, H_updated) - -@pytest.mark.parametrize("n_signatures", [1, 2]) -def test_given_signatures(counts, n_signatures): - given_signatures = counts.iloc[:, :n_signatures].astype(float).copy() - given_signatures /= given_signatures.sum(axis=0) - model = mvnmf.MvNMF(n_signatures=n_signatures, min_iterations=3, max_iterations=3) - model.fit(counts, given_signatures=given_signatures) - assert np.allclose(given_signatures, model.signatures) + def test_given_signatures(self, n_signatures, counts): + for n_given_signatures in range(1, n_signatures + 1): + given_signatures = counts.iloc[:, :n_given_signatures].astype(float).copy() + given_signatures /= given_signatures.sum(axis=0) + model = mvnmf.MvNMF( + n_signatures=n_signatures, min_iterations=3, max_iterations=3 + ) + model.fit(counts, given_signatures=given_signatures) + assert np.allclose( + given_signatures, model.signatures.iloc[:, :n_given_signatures] + ) diff --git a/tests/test_utils.py b/tests/test_utils.py deleted file mode 100644 index c5f7f03..0000000 --- a/tests/test_utils.py +++ /dev/null @@ -1,62 +0,0 @@ -import numpy as np -import pandas as pd -import pytest - -from salamander.utils import kl_divergence, poisson_llh, samplewise_kl_divergence - -PATH_TEST_DATA = "tests/test_data" -PATH_TEST_DATA_UTILS = f"{PATH_TEST_DATA}/utils" - - -@pytest.fixture -def counts(): - return pd.read_csv(f"{PATH_TEST_DATA_UTILS}/counts.csv", index_col=0) - - -@pytest.fixture(params=[1, 2]) -def n_signatures(request): - return request.param - - -@pytest.fixture -def objective_inputs(counts, n_signatures): - path = f"{PATH_TEST_DATA_UTILS}/objective_input_nsigs{n_signatures}" - W = np.load(f"{path}_W.npy") - H = np.load(f"{path}_H.npy") - - return (counts.values, W, H) - - -@pytest.fixture -def kl_divergence_output(n_signatures): - path = f"{PATH_TEST_DATA_UTILS}/kl_divergence_nsigs{n_signatures}_result.npy" - return np.load(path) - - -def test_kl_divergence(objective_inputs, kl_divergence_output): - assert np.allclose(kl_divergence(*objective_inputs), kl_divergence_output) - - -@pytest.fixture -def samplewise_kl_divergence_output(n_signatures): - path = ( - f"{PATH_TEST_DATA_UTILS}/" - f"samplewise_kl_divergence_nsigs{n_signatures}_result.npy" - ) - return np.load(path) - - -def test_samplewise_kl_divergence(objective_inputs, samplewise_kl_divergence_output): - assert np.allclose( - samplewise_kl_divergence(*objective_inputs), samplewise_kl_divergence_output - ) - - -@pytest.fixture -def poisson_llh_output(n_signatures): - path = f"{PATH_TEST_DATA_UTILS}/poisson_llh_nsigs{n_signatures}_result.npy" - return np.load(path) - - -def test_poisson_llh(objective_inputs, poisson_llh_output): - assert np.allclose(poisson_llh(*objective_inputs), poisson_llh_output) diff --git a/tests/test_utils_klnmf.py b/tests/test_utils_klnmf.py new file mode 100644 index 0000000..0590766 --- /dev/null +++ b/tests/test_utils_klnmf.py @@ -0,0 +1,127 @@ +import numpy as np +import pandas as pd +import pytest + +from salamander.nmf_framework import _utils_klnmf + +PATH_TEST_DATA = "tests/test_data" +PATH_TEST_DATA_UTILS_KLNMF = f"{PATH_TEST_DATA}/nmf_framework/utils_klnmf" + + +@pytest.fixture +def counts(): + return pd.read_csv(f"{PATH_TEST_DATA_UTILS_KLNMF}/counts.csv", index_col=0) + + +@pytest.fixture(params=[1, 2]) +def n_signatures(request): + return request.param + + +@pytest.fixture +def matrices_input(counts, n_signatures): + W = np.load(f"{PATH_TEST_DATA_UTILS_KLNMF}/W_nsigs{n_signatures}.npy") + H = np.load(f"{PATH_TEST_DATA_UTILS_KLNMF}/H_nsigs{n_signatures}.npy") + + return (counts.values, W, H) + + +@pytest.fixture +def kl_divergence_output(n_signatures): + path = f"{PATH_TEST_DATA_UTILS_KLNMF}/kl_divergence_nsigs{n_signatures}.npy" + return np.load(path) + + +def test_kl_divergence(matrices_input, kl_divergence_output): + assert np.allclose( + _utils_klnmf.kl_divergence(*matrices_input), kl_divergence_output + ) + + +@pytest.fixture +def samplewise_kl_divergence_output(n_signatures): + path = ( + f"{PATH_TEST_DATA_UTILS_KLNMF}/" + f"samplewise_kl_divergence_nsigs{n_signatures}.npy" + ) + return np.load(path) + + +def test_samplewise_kl_divergence(matrices_input, samplewise_kl_divergence_output): + assert np.allclose( + _utils_klnmf.samplewise_kl_divergence(*matrices_input), + samplewise_kl_divergence_output, + ) + + +@pytest.fixture +def poisson_llh_output(n_signatures): + path = f"{PATH_TEST_DATA_UTILS_KLNMF}/poisson_llh_nsigs{n_signatures}.npy" + return np.load(path) + + +def test_poisson_llh(matrices_input, poisson_llh_output): + assert np.allclose(_utils_klnmf.poisson_llh(*matrices_input), poisson_llh_output) + + +@pytest.fixture +def W_updated(n_signatures): + path = f"{PATH_TEST_DATA_UTILS_KLNMF}/W_updated_mu-standard_nsigs{n_signatures}.npy" + return np.load(path) + + +def test_update_W(matrices_input, W_updated): + W_updated_utils = _utils_klnmf.update_W(*matrices_input) + assert np.allclose(W_updated_utils, W_updated) + + +def test_given_signatures_update_W(matrices_input): + X, W, H = matrices_input + n_signatures = W.shape[1] + + for n_given_signatures in range(1, n_signatures + 1): + W_updated = _utils_klnmf.update_W( + X, W.copy(), H, n_given_signatures=n_given_signatures + ) + assert np.array_equal( + W_updated[:, :n_given_signatures], W[:, :n_given_signatures] + ) + + +@pytest.fixture +def H_updated(n_signatures): + path = f"{PATH_TEST_DATA_UTILS_KLNMF}/H_updated_mu-standard_nsigs{n_signatures}.npy" + return np.load(path) + + +def test_update_H(matrices_input, H_updated): + H_updated_utils = _utils_klnmf.update_H(*matrices_input) + assert np.allclose(H_updated_utils, H_updated) + + +@pytest.fixture +def WH_updated(n_signatures): + suffix = f"_updated_mu-standard_nsigs{n_signatures}.npy" + path_W = f"{PATH_TEST_DATA_UTILS_KLNMF}/W{suffix}" + path_H = f"{PATH_TEST_DATA_UTILS_KLNMF}/H{suffix}" + return np.load(path_W), np.load(path_H) + + +def test_update_WH(matrices_input, WH_updated): + W_updated, H_updated = WH_updated + W_updated_utils, H_updated_utils = _utils_klnmf.update_WH(*matrices_input) + assert np.allclose(W_updated_utils, W_updated) + assert np.allclose(H_updated_utils, H_updated) + + +def test_given_signatures_update_WH(matrices_input): + X, W, H = matrices_input + n_signatures = W.shape[1] + + for n_given_signatures in range(1, n_signatures + 1): + W_updated, _ = _utils_klnmf.update_WH( + X, W.copy(), H, n_given_signatures=n_given_signatures + ) + assert np.array_equal( + W_updated[:, :n_given_signatures], W[:, :n_given_signatures] + ) From 4f8d7c85343c88f5fafe5c14da491e394f3dc4d4 Mon Sep 17 00:00:00 2001 From: BeGeiger Date: Fri, 20 Oct 2023 18:14:14 -0400 Subject: [PATCH 09/13] remove surrogate objective during fit CorrNMF no longer computes the surrogate objective function value in each iteration. This decreases the runtime --- src/salamander/nmf_framework/corrnmf.py | 8 +++----- src/salamander/nmf_framework/corrnmf_det.py | 8 ++------ src/salamander/nmf_framework/multimodal_corrnmf.py | 13 +++++-------- tests/test_corrnmf.py | 4 ++-- tests/test_multimodal_corrnmf.py | 4 ++-- 5 files changed, 14 insertions(+), 23 deletions(-) diff --git a/src/salamander/nmf_framework/corrnmf.py b/src/salamander/nmf_framework/corrnmf.py index 332eebd..3b7506e 100644 --- a/src/salamander/nmf_framework/corrnmf.py +++ b/src/salamander/nmf_framework/corrnmf.py @@ -245,13 +245,11 @@ def objective_function(self, penalize_sample_embeddings=True) -> float: def objective(self) -> str: return "maximize" - def _surrogate_objective_function( - self, p, penalize_sample_embeddings=True - ) -> float: + def _surrogate_objective_function(self, penalize_sample_embeddings=True) -> float: """ - The surrogate lower bound of the ELBO after - introducing the auxiliary parameters p. + The surrogate lower bound of the ELBO. """ + p = self._update_p() exposures = self.exposures.values aux = np.log(self.W)[:, :, None] + np.log(exposures)[None, :, :] - np.log(p) sof_value = np.einsum("VD,VKD,VKD->", self.X, p, aux, optimize="greedy").item() diff --git a/src/salamander/nmf_framework/corrnmf_det.py b/src/salamander/nmf_framework/corrnmf_det.py index bbaf3c0..6855456 100644 --- a/src/salamander/nmf_framework/corrnmf_det.py +++ b/src/salamander/nmf_framework/corrnmf_det.py @@ -220,8 +220,6 @@ def fit( init_kwargs=init_kwargs, ) of_values = [self.objective_function()] - sof_values = [self.objective_function()] - n_iteration = 0 converged = False @@ -239,16 +237,14 @@ def fit( if self.n_given_signatures < self.n_signatures: self._update_W() + prev_of_value = of_values[-1] of_values.append(self.objective_function()) - prev_sof_value = sof_values[-1] - sof_values.append(self._surrogate_objective_function(p)) - rel_change = (sof_values[-1] - prev_sof_value) / np.abs(prev_sof_value) + rel_change = (of_values[-1] - prev_of_value) / np.abs(prev_of_value) converged = ( rel_change < self.tol and n_iteration >= self.min_iterations ) or (n_iteration >= self.max_iterations) if history: self.history["objective_function"] = of_values[1:] - self.history["surrogate_objective_function"] = sof_values[1:] return self diff --git a/src/salamander/nmf_framework/multimodal_corrnmf.py b/src/salamander/nmf_framework/multimodal_corrnmf.py index 7d2aa82..d7e716d 100755 --- a/src/salamander/nmf_framework/multimodal_corrnmf.py +++ b/src/salamander/nmf_framework/multimodal_corrnmf.py @@ -126,13 +126,14 @@ def objective_function(self) -> float: def objective(self) -> str: return "maximize" - def _surrogate_objective_function(self, ps) -> float: + def _surrogate_objective_function(self) -> float: """ The surrogate lower bound of the ELBO. """ + ps = self._update_ps() sof_value = np.sum( [ - model._surrogate_objective_function(p, penalize_sample_embeddings=False) + model._surrogate_objective_function(penalize_sample_embeddings=False) for model, p in zip(self.models, ps) ] ) @@ -387,8 +388,6 @@ def fit( init_kwargs=init_kwargs, ) of_values = [self.objective_function()] - sof_values = [self.objective_function()] - n_iteration = 0 converged = False @@ -404,17 +403,15 @@ def fit( self._update_sigma_sq() self._update_Ws() + prev_of_value = of_values[-1] of_values.append(self.objective_function()) - prev_sof_value = sof_values[-1] - sof_values.append(self._surrogate_objective_function(ps)) - rel_change = (sof_values[-1] - prev_sof_value) / np.abs(prev_sof_value) + rel_change = (of_values[-1] - prev_of_value) / np.abs(prev_of_value) converged = ( rel_change < self.tol and n_iteration >= self.min_iterations ) or (n_iteration >= self.max_iterations) if history: self.history["objective_function"] = of_values[1:] - self.history["surrogate_objective_function"] = sof_values[1:] return self diff --git a/tests/test_corrnmf.py b/tests/test_corrnmf.py index cbc3353..bba71e9 100644 --- a/tests/test_corrnmf.py +++ b/tests/test_corrnmf.py @@ -97,9 +97,9 @@ def surrogate_objective_init(path_suffix): return np.load(f"{PATH_TEST_DATA}/surrogate_objective_init_{path_suffix}") -def test_surrogate_objective_function(model_init, _p, surrogate_objective_init): +def test_surrogate_objective_function(model_init, surrogate_objective_init): assert np.allclose( - model_init._surrogate_objective_function(_p), surrogate_objective_init + model_init._surrogate_objective_function(), surrogate_objective_init ) diff --git a/tests/test_multimodal_corrnmf.py b/tests/test_multimodal_corrnmf.py index e276c17..ecbbae7 100644 --- a/tests/test_multimodal_corrnmf.py +++ b/tests/test_multimodal_corrnmf.py @@ -113,9 +113,9 @@ def surrogate_objective_init(): return np.load(f"{PATH_TEST_DATA}/surrogate_objective_init.npy") -def test_surrogate_objective_function(multi_model_init, _ps, surrogate_objective_init): +def test_surrogate_objective_function(multi_model_init, surrogate_objective_init): assert np.allclose( - multi_model_init._surrogate_objective_function(_ps), + multi_model_init._surrogate_objective_function(), surrogate_objective_init, ) From 25f05370ba1be7bcda7c56f5fca21bfd738056f4 Mon Sep 17 00:00:00 2001 From: BeGeiger Date: Fri, 20 Oct 2023 19:00:37 -0400 Subject: [PATCH 10/13] improve signature plot The bar chart colors can now also be overwritten for signatures with the standard 96 channels --- src/salamander/plot.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/salamander/plot.py b/src/salamander/plot.py index 8beea5d..0486309 100644 --- a/src/salamander/plot.py +++ b/src/salamander/plot.py @@ -295,31 +295,31 @@ def corr_plot( return clustergrid -def _get_colors_signature_plot(colors, mutation_types): +def _get_colors_signature_plot(mutation_types, colors=None): """ - Given the colors argument of sigplot_bar and the mutation types, return the - colors used in the signature bar chart. + Given the mutation types and the colors argument of sigplot_bar, return the + final colors used in the signature bar chart. """ n_features = len(mutation_types) - if colors == "SBS96" or (n_features == 96 and all(mutation_types == SBS_TYPES_96)): + if colors == "SBS96" or ( + n_features == 96 and all(mutation_types == SBS_TYPES_96) and colors is None + ): if n_features != 96: raise ValueError( "The standard SBS colors can only be used " "when the signatures have 96 features." ) - colors = COLORS_SBS96 elif colors == "Indel83" or ( - n_features == 83 and all(mutation_types == INDEL_TYPES_83) + n_features == 83 and all(mutation_types == INDEL_TYPES_83) and colors is None ): if n_features != 83: raise ValueError( "The standard Indel colors can only be used " "when the signatures have 83 features." ) - colors = COLORS_INDEL83 elif type(colors) in [str, tuple]: @@ -365,7 +365,7 @@ def _signature_plot( signature_normalized = signature / signature.sum(axis=0) mutation_types = signature.index - colors = _get_colors_signature_plot(colors, mutation_types) + colors = _get_colors_signature_plot(mutation_types, colors) ax.set_title(signature_normalized.columns[0]) ax.spines["left"].set_visible(False) From 9ec94d64e0de125759602f76b0bc4d15fb65f683 Mon Sep 17 00:00:00 2001 From: BeGeiger Date: Fri, 20 Oct 2023 22:10:42 -0400 Subject: [PATCH 11/13] improve exposure plot The exposure plot now allows a custom sample order. This is necessary for the multimodal exposure visualization to be consistent across modalities. --- .../nmf_framework/multimodal_corrnmf.py | 21 ++- src/salamander/nmf_framework/signature_nmf.py | 3 +- src/salamander/plot.py | 123 +++++++++++++++--- tests/test_plot.py | 64 +++++++++ 4 files changed, 190 insertions(+), 21 deletions(-) create mode 100644 tests/test_plot.py diff --git a/src/salamander/nmf_framework/multimodal_corrnmf.py b/src/salamander/nmf_framework/multimodal_corrnmf.py index d7e716d..451e14e 100755 --- a/src/salamander/nmf_framework/multimodal_corrnmf.py +++ b/src/salamander/nmf_framework/multimodal_corrnmf.py @@ -16,7 +16,13 @@ from scipy import optimize from scipy.spatial.distance import squareform -from ..plot import corr_plot, embeddings_plot, salamander_style, signatures_plot +from ..plot import ( + _get_sample_order, + corr_plot, + embeddings_plot, + salamander_style, + signatures_plot, +) from ..utils import type_checker, value_checker from . import _utils_corrnmf from .corrnmf_det import CorrNMFDet @@ -455,6 +461,7 @@ def plot_signatures( @salamander_style def plot_exposures( self, + sample_order=None, reorder_signatures=True, annotate_samples=True, colors=None, @@ -480,10 +487,20 @@ def plot_exposures( if colors is None: colors = [None for _ in range(self.n_modalities)] + if sample_order is None: + all_exposures = pd.concat([model.exposures for model in self.models]) + sample_order = _get_sample_order(all_exposures) + for n, (ax, model, cols) in enumerate(zip(axes, self.models, colors)): + if n < self.n_modalities - 1: + annotate = False + else: + annotate = annotate_samples + ax = model.plot_exposures( + sample_order=sample_order, reorder_signatures=reorder_signatures, - annotate_samples=annotate_samples, + annotate_samples=annotate, colors=cols, ncol_legend=ncol_legend, ax=ax, diff --git a/src/salamander/nmf_framework/signature_nmf.py b/src/salamander/nmf_framework/signature_nmf.py index 8aca3ac..0ba9481 100644 --- a/src/salamander/nmf_framework/signature_nmf.py +++ b/src/salamander/nmf_framework/signature_nmf.py @@ -339,6 +339,7 @@ def plot_signatures( @salamander_style def plot_exposures( self, + sample_order=None, reorder_signatures=True, annotate_samples=True, colors=None, @@ -353,6 +354,7 @@ def plot_exposures( """ ax = exposures_plot( exposures=self.exposures, + sample_order=sample_order, reorder_signatures=reorder_signatures, annotate_samples=annotate_samples, colors=colors, @@ -360,7 +362,6 @@ def plot_exposures( ax=ax, **kwargs, ) - if outfile is not None: plt.savefig(outfile, bbox_inches="tight") diff --git a/src/salamander/plot.py b/src/salamander/plot.py index 0486309..7400c97 100644 --- a/src/salamander/plot.py +++ b/src/salamander/plot.py @@ -513,27 +513,77 @@ def signatures_plot( return axes -def _reorder_exposures(exposures: pd.DataFrame, reorder_signatures=True): +def _get_sample_order(exposures: pd.DataFrame, normalize=True): """ - Reorder the samples using hierarchical clustering and - reorder the signatures by their total relative exposure. + Compute the aesthetically most pleasing order of the samples + for a stacked bar chart of the exposures. + + Parameters + ---------- + exposures : pd.DataFrame of shape (n_signatures, n_samples) + The named exposure matrix + + normalize : bool, default=True + If True, the exposures are normalized before computing the + hierarchical clustering. + + Returns + ------- + sample_order : np.ndarray + The ordered sample names """ - exposures_normalized = exposures / exposures.sum(axis=0) + if normalize: + # not in-place + exposures = exposures / exposures.sum(axis=0) - d = pdist(exposures_normalized.T) + d = pdist(exposures.T) linkage = fastcluster.linkage(d) # get the optimal sample order that is consistent # with the hierarchical clustering linkage sample_order = hierarchy.leaves_list(hierarchy.optimal_leaf_ordering(linkage, d)) - samples_reordered = exposures_normalized.columns[sample_order] - exposures_reordered = exposures_normalized[samples_reordered] + sample_order = exposures.columns[sample_order].to_numpy() + return sample_order + + +def _reorder_exposures( + exposures: pd.DataFrame, sample_order=None, reorder_signatures=True +): + """ + Reorder the samples with hierarchical clustering and + reorder the signatures by their total relative exposure. + + Parameters + ---------- + exposures : pd.DataFrame of shape (n_signatures, n_samples) + The named exposure matrix + + sample_order : np.ndarray, default=None + A predefined order of the samples as a list of sample names. + If None, hierarchical clustering is used to compute the + aesthetically most pleasing order. + + reorder_signatures : bool, default=True + If True, the signatures will be reordered such that the + total relative exposures of the signatures decrease from the bottom + to the top signature in the stacked bar chart. - # order the signatures by their total exposure + Returns + ------- + exposures_reordered : pd.DataFrame of shape (n_signatures, n_samples) + The reorderd named exposure matrix + """ + if sample_order is None: + sample_order = _get_sample_order(exposures) + + exposures_reordered = exposures[sample_order] + + # order the signatures by their total relative exposure if reorder_signatures: - signatures_reordered = ( - exposures_reordered.sum(axis=1).sort_values(ascending=False).index + exposures_normalized = exposures_reordered / exposures_reordered.sum(axis=0) + signature_order = ( + exposures_normalized.sum(axis=1).sort_values(ascending=False).index ) - exposures_reordered = exposures_reordered.reindex(signatures_reordered) + exposures_reordered = exposures_reordered.reindex(signature_order) return exposures_reordered @@ -541,6 +591,7 @@ def _reorder_exposures(exposures: pd.DataFrame, reorder_signatures=True): @salamander_style def exposures_plot( exposures: pd.DataFrame, + sample_order=None, reorder_signatures=True, annotate_samples=True, colors=None, @@ -549,13 +600,49 @@ def exposures_plot( **kwargs, ): """ - Visualize the exposures using a stacked bar chart. + Visualize the exposures with a stacked bar chart. + + Parameter + --------- + exposures : pd.DataFrame of shape (n_signatures, n_samples) + The named exposure matrix. + + sample_order : np.ndarray, default=None + A predefined order of the samples as a list of sample names. + If None, hierarchical clustering is used to compute the + aesthetically most pleasing order. + + reorder_signatures : bool, default=True + If True, the signatures will be reordered such that the + total relative exposures of the signatures decrease from the bottom + to the top signature in the stacked bar chart. + + annotate_samples : bool, default=True + If True, the x-axis is annotated with the sample names. + + colors : list of length n_signatures, default=None + Colors to pass to matplotlibs ax.bar, one per signature. + + n_col_legend : int, default=1 + The number of columns of the legend. + + ax : matplotlib.axes.Axes, default=None + Pre-existing axes for the plot. Otherwise, create an axis internally. + + kwargs : dict + Any keyword arguments to be passed to matplotlibs ax.bar. + + Returns + ------- + ax : matplotlib.axes.Axes + The matplotlib axes containing the plot. """ n_signatures, n_samples = exposures.shape - exposures_reordered = _reorder_exposures( - exposures, reorder_signatures=reorder_signatures + # not in-place + exposures = exposures / exposures.sum(axis=0) + exposures = _reorder_exposures( + exposures, sample_order=sample_order, reorder_signatures=reorder_signatures ) - samples = exposures_reordered.columns if ax is None: _, ax = plt.subplots(figsize=(0.3 * n_samples, 4)) @@ -565,8 +652,8 @@ def exposures_plot( bottom = np.zeros(n_samples) - for signature, color in zip(exposures_reordered.T, colors): - signature_exposures = exposures_reordered.T[signature].to_numpy() + for signature, color in zip(exposures.T, colors): + signature_exposures = exposures.T[signature].to_numpy() ax.bar( np.arange(n_samples), signature_exposures, @@ -581,7 +668,7 @@ def exposures_plot( if annotate_samples: ax.set_xticks(np.arange(n_samples)) - ax.set_xticklabels(samples, rotation=90, ha="center", fontsize=10) + ax.set_xticklabels(exposures.columns, rotation=90, ha="center", fontsize=10) else: ax.get_xaxis().set_visible(False) diff --git a/tests/test_plot.py b/tests/test_plot.py new file mode 100644 index 0000000..ae1878c --- /dev/null +++ b/tests/test_plot.py @@ -0,0 +1,64 @@ +import numpy as np +import pandas as pd +import pytest + +from salamander import plot + + +@pytest.fixture +def exposures(): + mat = np.array([[1, 2, 3, 4], [1, 3, 2, 4]]) + exposures = pd.DataFrame(mat, columns=["a", "b", "c", "d"]) + return exposures + + +def test_get_sample_order_normalized(exposures): + sample_order = plot._get_sample_order(exposures, normalize=True) + + # A next to D + position_a = np.where(sample_order == "a")[0][0] + position_d = np.where(sample_order == "d")[0][0] + assert np.abs(position_a - position_d) == 1 + + # B as far away from C as possible + position_b = np.where(sample_order == "b")[0][0] + position_c = np.where(sample_order == "c")[0][0] + assert np.abs(position_b - position_c) == 3 + + +def test_get_sample_order_unnormalized(exposures): + sample_order = plot._get_sample_order(exposures, normalize=False) + + # A as fara away from D as possible + position_a = np.where(sample_order == "a")[0][0] + position_d = np.where(sample_order == "d")[0][0] + assert np.abs(position_a - position_d) == 3 + + # B next to C + position_b = np.where(sample_order == "b")[0][0] + position_c = np.where(sample_order == "c")[0][0] + assert np.abs(position_b - position_c) == 1 + + +def test_reorder_exposures(exposures): + # reordering is based on the relative exposures + exposures_reordered = plot._reorder_exposures(exposures) + sample_order = exposures_reordered.columns.to_numpy() + + # A next to D + position_a = np.where(sample_order == "a")[0][0] + position_d = np.where(sample_order == "d")[0][0] + assert np.abs(position_a - position_d) == 1 + + # B as far away from C as possible + position_b = np.where(sample_order == "b")[0][0] + position_c = np.where(sample_order == "c")[0][0] + assert np.abs(position_b - position_c) == 3 + + +def test_reorder_custom(exposures): + custom_sample_order = ["b", "a", "c", "d"] + exposures_reordered = plot._reorder_exposures( + exposures, sample_order=custom_sample_order + ) + assert np.array_equal(exposures_reordered.columns, custom_sample_order) From f5de504b41b301f9ea62a221184cd0741bd2c9d8 Mon Sep 17 00:00:00 2001 From: BeGeiger Date: Sat, 21 Oct 2023 12:21:44 -0400 Subject: [PATCH 12/13] improve documentation --- src/salamander/nmf_framework/corrnmf.py | 4 +- .../nmf_framework/multimodal_corrnmf.py | 47 +++++++++++++++++-- src/salamander/nmf_framework/signature_nmf.py | 4 +- 3 files changed, 48 insertions(+), 7 deletions(-) diff --git a/src/salamander/nmf_framework/corrnmf.py b/src/salamander/nmf_framework/corrnmf.py index 3b7506e..16e542b 100644 --- a/src/salamander/nmf_framework/corrnmf.py +++ b/src/salamander/nmf_framework/corrnmf.py @@ -52,8 +52,8 @@ class CorrNMF(SignatureNMF): - fit: Run CorrNMF for a given mutation count data. Every - fit method should also implement a "refitting version", where the signatures - W are known in advance and fixed + fit method should also implement a version that allows fixing + arbitrary many a priori known signatures. The following attributes are implemented in the abstract class CNMF: diff --git a/src/salamander/nmf_framework/multimodal_corrnmf.py b/src/salamander/nmf_framework/multimodal_corrnmf.py index 451e14e..49890ce 100755 --- a/src/salamander/nmf_framework/multimodal_corrnmf.py +++ b/src/salamander/nmf_framework/multimodal_corrnmf.py @@ -679,15 +679,56 @@ def plot_feature_change( self, in_modality=None, out_modalities="all", - normalize=True, colors=None, annotate_mutation_types=False, figsize=None, outfile=None, **kwargs, ): + """ + For the signatures of one modality, plot the co-occuring spectra + in other modalities. This is achieved by interpreting a signature + embedding as a sample embedding and using the resulting exposures to + compute distributions over mutation types in different modalities. + + Parameters + ---------- + in_modality : str + The modality name of the signatures of interest, e.g. "SBS". + + out_modalities : list or str, default="all" + A list of modalities to convert the 'in_modality' signatures + into, e.g. ["Indel", "SV"]. A single string can also be provided + to select only one 'out_modality'. By default, all modalities + other than the 'in_modality' are selected. + + colors : list, default=None + A list of length '1 + len(out_modalities)' of colors to use + for the signature plots of the input modalitiy signatures + and the co-occuring spectra in the output modalites. + + annotate_mutation_types : bool, default=False + If True, the x-axis of the spectra plots will be annotated + with the mutation types of the respective modalities. + + figsize : tuple, default=None + The size of the matplotlib figure. If None, the figure size + is computed internally based on the number of input + signatures and output modalities. + + outfile : str, default=None + If not None, the figure will be saved to the provided path. + + kwargs : dict + Any keyword arguments to be passed to matplotlibs ax.bar. + + Returns + ------- + axes : np.ndarray + An array of matplotlib axes containing the plots. + """ # result[0] are the 'in_modality' signatures - results = self.feature_change(in_modality, out_modalities, normalize) + results = self.feature_change(in_modality, out_modalities) n_signatures = results[0].shape[1] n_feature_spaces = len(results) @@ -695,7 +736,7 @@ def plot_feature_change( colors = [None for _ in range(n_feature_spaces)] if figsize is None: - figsize = (8 * n_feature_spaces, 2 * n_signatures) + figsize = (4 * n_feature_spaces, n_signatures) fig, axes = plt.subplots(n_signatures, n_feature_spaces, figsize=figsize) fig.suptitle("Signature feature change") diff --git a/src/salamander/nmf_framework/signature_nmf.py b/src/salamander/nmf_framework/signature_nmf.py index 0ba9481..56f54f9 100644 --- a/src/salamander/nmf_framework/signature_nmf.py +++ b/src/salamander/nmf_framework/signature_nmf.py @@ -68,8 +68,8 @@ class SignatureNMF(ABC): - fit: Run the NMF algorithm for a given mutation count data. Every - fit method should also implement a "refitting version", where the signatures - W are known in advance and fixed. + fit method should also implement a version that allows fixing + arbitrary many a priori known signatures. - plot_embeddings: Plot the sample (and potentially the signature) embeddings in 2D. From 19af3e8cec7a05a94c8a4448edd2ecf41d0e329d Mon Sep 17 00:00:00 2001 From: BeGeiger Date: Sat, 21 Oct 2023 12:40:08 -0400 Subject: [PATCH 13/13] bump version --- CHANGELOG.md | 7 +++++-- pyproject.toml | 2 +- src/salamander/__init__.py | 2 +- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e6df702..3deeed6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,11 +5,14 @@ All noteable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [Unreleased] - --- --- +## 0.2.0 - 2023-10 +### Added + - Support fixing arbitrary many a priori known signatures during inference. + - Improved performance with just-in-time compiled update rules. + ## 0.1.0 - 2023-10 ### Added - First release of the non-negative matrix factorization (NMF) framework. Implemented algorithms: NMF with the generalized Kullback-Leibler divergence [(KL-NMF)](https://proceedings.neurips.cc/paper_files/paper/2000/file/f9d1152547c0bde01830b7e8bd60024c-Paper.pdf), minimum-volume NMF [(mvNMF)](https://arxiv.org/pdf/1907.02404.pdf), a version of correlated NMF [(CorrNMF)](https://citeseerx.ist.psu.edu/document?repid=rep1&type=pdf&doi=87224164eef14589b137547a3fa81f06eef9bbf4), a multimodal version of correlated NMF [(MultimodalCorrNMF)](https://citeseerx.ist.psu.edu/document?repid=rep1&type=pdf&doi=87224164eef14589b137547a3fa81f06eef9bbf4). diff --git a/pyproject.toml b/pyproject.toml index 8c53da9..aa0ef2a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "salamander-learn" -version = "0.1.0" +version = "0.2.0" description = "Salamander is a non-negative matrix factorization framework for signature analysis" license = "MIT" authors = ["Benedikt Geiger"] diff --git a/src/salamander/__init__.py b/src/salamander/__init__.py index 88173ef..ea8fcac 100644 --- a/src/salamander/__init__.py +++ b/src/salamander/__init__.py @@ -7,5 +7,5 @@ from .nmf_framework.multimodal_corrnmf import MultimodalCorrNMF from .nmf_framework.mvnmf import MvNMF -__version__ = "0.1.0" +__version__ = "0.2.0" __all__ = ["CorrNMFDet", "KLNMF", "MvNMF", "MultimodalCorrNMF"]