From 5019ac4292c78e7cfcfefcde34258753fc0150bc Mon Sep 17 00:00:00 2001 From: Johanna Schneeberger Date: Wed, 3 May 2023 17:54:30 +0200 Subject: [PATCH 1/9] Add usage of `EMParams` Class `EMParams` is now used by the `ExpectationMaximizationModel`. --- scvelo/tools/_em_model.py | 54 ++++++++++++++++----------------------- 1 file changed, 22 insertions(+), 32 deletions(-) diff --git a/scvelo/tools/_em_model.py b/scvelo/tools/_em_model.py index c96af7f7..31ad3182 100644 --- a/scvelo/tools/_em_model.py +++ b/scvelo/tools/_em_model.py @@ -1,5 +1,5 @@ import os -from dataclasses import dataclass +from dataclasses import asdict, dataclass from typing import List, Optional, Sequence import numpy as np @@ -42,7 +42,6 @@ class EMParams: variance: np.ndarray -# TODO: Refactor to use `EMParams` # TODO: Implement abstract methods class ExpectationMaximizationModel(BaseInference): """EM 'Dynamical' model for velocity estimation. @@ -126,6 +125,10 @@ def __init__( self._n_jobs = get_n_jobs(n_jobs=n_jobs) self._backend = backend + def _initialize_state_dict(self, adata, pars_names=None, key="fit"): + pars = _read_pars(adata, pars_names, key) + self._state_dict = EMParams(*pars) + def _prepare_genes(self): """Initialize genes to use for the fitting.""" var_names = self._var_names_key @@ -202,10 +205,8 @@ def fit( if return_model is None: return_model = len(self._var_names) < 5 - pars = _read_pars(self._adata) - alpha, beta, gamma, t_, scaling, std_u, std_s, likelihood = pars[:8] - u0, s0, pval, steady_u, steady_s, varx = pars[8:] - # likelihood[np.isnan(likelihood)] = 0 + self._initialize_state_dict(self._adata) + sd = self._state_dict idx, L = [], [] T = np.zeros(self._adata.shape) * np.nan Tau = np.zeros(self._adata.shape) * np.nan @@ -239,33 +240,22 @@ def fit( for ix, dm in zip(idx, dms): T[:, ix], Tau[:, ix], Tau_[:, ix] = dm.t, dm.tau, dm.tau_ - alpha[ix], beta[ix], gamma[ix], t_[ix], scaling[ix] = dm.pars[:, -1] - u0[ix], s0[ix], pval[ix] = dm.u0, dm.s0, dm.pval_steady - steady_u[ix], steady_s[ix] = dm.steady_u, dm.steady_s - beta[ix] /= scaling[ix] - steady_u[ix] *= scaling[ix] - - std_u[ix], std_s[ix] = dm.std_u, dm.std_s - likelihood[ix], varx[ix] = dm.likelihood, dm.varx + ( + sd.alpha[ix], + sd.beta[ix], + sd.gamma[ix], + sd.t_[ix], + sd.scaling[ix], + ) = dm.pars[:, -1] + sd.u0[ix], sd.s0[ix], sd.pval_steady[ix] = dm.u0, dm.s0, dm.pval_steady + sd.steady_u[ix], sd.steady_s[ix] = dm.steady_u, dm.steady_s + sd.beta[ix] /= sd.scaling[ix] + sd.steady_u[ix] *= sd.scaling[ix] + + sd.std_u[ix], sd.std_s[ix] = dm.std_u, dm.std_s + sd.likelihood[ix], sd.variance[ix] = dm.likelihood, dm.varx L.append(dm.loss) - _pars = [ - alpha, - beta, - gamma, - t_, - scaling, - std_u, - std_s, - likelihood, - u0, - s0, - pval, - steady_u, - steady_s, - varx, - ] - adata = self._adata.copy() if copy else self._adata adata.uns["recover_dynamics"] = { @@ -274,7 +264,7 @@ def fit( "use_raw": use_raw, } - _write_pars(adata, _pars, add_key=self._fit_key) + _write_pars(adata, list(asdict(sd).values()), add_key=self._fit_key) if f"{self._fit_key}_t" in adata.layers.keys(): adata.layers[f"{self._fit_key}_t"][:, idx] = ( T[:, idx] if conn is None else conn.dot(T[:, idx]) From 2bd5f427f770360b0e301ff43ae00c5f2571500a Mon Sep 17 00:00:00 2001 From: Johanna Schneeberger Date: Mon, 15 May 2023 12:35:14 +0200 Subject: [PATCH 2/9] Remove use of `_read_pars()` from `_em_model.py` Since _read_pars() should be deleted in the future, _initialize_state_dict() does not rely on it anymore. --- scvelo/tools/_em_model.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/scvelo/tools/_em_model.py b/scvelo/tools/_em_model.py index 31ad3182..c3414c64 100644 --- a/scvelo/tools/_em_model.py +++ b/scvelo/tools/_em_model.py @@ -1,5 +1,5 @@ import os -from dataclasses import asdict, dataclass +from dataclasses import asdict, dataclass, fields from typing import List, Optional, Sequence import numpy as np @@ -11,13 +11,7 @@ from scvelo.core import get_n_jobs, parallelize from scvelo.preprocessing.moments import get_connectivities from ._core import BaseInference -from ._em_model_core import ( - _flatten, - _read_pars, - _write_pars, - align_dynamics, - DynamicsRecovery, -) +from ._em_model_core import _flatten, _write_pars, align_dynamics, DynamicsRecovery from ._steady_state_model import SteadyStateModel from .utils import make_unique_list @@ -126,7 +120,14 @@ def __init__( self._backend = backend def _initialize_state_dict(self, adata, pars_names=None, key="fit"): - pars = _read_pars(adata, pars_names, key) + pars = [] + emparams_fields = [field.name for field in fields(EMParams)] + for name in emparams_fields if pars_names is None else pars_names: + pkey = f"{key}_{name}" + par = np.zeros(adata.n_vars) * np.nan + if pkey in adata.var.keys(): + par = adata.var[pkey].values + pars.append(par) self._state_dict = EMParams(*pars) def _prepare_genes(self): From 36de25384b720b245a814cdd1858f1b5240d054e Mon Sep 17 00:00:00 2001 From: Johanna Schneeberger Date: Mon, 15 May 2023 22:03:43 +0200 Subject: [PATCH 3/9] Implement abstact methods Add implementation of state_dict() and export_results_adata() and remove use of _write_pars() in ExpectationMaximizationModel.fit(). --- scvelo/tools/_em_model.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/scvelo/tools/_em_model.py b/scvelo/tools/_em_model.py index c3414c64..123c9472 100644 --- a/scvelo/tools/_em_model.py +++ b/scvelo/tools/_em_model.py @@ -11,7 +11,7 @@ from scvelo.core import get_n_jobs, parallelize from scvelo.preprocessing.moments import get_connectivities from ._core import BaseInference -from ._em_model_core import _flatten, _write_pars, align_dynamics, DynamicsRecovery +from ._em_model_core import _flatten, align_dynamics, DynamicsRecovery from ._steady_state_model import SteadyStateModel from .utils import make_unique_list @@ -162,15 +162,17 @@ def _prepare_genes(self): var_names = var_names[np.argsort(np.sum(X, 0))[::-1][: self._n_top_genes]] self._var_names = var_names - # TODO: Implement def state_dict(self): """Return the state of the model.""" - raise NotImplementedError + return asdict(self._state_dict) - # TODO: Implement - def export_results_adata(self): - """Export the results to the AnnData object.""" - raise NotImplementedError + def export_results_adata(self, copy: bool = True, add_key: str = "fit"): + """Export the results to the AnnData object and return it.""" + adata = self._adata.copy() if copy else self._adata + adata.var[f"{add_key}_r2"] = self.r2 + for key, value in asdict(self._state_dict).items(): + adata.var[f"{add_key}_{key}"] = value + return adata # TODO: Remove `use_raw` argument # TODO: Remove `return_model` argument @@ -257,7 +259,7 @@ def fit( sd.likelihood[ix], sd.variance[ix] = dm.likelihood, dm.varx L.append(dm.loss) - adata = self._adata.copy() if copy else self._adata + adata = self.export_results_adata(copy=copy, add_key=self._fit_key) adata.uns["recover_dynamics"] = { "fit_connected_states": fit_connected_states, @@ -265,7 +267,6 @@ def fit( "use_raw": use_raw, } - _write_pars(adata, list(asdict(sd).values()), add_key=self._fit_key) if f"{self._fit_key}_t" in adata.layers.keys(): adata.layers[f"{self._fit_key}_t"][:, idx] = ( T[:, idx] if conn is None else conn.dot(T[:, idx]) From 3ac2c3b169ebc10ba152906df67abf9d5933393a Mon Sep 17 00:00:00 2001 From: Philipp Weiler Date: Mon, 22 May 2023 21:07:34 -0400 Subject: [PATCH 4/9] Add type hints to `_initialize_state_dict` --- scvelo/tools/_em_model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/scvelo/tools/_em_model.py b/scvelo/tools/_em_model.py index 123c9472..86ee9e75 100644 --- a/scvelo/tools/_em_model.py +++ b/scvelo/tools/_em_model.py @@ -119,7 +119,9 @@ def __init__( self._n_jobs = get_n_jobs(n_jobs=n_jobs) self._backend = backend - def _initialize_state_dict(self, adata, pars_names=None, key="fit"): + def _initialize_state_dict( + self, adata: AnnData, pars_names: Optional[List[str]] = None, key: str = "fit" + ): pars = [] emparams_fields = [field.name for field in fields(EMParams)] for name in emparams_fields if pars_names is None else pars_names: From b4c27d3c241c97ac064e14a8fba2792b7dac2a77 Mon Sep 17 00:00:00 2001 From: Philipp Weiler Date: Mon, 22 May 2023 21:19:02 -0400 Subject: [PATCH 5/9] Refactor `_initialize_state_dict` * Rename variables (`pars_names` to `parameters`, `pars` to `parameter_dict`). * Restructure code to reduce number of variables. * Update definition of parameters initialized with as nan. --- scvelo/tools/_em_model.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/scvelo/tools/_em_model.py b/scvelo/tools/_em_model.py index 86ee9e75..c9135ce2 100644 --- a/scvelo/tools/_em_model.py +++ b/scvelo/tools/_em_model.py @@ -120,17 +120,21 @@ def __init__( self._backend = backend def _initialize_state_dict( - self, adata: AnnData, pars_names: Optional[List[str]] = None, key: str = "fit" + self, adata: AnnData, parameters: Optional[List[str]] = None, key: str = "fit" ): - pars = [] - emparams_fields = [field.name for field in fields(EMParams)] - for name in emparams_fields if pars_names is None else pars_names: - pkey = f"{key}_{name}" - par = np.zeros(adata.n_vars) * np.nan - if pkey in adata.var.keys(): - par = adata.var[pkey].values - pars.append(par) - self._state_dict = EMParams(*pars) + if parameters is None: + parameters = [field.name for field in fields(EMParams)] + parameter_dict = {} + + for parameter in parameters: + if f"{key}_{parameter}" in adata.var.keys(): + parameter_dict[parameter] = adata.var[f"{key}_{parameter}"].values + else: + _vals = np.empty(adata.n_vars) + _vals.fill(np.nan) + parameter_dict[parameter] = _vals + + self._state_dict = EMParams(**parameter_dict) def _prepare_genes(self): """Initialize genes to use for the fitting.""" From 15808e0a08a51781d51888350db9b86212561d93 Mon Sep 17 00:00:00 2001 From: Johanna Schneeberger Date: Tue, 22 Aug 2023 18:12:01 +0200 Subject: [PATCH 6/9] Refine `ExpectationMaximizationModel` * Add metadata the attributes to `EMParams` * Add methods read/write `EMParams` from/to AnnData * Adapt the method `fit()` such that it uses the class `EMParams` * Include `_align_dynamics()` as a method in `ExpectationMaximizationModel` and adapt it to also use `EMParams` * Include `_flatten()` from `_em_model_core.py` --- scvelo/tools/_em_model.py | 271 +++++++++++++++++++++++++++----------- 1 file changed, 194 insertions(+), 77 deletions(-) diff --git a/scvelo/tools/_em_model.py b/scvelo/tools/_em_model.py index c9135ce2..9bc374e9 100644 --- a/scvelo/tools/_em_model.py +++ b/scvelo/tools/_em_model.py @@ -1,6 +1,6 @@ import os -from dataclasses import asdict, dataclass, fields -from typing import List, Optional, Sequence +from dataclasses import asdict, dataclass, field, fields +from typing import List, Optional, Sequence, Union import numpy as np @@ -11,7 +11,7 @@ from scvelo.core import get_n_jobs, parallelize from scvelo.preprocessing.moments import get_connectivities from ._core import BaseInference -from ._em_model_core import _flatten, align_dynamics, DynamicsRecovery +from ._em_model_core import DynamicsRecovery from ._steady_state_model import SteadyStateModel from .utils import make_unique_list @@ -20,20 +20,61 @@ class EMParams: """EM parameters.""" - alpha: np.ndarray - beta: np.ndarray - gamma: np.ndarray - t_: np.ndarray - scaling: np.ndarray - std_u: np.ndarray - std_s: np.ndarray - likelihood: np.ndarray - u0: np.ndarray - s0: np.ndarray - pval_steady: np.ndarray - steady_u: np.ndarray - steady_s: np.ndarray - variance: np.ndarray + r2: np.ndarray = field(metadata={"is_matrix": False}) + alpha: np.ndarray = field(metadata={"is_matrix": False}) + beta: np.ndarray = field(metadata={"is_matrix": False}) + gamma: np.ndarray = field(metadata={"is_matrix": False}) + t_: np.ndarray = field(metadata={"is_matrix": False}) + scaling: np.ndarray = field(metadata={"is_matrix": False}) + std_u: np.ndarray = field(metadata={"is_matrix": False}) + std_s: np.ndarray = field(metadata={"is_matrix": False}) + likelihood: np.ndarray = field(metadata={"is_matrix": False}) + u0: np.ndarray = field(metadata={"is_matrix": False}) + s0: np.ndarray = field(metadata={"is_matrix": False}) + pval_steady: np.ndarray = field(metadata={"is_matrix": False}) + steady_u: np.ndarray = field(metadata={"is_matrix": False}) + steady_s: np.ndarray = field(metadata={"is_matrix": False}) + variance: np.ndarray = field(metadata={"is_matrix": False}) + alignment_scaling: np.ndarray = field(metadata={"is_matrix": False}) + T: np.ndarray = field(metadata={"is_matrix": True}) + Tau: np.ndarray = field(metadata={"is_matrix": True}) + Tau_: np.ndarray = field(metadata={"is_matrix": True}) + + @classmethod + def from_adata(cls, adata: AnnData, key: str = "fit"): + parameter_dict = {} + for parameter in fields(cls): + para_name = parameter.name + if parameter.metadata["is_matrix"]: + if f"{key}_{para_name.lower()}" in adata.layers.keys(): + parameter_dict[para_name] = adata.layers[ + f"{key}_{para_name.lower()}" + ] + else: + _vals = np.empty(adata.shape) + _vals.fill(np.nan) + parameter_dict[para_name] = _vals + else: + if f"{key}_{para_name.lower()}" in adata.var.keys(): + parameter_dict[para_name] = adata.var[ + f"{key}_{para_name.lower()}" + ].values + else: + _vals = np.empty(adata.n_vars) + _vals.fill(np.nan) + parameter_dict[para_name] = _vals + return cls(**parameter_dict) + + # TODO: Atm, fields are also written if they contain only NaN values. Is this useful? + def export_to_adata(self, adata: AnnData, key: str = "fit"): + for parameter in fields(self): + para_name = parameter.name + value = getattr(self, para_name) + if parameter.metadata["is_matrix"]: + adata.layers[f"{key}_{para_name.lower()}"] = value + else: + adata.var[f"{key}_{para_name.lower()}"] = value + return adata # TODO: Implement abstract methods @@ -119,22 +160,10 @@ def __init__( self._n_jobs = get_n_jobs(n_jobs=n_jobs) self._backend = backend - def _initialize_state_dict( - self, adata: AnnData, parameters: Optional[List[str]] = None, key: str = "fit" - ): - if parameters is None: - parameters = [field.name for field in fields(EMParams)] - parameter_dict = {} - - for parameter in parameters: - if f"{key}_{parameter}" in adata.var.keys(): - parameter_dict[parameter] = adata.var[f"{key}_{parameter}"].values - else: - _vals = np.empty(adata.n_vars) - _vals.fill(np.nan) - parameter_dict[parameter] = _vals - - self._state_dict = EMParams(**parameter_dict) + if len(set(self._adata.var_names)) != len(self._adata.var_names): + logg.warn("Duplicate var_names found. Making them unique.") + self._adata.var_names_make_unique() + self._state_dict = EMParams.from_adata(adata) def _prepare_genes(self): """Initialize genes to use for the fitting.""" @@ -153,7 +182,7 @@ def _prepare_genes(self): ) velo.fit() var_names = adata.var_names[velo.state_dict()["velocity_genes"]] - self.r2 = velo.state_dict()["r2"] + self._state_dict.r2 = velo.state_dict()["r2"] else: raise ValueError("Variable name not found in var keys.") if not isinstance(var_names, str): @@ -175,9 +204,13 @@ def state_dict(self): def export_results_adata(self, copy: bool = True, add_key: str = "fit"): """Export the results to the AnnData object and return it.""" adata = self._adata.copy() if copy else self._adata - adata.var[f"{add_key}_r2"] = self.r2 - for key, value in asdict(self._state_dict).items(): - adata.var[f"{add_key}_{key}"] = value + self._state_dict.export_to_adata(adata, add_key) + adata.uns["recover_dynamics"] = { + "fit_connected_states": self._fit_connected_states, + "fit_basal_transcription": self._fit_basal_transcription, + "use_raw": self._use_raw, + } + adata.varm["loss"] = self._loss return adata # TODO: Remove `use_raw` argument @@ -197,37 +230,23 @@ def fit( f"recovering dynamics (using {self._n_jobs}/{os.cpu_count()} cores)", r=True ) - # TODO: Remove or move to `__init__` - if len(set(self._adata.var_names)) != len(self._adata.var_names): - logg.warn("Duplicate var_names found. Making them unique.") - self._adata.var_names_make_unique() - if ( "Ms" not in self._adata.layers.keys() or "Mu" not in self._adata.layers.keys() ): use_raw = True + # TODO: Refactor the definition of 'use_raw'; Move to init? + self._use_raw = use_raw if self._fit_connected_states is None: - fit_connected_states = not use_raw + self._fit_connected_states = not use_raw self._prepare_genes() if return_model is None: return_model = len(self._var_names) < 5 - self._initialize_state_dict(self._adata) sd = self._state_dict idx, L = [], [] - T = np.zeros(self._adata.shape) * np.nan - Tau = np.zeros(self._adata.shape) * np.nan - Tau_ = np.zeros(self._adata.shape) * np.nan - if f"{self._fit_key}_t" in self._adata.layers.keys(): - T = self._adata.layers[f"{self._fit_key}_t"] - if f"{self._fit_key}_tau" in self._adata.layers.keys(): - Tau = self._adata.layers[f"{self._fit_key}_tau"] - if f"{self._fit_key}_tau_" in self._adata.layers.keys(): - Tau_ = self._adata.layers[f"{self._fit_key}_tau_"] - - conn = get_connectivities(self._adata) if fit_connected_states else None + conn = get_connectivities(self._adata) if self._fit_connected_states else None res = parallelize( self._fit, @@ -236,7 +255,7 @@ def fit( unit="gene", as_array=False, backend=self._backend, - show_progress_bar=len(self._var_names) > 9, + show_progress_bar=False, # len(self._var_names) > 9, )( use_raw=use_raw, load_pars=load_pars, @@ -248,7 +267,7 @@ def fit( idx, dms = map(_flatten, zip(*res)) for ix, dm in zip(idx, dms): - T[:, ix], Tau[:, ix], Tau_[:, ix] = dm.t, dm.tau, dm.tau_ + sd.T[:, ix], sd.Tau[:, ix], sd.Tau_[:, ix] = dm.t, dm.tau, dm.tau_ ( sd.alpha[ix], sd.beta[ix], @@ -265,43 +284,34 @@ def fit( sd.likelihood[ix], sd.variance[ix] = dm.likelihood, dm.varx L.append(dm.loss) - adata = self.export_results_adata(copy=copy, add_key=self._fit_key) - - adata.uns["recover_dynamics"] = { - "fit_connected_states": fit_connected_states, - "fit_basal_transcription": self._fit_basal_transcription, - "use_raw": use_raw, - } + adata = self._adata - if f"{self._fit_key}_t" in adata.layers.keys(): - adata.layers[f"{self._fit_key}_t"][:, idx] = ( - T[:, idx] if conn is None else conn.dot(T[:, idx]) - ) - else: - adata.layers[f"{self._fit_key}_t"] = T if conn is None else conn.dot(T) - adata.layers[f"{self._fit_key}_tau"] = Tau - adata.layers[f"{self._fit_key}_tau_"] = Tau_ + if conn is not None: + if f"{self._fit_key}_t" in adata.layers.keys(): + sd.T[:, idx] = conn.dot(sd.T[:, idx]) + else: + sd.T = conn.dot(sd.T) # is False if only one invalid / irrecoverable gene was given in var_names if L: cur_len = adata.varm["loss"].shape[1] if "loss" in adata.varm.keys() else 2 max_len = max(np.max([len(loss) for loss in L]), cur_len) if L else cur_len - loss = np.ones((adata.n_vars, max_len)) * np.nan + self._loss = np.empty((adata.n_vars, max_len)) + self._loss.fill(np.nan) if "loss" in adata.varm.keys(): - loss[:, :cur_len] = adata.varm["loss"] + self._loss[:, :cur_len] = adata.varm["loss"] - loss[idx] = np.vstack( + self._loss[idx] = np.vstack( [ np.concatenate([loss, np.ones(max_len - len(loss)) * np.nan]) for loss in L ] ) - adata.varm["loss"] = loss # TODO: Fix s.t. `self._t_max` is only integer if self._t_max is not False: - dm = align_dynamics(adata, t_max=self._t_max, dm=dm, idx=idx) + dm = self._align_dynamics(t_max=self._t_max, dm=dm, idx=idx) logg.info( " finished", time=True, end=" " if settings.verbosity > 2 else "\n" @@ -317,6 +327,109 @@ def fit( return dm if return_model else adata if copy else None + def _align_dynamics( + self, + t_max: Optional[Union[float, bool]] = None, + dm: Optional[DynamicsRecovery] = None, + idx: Optional[List[bool]] = None, + mode: Optional[str] = None, + remove_outliers: bool = False, + ): + """Align dynamics to a common set of parameters. + + Arguments + --------- + t_max: `float`, `False` or `None` (default: `None`) + Total range for time assignments. + dm: :class:`~DynamicsRecovery` + DynamicsRecovery object to perform alignment on. + idx: list of `bool` or `None` (default: `None`) + Mask for indices used for alignment. + mode: `str` or None (default: `'align_total_time`) + What to align. Takes the following arguments: + common_splicing_rate, common_scaling, align_increments, align_total_time + remove_outliers: `bool` or `None` (default: `None`) + Whether to remove outliers. + copy: `bool` (default: `False`) + Return a copy instead of writing to `adata`. + + Returns + ------- + `alpha`, `beta`, `gamma`, `t_`, `alignment_scaling`: `.var` + Aligned parameters + `fit_t`, `fit_tau`, `fit_tau_`: `.layer` + Aligned time + """ + sd = self._state_dict + if idx is None: + idx = ~np.isnan(np.sum(sd.T, axis=0)) + if np.all(np.isnan(sd.alignment_scaling)): + sd.alignment_scaling.fill(1) + if mode is None: + mode = "align_total_time" + + m = np.ones(self._adata.n_vars) + mz = sd.alignment_scaling + mz_prev = np.array(mz) + + if dm is not None: # newly fitted + mz[idx] = 1 + + if mode == "align_total_time" and t_max is not False: + # transient 'on' + T_max = np.max(sd.T[:, idx] * (sd.T[:, idx] < sd.t_[idx]), axis=0) + # 'off' + T_max += np.max( + (sd.T[:, idx] - sd.t_[idx]) * (sd.T[:, idx] > sd.t_[idx]), axis=0 + ) + + denom = 1 - np.sum( + (sd.T[:, idx] == sd.t_[idx]) | (sd.T[:, idx] == 0), axis=0 + ) / len(sd.T) + denom += denom == 0 + + T_max = T_max / denom + T_max += T_max == 0 + + if t_max is None: + t_max = 20 + m[idx] = t_max / T_max + mz *= m + else: + m = 1 / mz + mz = np.ones(self.adata.n_vars) + + if remove_outliers: + mu, std = np.nanmean(mz), np.nanstd(mz) + mz = np.clip(mz, mu - 3 * std, mu + 3 * std) + m = mz / mz_prev + + if idx is None: + sd.alpha, sd.beta, sd.gamma = sd.alpha / m, sd.beta / m, sd.gamma / m + sd.T, sd.t_, sd.Tau, sd.Tau_ = sd.T * m, sd.t_ * m, sd.Tau * m, sd.Tau_ * m + else: + m_ = m[idx] + sd.alpha[idx] = sd.alpha[idx] / m_ + sd.beta[idx] = sd.beta[idx] / m_ + sd.gamma[idx] = sd.gamma[idx] / m_ + sd.T[:, idx], sd.t_[idx] = sd.T[:, idx] * m_, sd.t_[idx] * m_ + sd.Tau[:, idx], sd.Tau_[:, idx] = sd.Tau[:, idx] * m_, sd.Tau_[:, idx] * m_ + + mz[mz == 1] = np.nan + + if dm is not None and dm.recoverable: + dm.m = m[idx] + dm.alpha = dm.alpha / dm.m[-1] + dm.beta = dm.beta / dm.m[-1] + dm.gamma = dm.gamma / dm.m[-1] + dm.pars[:3] = dm.pars[:3] / dm.m[-1] + + dm.t = dm.t * dm.m[-1] + dm.tau = dm.tau * dm.m[-1] + dm.t_ = dm.t_ * dm.m[-1] + dm.pars[4] = dm.pars[4] * dm.m[-1] + return dm + # TODO: Add docstrings def _fit( self, @@ -362,3 +475,7 @@ def _fit( queue.put(None) return idx, dms + + +def _flatten(iterable): + return [i for it in iterable for i in it] From b2c8258cedb28cb29be2b64719cc67c457a1c607 Mon Sep 17 00:00:00 2001 From: Johanna Schneeberger Date: Wed, 20 Sep 2023 19:51:40 +0200 Subject: [PATCH 7/9] Modify export to AnnData After the change, parameters are not written if they contain only nan values. --- scvelo/tools/_em_model.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/scvelo/tools/_em_model.py b/scvelo/tools/_em_model.py index 9bc374e9..e2291dec 100644 --- a/scvelo/tools/_em_model.py +++ b/scvelo/tools/_em_model.py @@ -65,15 +65,16 @@ def from_adata(cls, adata: AnnData, key: str = "fit"): parameter_dict[para_name] = _vals return cls(**parameter_dict) - # TODO: Atm, fields are also written if they contain only NaN values. Is this useful? def export_to_adata(self, adata: AnnData, key: str = "fit"): for parameter in fields(self): para_name = parameter.name value = getattr(self, para_name) - if parameter.metadata["is_matrix"]: - adata.layers[f"{key}_{para_name.lower()}"] = value - else: - adata.var[f"{key}_{para_name.lower()}"] = value + # The parameter is only written if not all entries are nan. + if not np.all(np.isnan(value)): + if parameter.metadata["is_matrix"]: + adata.layers[f"{key}_{para_name.lower()}"] = value + else: + adata.var[f"{key}_{para_name.lower()}"] = value return adata @@ -164,6 +165,7 @@ def __init__( logg.warn("Duplicate var_names found. Making them unique.") self._adata.var_names_make_unique() self._state_dict = EMParams.from_adata(adata) + self._use_raw = False def _prepare_genes(self): """Initialize genes to use for the fitting.""" @@ -210,7 +212,9 @@ def export_results_adata(self, copy: bool = True, add_key: str = "fit"): "fit_basal_transcription": self._fit_basal_transcription, "use_raw": self._use_raw, } - adata.varm["loss"] = self._loss + # loss is only written after the execution of fit() + if hasattr(self, "_loss"): + adata.varm["loss"] = self._loss return adata # TODO: Remove `use_raw` argument From d120e501e825522e9c1b2ce764304e9b4f9ae3ff Mon Sep 17 00:00:00 2001 From: Philipp Weiler Date: Fri, 1 Dec 2023 14:29:40 +0100 Subject: [PATCH 8/9] Update `_em_model.py` Add argument `show_progress_bar` to `fit` method. --- scvelo/tools/_em_model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scvelo/tools/_em_model.py b/scvelo/tools/_em_model.py index e2291dec..0ef5fb39 100644 --- a/scvelo/tools/_em_model.py +++ b/scvelo/tools/_em_model.py @@ -227,6 +227,7 @@ def fit( load_pars: bool = False, steady_state_prior: Optional[List[bool]] = None, assignment_mode: str = "projection", + show_progress_bar: bool = True, **kwargs, ): """Fit the model.""" @@ -259,7 +260,7 @@ def fit( unit="gene", as_array=False, backend=self._backend, - show_progress_bar=False, # len(self._var_names) > 9, + show_progress_bar=show_progress_bar, )( use_raw=use_raw, load_pars=load_pars, From 60f94cc0135d57297b8d36e688829b566c3adcdc Mon Sep 17 00:00:00 2001 From: Philipp Weiler Date: Fri, 1 Dec 2023 15:07:03 +0100 Subject: [PATCH 9/9] Update `tests/test_basic.py` --- tests/test_basic.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_basic.py b/tests/test_basic.py index 234a83b4..169aac68 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -32,6 +32,7 @@ def test_dynamical_model(): adata=adata, var_names_key=adata.var_names[0] ) em_model.fit(return_model=False, copy=False) + adata = em_model.export_results_adata(adata) assert np.round(adata[:, adata.var_names[0]].var["fit_alpha"][0], 4) == 4.7409 @@ -45,6 +46,7 @@ def test_pipeline(): em_model = ExpectationMaximizationModel(adata=adata) em_model.fit(copy=False) + adata = em_model.export_results_adata(adata) scv.tl.velocity(adata) scv.tl.velocity(adata, vkey="dynamical_velocity", mode="dynamical") adata.var.velocity_genes = True