diff --git a/scvelo/tools/_em_model.py b/scvelo/tools/_em_model.py index c9135ce2..9bc374e9 100644 --- a/scvelo/tools/_em_model.py +++ b/scvelo/tools/_em_model.py @@ -1,6 +1,6 @@ import os -from dataclasses import asdict, dataclass, fields -from typing import List, Optional, Sequence +from dataclasses import asdict, dataclass, field, fields +from typing import List, Optional, Sequence, Union import numpy as np @@ -11,7 +11,7 @@ from scvelo.core import get_n_jobs, parallelize from scvelo.preprocessing.moments import get_connectivities from ._core import BaseInference -from ._em_model_core import _flatten, align_dynamics, DynamicsRecovery +from ._em_model_core import DynamicsRecovery from ._steady_state_model import SteadyStateModel from .utils import make_unique_list @@ -20,20 +20,61 @@ class EMParams: """EM parameters.""" - alpha: np.ndarray - beta: np.ndarray - gamma: np.ndarray - t_: np.ndarray - scaling: np.ndarray - std_u: np.ndarray - std_s: np.ndarray - likelihood: np.ndarray - u0: np.ndarray - s0: np.ndarray - pval_steady: np.ndarray - steady_u: np.ndarray - steady_s: np.ndarray - variance: np.ndarray + r2: np.ndarray = field(metadata={"is_matrix": False}) + alpha: np.ndarray = field(metadata={"is_matrix": False}) + beta: np.ndarray = field(metadata={"is_matrix": False}) + gamma: np.ndarray = field(metadata={"is_matrix": False}) + t_: np.ndarray = field(metadata={"is_matrix": False}) + scaling: np.ndarray = field(metadata={"is_matrix": False}) + std_u: np.ndarray = field(metadata={"is_matrix": False}) + std_s: np.ndarray = field(metadata={"is_matrix": False}) + likelihood: np.ndarray = field(metadata={"is_matrix": False}) + u0: np.ndarray = field(metadata={"is_matrix": False}) + s0: np.ndarray = field(metadata={"is_matrix": False}) + pval_steady: np.ndarray = field(metadata={"is_matrix": False}) + steady_u: np.ndarray = field(metadata={"is_matrix": False}) + steady_s: np.ndarray = field(metadata={"is_matrix": False}) + variance: np.ndarray = field(metadata={"is_matrix": False}) + alignment_scaling: np.ndarray = field(metadata={"is_matrix": False}) + T: np.ndarray = field(metadata={"is_matrix": True}) + Tau: np.ndarray = field(metadata={"is_matrix": True}) + Tau_: np.ndarray = field(metadata={"is_matrix": True}) + + @classmethod + def from_adata(cls, adata: AnnData, key: str = "fit"): + parameter_dict = {} + for parameter in fields(cls): + para_name = parameter.name + if parameter.metadata["is_matrix"]: + if f"{key}_{para_name.lower()}" in adata.layers.keys(): + parameter_dict[para_name] = adata.layers[ + f"{key}_{para_name.lower()}" + ] + else: + _vals = np.empty(adata.shape) + _vals.fill(np.nan) + parameter_dict[para_name] = _vals + else: + if f"{key}_{para_name.lower()}" in adata.var.keys(): + parameter_dict[para_name] = adata.var[ + f"{key}_{para_name.lower()}" + ].values + else: + _vals = np.empty(adata.n_vars) + _vals.fill(np.nan) + parameter_dict[para_name] = _vals + return cls(**parameter_dict) + + # TODO: Atm, fields are also written if they contain only NaN values. Is this useful? + def export_to_adata(self, adata: AnnData, key: str = "fit"): + for parameter in fields(self): + para_name = parameter.name + value = getattr(self, para_name) + if parameter.metadata["is_matrix"]: + adata.layers[f"{key}_{para_name.lower()}"] = value + else: + adata.var[f"{key}_{para_name.lower()}"] = value + return adata # TODO: Implement abstract methods @@ -119,22 +160,10 @@ def __init__( self._n_jobs = get_n_jobs(n_jobs=n_jobs) self._backend = backend - def _initialize_state_dict( - self, adata: AnnData, parameters: Optional[List[str]] = None, key: str = "fit" - ): - if parameters is None: - parameters = [field.name for field in fields(EMParams)] - parameter_dict = {} - - for parameter in parameters: - if f"{key}_{parameter}" in adata.var.keys(): - parameter_dict[parameter] = adata.var[f"{key}_{parameter}"].values - else: - _vals = np.empty(adata.n_vars) - _vals.fill(np.nan) - parameter_dict[parameter] = _vals - - self._state_dict = EMParams(**parameter_dict) + if len(set(self._adata.var_names)) != len(self._adata.var_names): + logg.warn("Duplicate var_names found. Making them unique.") + self._adata.var_names_make_unique() + self._state_dict = EMParams.from_adata(adata) def _prepare_genes(self): """Initialize genes to use for the fitting.""" @@ -153,7 +182,7 @@ def _prepare_genes(self): ) velo.fit() var_names = adata.var_names[velo.state_dict()["velocity_genes"]] - self.r2 = velo.state_dict()["r2"] + self._state_dict.r2 = velo.state_dict()["r2"] else: raise ValueError("Variable name not found in var keys.") if not isinstance(var_names, str): @@ -175,9 +204,13 @@ def state_dict(self): def export_results_adata(self, copy: bool = True, add_key: str = "fit"): """Export the results to the AnnData object and return it.""" adata = self._adata.copy() if copy else self._adata - adata.var[f"{add_key}_r2"] = self.r2 - for key, value in asdict(self._state_dict).items(): - adata.var[f"{add_key}_{key}"] = value + self._state_dict.export_to_adata(adata, add_key) + adata.uns["recover_dynamics"] = { + "fit_connected_states": self._fit_connected_states, + "fit_basal_transcription": self._fit_basal_transcription, + "use_raw": self._use_raw, + } + adata.varm["loss"] = self._loss return adata # TODO: Remove `use_raw` argument @@ -197,37 +230,23 @@ def fit( f"recovering dynamics (using {self._n_jobs}/{os.cpu_count()} cores)", r=True ) - # TODO: Remove or move to `__init__` - if len(set(self._adata.var_names)) != len(self._adata.var_names): - logg.warn("Duplicate var_names found. Making them unique.") - self._adata.var_names_make_unique() - if ( "Ms" not in self._adata.layers.keys() or "Mu" not in self._adata.layers.keys() ): use_raw = True + # TODO: Refactor the definition of 'use_raw'; Move to init? + self._use_raw = use_raw if self._fit_connected_states is None: - fit_connected_states = not use_raw + self._fit_connected_states = not use_raw self._prepare_genes() if return_model is None: return_model = len(self._var_names) < 5 - self._initialize_state_dict(self._adata) sd = self._state_dict idx, L = [], [] - T = np.zeros(self._adata.shape) * np.nan - Tau = np.zeros(self._adata.shape) * np.nan - Tau_ = np.zeros(self._adata.shape) * np.nan - if f"{self._fit_key}_t" in self._adata.layers.keys(): - T = self._adata.layers[f"{self._fit_key}_t"] - if f"{self._fit_key}_tau" in self._adata.layers.keys(): - Tau = self._adata.layers[f"{self._fit_key}_tau"] - if f"{self._fit_key}_tau_" in self._adata.layers.keys(): - Tau_ = self._adata.layers[f"{self._fit_key}_tau_"] - - conn = get_connectivities(self._adata) if fit_connected_states else None + conn = get_connectivities(self._adata) if self._fit_connected_states else None res = parallelize( self._fit, @@ -236,7 +255,7 @@ def fit( unit="gene", as_array=False, backend=self._backend, - show_progress_bar=len(self._var_names) > 9, + show_progress_bar=False, # len(self._var_names) > 9, )( use_raw=use_raw, load_pars=load_pars, @@ -248,7 +267,7 @@ def fit( idx, dms = map(_flatten, zip(*res)) for ix, dm in zip(idx, dms): - T[:, ix], Tau[:, ix], Tau_[:, ix] = dm.t, dm.tau, dm.tau_ + sd.T[:, ix], sd.Tau[:, ix], sd.Tau_[:, ix] = dm.t, dm.tau, dm.tau_ ( sd.alpha[ix], sd.beta[ix], @@ -265,43 +284,34 @@ def fit( sd.likelihood[ix], sd.variance[ix] = dm.likelihood, dm.varx L.append(dm.loss) - adata = self.export_results_adata(copy=copy, add_key=self._fit_key) - - adata.uns["recover_dynamics"] = { - "fit_connected_states": fit_connected_states, - "fit_basal_transcription": self._fit_basal_transcription, - "use_raw": use_raw, - } + adata = self._adata - if f"{self._fit_key}_t" in adata.layers.keys(): - adata.layers[f"{self._fit_key}_t"][:, idx] = ( - T[:, idx] if conn is None else conn.dot(T[:, idx]) - ) - else: - adata.layers[f"{self._fit_key}_t"] = T if conn is None else conn.dot(T) - adata.layers[f"{self._fit_key}_tau"] = Tau - adata.layers[f"{self._fit_key}_tau_"] = Tau_ + if conn is not None: + if f"{self._fit_key}_t" in adata.layers.keys(): + sd.T[:, idx] = conn.dot(sd.T[:, idx]) + else: + sd.T = conn.dot(sd.T) # is False if only one invalid / irrecoverable gene was given in var_names if L: cur_len = adata.varm["loss"].shape[1] if "loss" in adata.varm.keys() else 2 max_len = max(np.max([len(loss) for loss in L]), cur_len) if L else cur_len - loss = np.ones((adata.n_vars, max_len)) * np.nan + self._loss = np.empty((adata.n_vars, max_len)) + self._loss.fill(np.nan) if "loss" in adata.varm.keys(): - loss[:, :cur_len] = adata.varm["loss"] + self._loss[:, :cur_len] = adata.varm["loss"] - loss[idx] = np.vstack( + self._loss[idx] = np.vstack( [ np.concatenate([loss, np.ones(max_len - len(loss)) * np.nan]) for loss in L ] ) - adata.varm["loss"] = loss # TODO: Fix s.t. `self._t_max` is only integer if self._t_max is not False: - dm = align_dynamics(adata, t_max=self._t_max, dm=dm, idx=idx) + dm = self._align_dynamics(t_max=self._t_max, dm=dm, idx=idx) logg.info( " finished", time=True, end=" " if settings.verbosity > 2 else "\n" @@ -317,6 +327,109 @@ def fit( return dm if return_model else adata if copy else None + def _align_dynamics( + self, + t_max: Optional[Union[float, bool]] = None, + dm: Optional[DynamicsRecovery] = None, + idx: Optional[List[bool]] = None, + mode: Optional[str] = None, + remove_outliers: bool = False, + ): + """Align dynamics to a common set of parameters. + + Arguments + --------- + t_max: `float`, `False` or `None` (default: `None`) + Total range for time assignments. + dm: :class:`~DynamicsRecovery` + DynamicsRecovery object to perform alignment on. + idx: list of `bool` or `None` (default: `None`) + Mask for indices used for alignment. + mode: `str` or None (default: `'align_total_time`) + What to align. Takes the following arguments: + common_splicing_rate, common_scaling, align_increments, align_total_time + remove_outliers: `bool` or `None` (default: `None`) + Whether to remove outliers. + copy: `bool` (default: `False`) + Return a copy instead of writing to `adata`. + + Returns + ------- + `alpha`, `beta`, `gamma`, `t_`, `alignment_scaling`: `.var` + Aligned parameters + `fit_t`, `fit_tau`, `fit_tau_`: `.layer` + Aligned time + """ + sd = self._state_dict + if idx is None: + idx = ~np.isnan(np.sum(sd.T, axis=0)) + if np.all(np.isnan(sd.alignment_scaling)): + sd.alignment_scaling.fill(1) + if mode is None: + mode = "align_total_time" + + m = np.ones(self._adata.n_vars) + mz = sd.alignment_scaling + mz_prev = np.array(mz) + + if dm is not None: # newly fitted + mz[idx] = 1 + + if mode == "align_total_time" and t_max is not False: + # transient 'on' + T_max = np.max(sd.T[:, idx] * (sd.T[:, idx] < sd.t_[idx]), axis=0) + # 'off' + T_max += np.max( + (sd.T[:, idx] - sd.t_[idx]) * (sd.T[:, idx] > sd.t_[idx]), axis=0 + ) + + denom = 1 - np.sum( + (sd.T[:, idx] == sd.t_[idx]) | (sd.T[:, idx] == 0), axis=0 + ) / len(sd.T) + denom += denom == 0 + + T_max = T_max / denom + T_max += T_max == 0 + + if t_max is None: + t_max = 20 + m[idx] = t_max / T_max + mz *= m + else: + m = 1 / mz + mz = np.ones(self.adata.n_vars) + + if remove_outliers: + mu, std = np.nanmean(mz), np.nanstd(mz) + mz = np.clip(mz, mu - 3 * std, mu + 3 * std) + m = mz / mz_prev + + if idx is None: + sd.alpha, sd.beta, sd.gamma = sd.alpha / m, sd.beta / m, sd.gamma / m + sd.T, sd.t_, sd.Tau, sd.Tau_ = sd.T * m, sd.t_ * m, sd.Tau * m, sd.Tau_ * m + else: + m_ = m[idx] + sd.alpha[idx] = sd.alpha[idx] / m_ + sd.beta[idx] = sd.beta[idx] / m_ + sd.gamma[idx] = sd.gamma[idx] / m_ + sd.T[:, idx], sd.t_[idx] = sd.T[:, idx] * m_, sd.t_[idx] * m_ + sd.Tau[:, idx], sd.Tau_[:, idx] = sd.Tau[:, idx] * m_, sd.Tau_[:, idx] * m_ + + mz[mz == 1] = np.nan + + if dm is not None and dm.recoverable: + dm.m = m[idx] + dm.alpha = dm.alpha / dm.m[-1] + dm.beta = dm.beta / dm.m[-1] + dm.gamma = dm.gamma / dm.m[-1] + dm.pars[:3] = dm.pars[:3] / dm.m[-1] + + dm.t = dm.t * dm.m[-1] + dm.tau = dm.tau * dm.m[-1] + dm.t_ = dm.t_ * dm.m[-1] + dm.pars[4] = dm.pars[4] * dm.m[-1] + return dm + # TODO: Add docstrings def _fit( self, @@ -362,3 +475,7 @@ def _fit( queue.put(None) return idx, dms + + +def _flatten(iterable): + return [i for it in iterable for i in it]