diff --git a/scvelo/tools/_em_model.py b/scvelo/tools/_em_model.py index c96af7f7..0ef5fb39 100644 --- a/scvelo/tools/_em_model.py +++ b/scvelo/tools/_em_model.py @@ -1,6 +1,6 @@ import os -from dataclasses import dataclass -from typing import List, Optional, Sequence +from dataclasses import asdict, dataclass, field, fields +from typing import List, Optional, Sequence, Union import numpy as np @@ -11,13 +11,7 @@ from scvelo.core import get_n_jobs, parallelize from scvelo.preprocessing.moments import get_connectivities from ._core import BaseInference -from ._em_model_core import ( - _flatten, - _read_pars, - _write_pars, - align_dynamics, - DynamicsRecovery, -) +from ._em_model_core import DynamicsRecovery from ._steady_state_model import SteadyStateModel from .utils import make_unique_list @@ -26,23 +20,64 @@ class EMParams: """EM parameters.""" - alpha: np.ndarray - beta: np.ndarray - gamma: np.ndarray - t_: np.ndarray - scaling: np.ndarray - std_u: np.ndarray - std_s: np.ndarray - likelihood: np.ndarray - u0: np.ndarray - s0: np.ndarray - pval_steady: np.ndarray - steady_u: np.ndarray - steady_s: np.ndarray - variance: np.ndarray - - -# TODO: Refactor to use `EMParams` + r2: np.ndarray = field(metadata={"is_matrix": False}) + alpha: np.ndarray = field(metadata={"is_matrix": False}) + beta: np.ndarray = field(metadata={"is_matrix": False}) + gamma: np.ndarray = field(metadata={"is_matrix": False}) + t_: np.ndarray = field(metadata={"is_matrix": False}) + scaling: np.ndarray = field(metadata={"is_matrix": False}) + std_u: np.ndarray = field(metadata={"is_matrix": False}) + std_s: np.ndarray = field(metadata={"is_matrix": False}) + likelihood: np.ndarray = field(metadata={"is_matrix": False}) + u0: np.ndarray = field(metadata={"is_matrix": False}) + s0: np.ndarray = field(metadata={"is_matrix": False}) + pval_steady: np.ndarray = field(metadata={"is_matrix": False}) + steady_u: np.ndarray = field(metadata={"is_matrix": False}) + steady_s: np.ndarray = field(metadata={"is_matrix": False}) + variance: np.ndarray = field(metadata={"is_matrix": False}) + alignment_scaling: np.ndarray = field(metadata={"is_matrix": False}) + T: np.ndarray = field(metadata={"is_matrix": True}) + Tau: np.ndarray = field(metadata={"is_matrix": True}) + Tau_: np.ndarray = field(metadata={"is_matrix": True}) + + @classmethod + def from_adata(cls, adata: AnnData, key: str = "fit"): + parameter_dict = {} + for parameter in fields(cls): + para_name = parameter.name + if parameter.metadata["is_matrix"]: + if f"{key}_{para_name.lower()}" in adata.layers.keys(): + parameter_dict[para_name] = adata.layers[ + f"{key}_{para_name.lower()}" + ] + else: + _vals = np.empty(adata.shape) + _vals.fill(np.nan) + parameter_dict[para_name] = _vals + else: + if f"{key}_{para_name.lower()}" in adata.var.keys(): + parameter_dict[para_name] = adata.var[ + f"{key}_{para_name.lower()}" + ].values + else: + _vals = np.empty(adata.n_vars) + _vals.fill(np.nan) + parameter_dict[para_name] = _vals + return cls(**parameter_dict) + + def export_to_adata(self, adata: AnnData, key: str = "fit"): + for parameter in fields(self): + para_name = parameter.name + value = getattr(self, para_name) + # The parameter is only written if not all entries are nan. + if not np.all(np.isnan(value)): + if parameter.metadata["is_matrix"]: + adata.layers[f"{key}_{para_name.lower()}"] = value + else: + adata.var[f"{key}_{para_name.lower()}"] = value + return adata + + # TODO: Implement abstract methods class ExpectationMaximizationModel(BaseInference): """EM 'Dynamical' model for velocity estimation. @@ -126,6 +161,12 @@ def __init__( self._n_jobs = get_n_jobs(n_jobs=n_jobs) self._backend = backend + if len(set(self._adata.var_names)) != len(self._adata.var_names): + logg.warn("Duplicate var_names found. Making them unique.") + self._adata.var_names_make_unique() + self._state_dict = EMParams.from_adata(adata) + self._use_raw = False + def _prepare_genes(self): """Initialize genes to use for the fitting.""" var_names = self._var_names_key @@ -143,7 +184,7 @@ def _prepare_genes(self): ) velo.fit() var_names = adata.var_names[velo.state_dict()["velocity_genes"]] - self.r2 = velo.state_dict()["r2"] + self._state_dict.r2 = velo.state_dict()["r2"] else: raise ValueError("Variable name not found in var keys.") if not isinstance(var_names, str): @@ -158,15 +199,23 @@ def _prepare_genes(self): var_names = var_names[np.argsort(np.sum(X, 0))[::-1][: self._n_top_genes]] self._var_names = var_names - # TODO: Implement def state_dict(self): """Return the state of the model.""" - raise NotImplementedError + return asdict(self._state_dict) - # TODO: Implement - def export_results_adata(self): - """Export the results to the AnnData object.""" - raise NotImplementedError + def export_results_adata(self, copy: bool = True, add_key: str = "fit"): + """Export the results to the AnnData object and return it.""" + adata = self._adata.copy() if copy else self._adata + self._state_dict.export_to_adata(adata, add_key) + adata.uns["recover_dynamics"] = { + "fit_connected_states": self._fit_connected_states, + "fit_basal_transcription": self._fit_basal_transcription, + "use_raw": self._use_raw, + } + # loss is only written after the execution of fit() + if hasattr(self, "_loss"): + adata.varm["loss"] = self._loss + return adata # TODO: Remove `use_raw` argument # TODO: Remove `return_model` argument @@ -178,6 +227,7 @@ def fit( load_pars: bool = False, steady_state_prior: Optional[List[bool]] = None, assignment_mode: str = "projection", + show_progress_bar: bool = True, **kwargs, ): """Fit the model.""" @@ -185,39 +235,23 @@ def fit( f"recovering dynamics (using {self._n_jobs}/{os.cpu_count()} cores)", r=True ) - # TODO: Remove or move to `__init__` - if len(set(self._adata.var_names)) != len(self._adata.var_names): - logg.warn("Duplicate var_names found. Making them unique.") - self._adata.var_names_make_unique() - if ( "Ms" not in self._adata.layers.keys() or "Mu" not in self._adata.layers.keys() ): use_raw = True + # TODO: Refactor the definition of 'use_raw'; Move to init? + self._use_raw = use_raw if self._fit_connected_states is None: - fit_connected_states = not use_raw + self._fit_connected_states = not use_raw self._prepare_genes() if return_model is None: return_model = len(self._var_names) < 5 - pars = _read_pars(self._adata) - alpha, beta, gamma, t_, scaling, std_u, std_s, likelihood = pars[:8] - u0, s0, pval, steady_u, steady_s, varx = pars[8:] - # likelihood[np.isnan(likelihood)] = 0 + sd = self._state_dict idx, L = [], [] - T = np.zeros(self._adata.shape) * np.nan - Tau = np.zeros(self._adata.shape) * np.nan - Tau_ = np.zeros(self._adata.shape) * np.nan - if f"{self._fit_key}_t" in self._adata.layers.keys(): - T = self._adata.layers[f"{self._fit_key}_t"] - if f"{self._fit_key}_tau" in self._adata.layers.keys(): - Tau = self._adata.layers[f"{self._fit_key}_tau"] - if f"{self._fit_key}_tau_" in self._adata.layers.keys(): - Tau_ = self._adata.layers[f"{self._fit_key}_tau_"] - - conn = get_connectivities(self._adata) if fit_connected_states else None + conn = get_connectivities(self._adata) if self._fit_connected_states else None res = parallelize( self._fit, @@ -226,7 +260,7 @@ def fit( unit="gene", as_array=False, backend=self._backend, - show_progress_bar=len(self._var_names) > 9, + show_progress_bar=show_progress_bar, )( use_raw=use_raw, load_pars=load_pars, @@ -238,72 +272,51 @@ def fit( idx, dms = map(_flatten, zip(*res)) for ix, dm in zip(idx, dms): - T[:, ix], Tau[:, ix], Tau_[:, ix] = dm.t, dm.tau, dm.tau_ - alpha[ix], beta[ix], gamma[ix], t_[ix], scaling[ix] = dm.pars[:, -1] - u0[ix], s0[ix], pval[ix] = dm.u0, dm.s0, dm.pval_steady - steady_u[ix], steady_s[ix] = dm.steady_u, dm.steady_s - beta[ix] /= scaling[ix] - steady_u[ix] *= scaling[ix] - - std_u[ix], std_s[ix] = dm.std_u, dm.std_s - likelihood[ix], varx[ix] = dm.likelihood, dm.varx + sd.T[:, ix], sd.Tau[:, ix], sd.Tau_[:, ix] = dm.t, dm.tau, dm.tau_ + ( + sd.alpha[ix], + sd.beta[ix], + sd.gamma[ix], + sd.t_[ix], + sd.scaling[ix], + ) = dm.pars[:, -1] + sd.u0[ix], sd.s0[ix], sd.pval_steady[ix] = dm.u0, dm.s0, dm.pval_steady + sd.steady_u[ix], sd.steady_s[ix] = dm.steady_u, dm.steady_s + sd.beta[ix] /= sd.scaling[ix] + sd.steady_u[ix] *= sd.scaling[ix] + + sd.std_u[ix], sd.std_s[ix] = dm.std_u, dm.std_s + sd.likelihood[ix], sd.variance[ix] = dm.likelihood, dm.varx L.append(dm.loss) - _pars = [ - alpha, - beta, - gamma, - t_, - scaling, - std_u, - std_s, - likelihood, - u0, - s0, - pval, - steady_u, - steady_s, - varx, - ] - - adata = self._adata.copy() if copy else self._adata - - adata.uns["recover_dynamics"] = { - "fit_connected_states": fit_connected_states, - "fit_basal_transcription": self._fit_basal_transcription, - "use_raw": use_raw, - } + adata = self._adata - _write_pars(adata, _pars, add_key=self._fit_key) - if f"{self._fit_key}_t" in adata.layers.keys(): - adata.layers[f"{self._fit_key}_t"][:, idx] = ( - T[:, idx] if conn is None else conn.dot(T[:, idx]) - ) - else: - adata.layers[f"{self._fit_key}_t"] = T if conn is None else conn.dot(T) - adata.layers[f"{self._fit_key}_tau"] = Tau - adata.layers[f"{self._fit_key}_tau_"] = Tau_ + if conn is not None: + if f"{self._fit_key}_t" in adata.layers.keys(): + sd.T[:, idx] = conn.dot(sd.T[:, idx]) + else: + sd.T = conn.dot(sd.T) # is False if only one invalid / irrecoverable gene was given in var_names if L: cur_len = adata.varm["loss"].shape[1] if "loss" in adata.varm.keys() else 2 max_len = max(np.max([len(loss) for loss in L]), cur_len) if L else cur_len - loss = np.ones((adata.n_vars, max_len)) * np.nan + self._loss = np.empty((adata.n_vars, max_len)) + self._loss.fill(np.nan) if "loss" in adata.varm.keys(): - loss[:, :cur_len] = adata.varm["loss"] + self._loss[:, :cur_len] = adata.varm["loss"] - loss[idx] = np.vstack( + self._loss[idx] = np.vstack( [ np.concatenate([loss, np.ones(max_len - len(loss)) * np.nan]) for loss in L ] ) - adata.varm["loss"] = loss # TODO: Fix s.t. `self._t_max` is only integer if self._t_max is not False: - dm = align_dynamics(adata, t_max=self._t_max, dm=dm, idx=idx) + dm = self._align_dynamics(t_max=self._t_max, dm=dm, idx=idx) logg.info( " finished", time=True, end=" " if settings.verbosity > 2 else "\n" @@ -319,6 +332,109 @@ def fit( return dm if return_model else adata if copy else None + def _align_dynamics( + self, + t_max: Optional[Union[float, bool]] = None, + dm: Optional[DynamicsRecovery] = None, + idx: Optional[List[bool]] = None, + mode: Optional[str] = None, + remove_outliers: bool = False, + ): + """Align dynamics to a common set of parameters. + + Arguments + --------- + t_max: `float`, `False` or `None` (default: `None`) + Total range for time assignments. + dm: :class:`~DynamicsRecovery` + DynamicsRecovery object to perform alignment on. + idx: list of `bool` or `None` (default: `None`) + Mask for indices used for alignment. + mode: `str` or None (default: `'align_total_time`) + What to align. Takes the following arguments: + common_splicing_rate, common_scaling, align_increments, align_total_time + remove_outliers: `bool` or `None` (default: `None`) + Whether to remove outliers. + copy: `bool` (default: `False`) + Return a copy instead of writing to `adata`. + + Returns + ------- + `alpha`, `beta`, `gamma`, `t_`, `alignment_scaling`: `.var` + Aligned parameters + `fit_t`, `fit_tau`, `fit_tau_`: `.layer` + Aligned time + """ + sd = self._state_dict + if idx is None: + idx = ~np.isnan(np.sum(sd.T, axis=0)) + if np.all(np.isnan(sd.alignment_scaling)): + sd.alignment_scaling.fill(1) + if mode is None: + mode = "align_total_time" + + m = np.ones(self._adata.n_vars) + mz = sd.alignment_scaling + mz_prev = np.array(mz) + + if dm is not None: # newly fitted + mz[idx] = 1 + + if mode == "align_total_time" and t_max is not False: + # transient 'on' + T_max = np.max(sd.T[:, idx] * (sd.T[:, idx] < sd.t_[idx]), axis=0) + # 'off' + T_max += np.max( + (sd.T[:, idx] - sd.t_[idx]) * (sd.T[:, idx] > sd.t_[idx]), axis=0 + ) + + denom = 1 - np.sum( + (sd.T[:, idx] == sd.t_[idx]) | (sd.T[:, idx] == 0), axis=0 + ) / len(sd.T) + denom += denom == 0 + + T_max = T_max / denom + T_max += T_max == 0 + + if t_max is None: + t_max = 20 + m[idx] = t_max / T_max + mz *= m + else: + m = 1 / mz + mz = np.ones(self.adata.n_vars) + + if remove_outliers: + mu, std = np.nanmean(mz), np.nanstd(mz) + mz = np.clip(mz, mu - 3 * std, mu + 3 * std) + m = mz / mz_prev + + if idx is None: + sd.alpha, sd.beta, sd.gamma = sd.alpha / m, sd.beta / m, sd.gamma / m + sd.T, sd.t_, sd.Tau, sd.Tau_ = sd.T * m, sd.t_ * m, sd.Tau * m, sd.Tau_ * m + else: + m_ = m[idx] + sd.alpha[idx] = sd.alpha[idx] / m_ + sd.beta[idx] = sd.beta[idx] / m_ + sd.gamma[idx] = sd.gamma[idx] / m_ + sd.T[:, idx], sd.t_[idx] = sd.T[:, idx] * m_, sd.t_[idx] * m_ + sd.Tau[:, idx], sd.Tau_[:, idx] = sd.Tau[:, idx] * m_, sd.Tau_[:, idx] * m_ + + mz[mz == 1] = np.nan + + if dm is not None and dm.recoverable: + dm.m = m[idx] + dm.alpha = dm.alpha / dm.m[-1] + dm.beta = dm.beta / dm.m[-1] + dm.gamma = dm.gamma / dm.m[-1] + dm.pars[:3] = dm.pars[:3] / dm.m[-1] + + dm.t = dm.t * dm.m[-1] + dm.tau = dm.tau * dm.m[-1] + dm.t_ = dm.t_ * dm.m[-1] + dm.pars[4] = dm.pars[4] * dm.m[-1] + return dm + # TODO: Add docstrings def _fit( self, @@ -364,3 +480,7 @@ def _fit( queue.put(None) return idx, dms + + +def _flatten(iterable): + return [i for it in iterable for i in it] diff --git a/tests/test_basic.py b/tests/test_basic.py index 234a83b4..169aac68 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -32,6 +32,7 @@ def test_dynamical_model(): adata=adata, var_names_key=adata.var_names[0] ) em_model.fit(return_model=False, copy=False) + adata = em_model.export_results_adata(adata) assert np.round(adata[:, adata.var_names[0]].var["fit_alpha"][0], 4) == 4.7409 @@ -45,6 +46,7 @@ def test_pipeline(): em_model = ExpectationMaximizationModel(adata=adata) em_model.fit(copy=False) + adata = em_model.export_results_adata(adata) scv.tl.velocity(adata) scv.tl.velocity(adata, vkey="dynamical_velocity", mode="dynamical") adata.var.velocity_genes = True