Skip to content

Commit

Permalink
Refine ExpectationMaximizationModel
Browse files Browse the repository at this point in the history
* Add metadata the attributes to `EMParams`
* Add methods read/write `EMParams` from/to AnnData
* Adapt the method `fit()` such that it uses the class `EMParams`
* Include `_align_dynamics()` as a method in
  `ExpectationMaximizationModel` and adapt it to also use `EMParams`
* Include `_flatten()` from `_em_model_core.py`
  • Loading branch information
johschnee committed Aug 22, 2023
1 parent b4c27d3 commit 15808e0
Showing 1 changed file with 194 additions and 77 deletions.
271 changes: 194 additions & 77 deletions scvelo/tools/_em_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from dataclasses import asdict, dataclass, fields
from typing import List, Optional, Sequence
from dataclasses import asdict, dataclass, field, fields
from typing import List, Optional, Sequence, Union

import numpy as np

Expand All @@ -11,7 +11,7 @@
from scvelo.core import get_n_jobs, parallelize
from scvelo.preprocessing.moments import get_connectivities
from ._core import BaseInference
from ._em_model_core import _flatten, align_dynamics, DynamicsRecovery
from ._em_model_core import DynamicsRecovery
from ._steady_state_model import SteadyStateModel
from .utils import make_unique_list

Expand All @@ -20,20 +20,61 @@
class EMParams:
"""EM parameters."""

alpha: np.ndarray
beta: np.ndarray
gamma: np.ndarray
t_: np.ndarray
scaling: np.ndarray
std_u: np.ndarray
std_s: np.ndarray
likelihood: np.ndarray
u0: np.ndarray
s0: np.ndarray
pval_steady: np.ndarray
steady_u: np.ndarray
steady_s: np.ndarray
variance: np.ndarray
r2: np.ndarray = field(metadata={"is_matrix": False})
alpha: np.ndarray = field(metadata={"is_matrix": False})
beta: np.ndarray = field(metadata={"is_matrix": False})
gamma: np.ndarray = field(metadata={"is_matrix": False})
t_: np.ndarray = field(metadata={"is_matrix": False})
scaling: np.ndarray = field(metadata={"is_matrix": False})
std_u: np.ndarray = field(metadata={"is_matrix": False})
std_s: np.ndarray = field(metadata={"is_matrix": False})
likelihood: np.ndarray = field(metadata={"is_matrix": False})
u0: np.ndarray = field(metadata={"is_matrix": False})
s0: np.ndarray = field(metadata={"is_matrix": False})
pval_steady: np.ndarray = field(metadata={"is_matrix": False})
steady_u: np.ndarray = field(metadata={"is_matrix": False})
steady_s: np.ndarray = field(metadata={"is_matrix": False})
variance: np.ndarray = field(metadata={"is_matrix": False})
alignment_scaling: np.ndarray = field(metadata={"is_matrix": False})
T: np.ndarray = field(metadata={"is_matrix": True})
Tau: np.ndarray = field(metadata={"is_matrix": True})
Tau_: np.ndarray = field(metadata={"is_matrix": True})

@classmethod
def from_adata(cls, adata: AnnData, key: str = "fit"):
parameter_dict = {}
for parameter in fields(cls):
para_name = parameter.name
if parameter.metadata["is_matrix"]:
if f"{key}_{para_name.lower()}" in adata.layers.keys():
parameter_dict[para_name] = adata.layers[
f"{key}_{para_name.lower()}"
]
else:
_vals = np.empty(adata.shape)
_vals.fill(np.nan)
parameter_dict[para_name] = _vals
else:
if f"{key}_{para_name.lower()}" in adata.var.keys():
parameter_dict[para_name] = adata.var[
f"{key}_{para_name.lower()}"
].values
else:
_vals = np.empty(adata.n_vars)
_vals.fill(np.nan)
parameter_dict[para_name] = _vals
return cls(**parameter_dict)

# TODO: Atm, fields are also written if they contain only NaN values. Is this useful?
def export_to_adata(self, adata: AnnData, key: str = "fit"):
for parameter in fields(self):
para_name = parameter.name
value = getattr(self, para_name)
if parameter.metadata["is_matrix"]:
adata.layers[f"{key}_{para_name.lower()}"] = value
else:
adata.var[f"{key}_{para_name.lower()}"] = value
return adata


# TODO: Implement abstract methods
Expand Down Expand Up @@ -119,22 +160,10 @@ def __init__(
self._n_jobs = get_n_jobs(n_jobs=n_jobs)
self._backend = backend

def _initialize_state_dict(
self, adata: AnnData, parameters: Optional[List[str]] = None, key: str = "fit"
):
if parameters is None:
parameters = [field.name for field in fields(EMParams)]
parameter_dict = {}

for parameter in parameters:
if f"{key}_{parameter}" in adata.var.keys():
parameter_dict[parameter] = adata.var[f"{key}_{parameter}"].values
else:
_vals = np.empty(adata.n_vars)
_vals.fill(np.nan)
parameter_dict[parameter] = _vals

self._state_dict = EMParams(**parameter_dict)
if len(set(self._adata.var_names)) != len(self._adata.var_names):
logg.warn("Duplicate var_names found. Making them unique.")
self._adata.var_names_make_unique()
self._state_dict = EMParams.from_adata(adata)

def _prepare_genes(self):
"""Initialize genes to use for the fitting."""
Expand All @@ -153,7 +182,7 @@ def _prepare_genes(self):
)
velo.fit()
var_names = adata.var_names[velo.state_dict()["velocity_genes"]]
self.r2 = velo.state_dict()["r2"]
self._state_dict.r2 = velo.state_dict()["r2"]
else:
raise ValueError("Variable name not found in var keys.")
if not isinstance(var_names, str):
Expand All @@ -175,9 +204,13 @@ def state_dict(self):
def export_results_adata(self, copy: bool = True, add_key: str = "fit"):
"""Export the results to the AnnData object and return it."""
adata = self._adata.copy() if copy else self._adata
adata.var[f"{add_key}_r2"] = self.r2
for key, value in asdict(self._state_dict).items():
adata.var[f"{add_key}_{key}"] = value
self._state_dict.export_to_adata(adata, add_key)
adata.uns["recover_dynamics"] = {
"fit_connected_states": self._fit_connected_states,
"fit_basal_transcription": self._fit_basal_transcription,
"use_raw": self._use_raw,
}
adata.varm["loss"] = self._loss
return adata

# TODO: Remove `use_raw` argument
Expand All @@ -197,37 +230,23 @@ def fit(
f"recovering dynamics (using {self._n_jobs}/{os.cpu_count()} cores)", r=True
)

# TODO: Remove or move to `__init__`
if len(set(self._adata.var_names)) != len(self._adata.var_names):
logg.warn("Duplicate var_names found. Making them unique.")
self._adata.var_names_make_unique()

if (
"Ms" not in self._adata.layers.keys()
or "Mu" not in self._adata.layers.keys()
):
use_raw = True
# TODO: Refactor the definition of 'use_raw'; Move to init?
self._use_raw = use_raw
if self._fit_connected_states is None:
fit_connected_states = not use_raw
self._fit_connected_states = not use_raw

self._prepare_genes()
if return_model is None:
return_model = len(self._var_names) < 5

self._initialize_state_dict(self._adata)
sd = self._state_dict
idx, L = [], []
T = np.zeros(self._adata.shape) * np.nan
Tau = np.zeros(self._adata.shape) * np.nan
Tau_ = np.zeros(self._adata.shape) * np.nan
if f"{self._fit_key}_t" in self._adata.layers.keys():
T = self._adata.layers[f"{self._fit_key}_t"]
if f"{self._fit_key}_tau" in self._adata.layers.keys():
Tau = self._adata.layers[f"{self._fit_key}_tau"]
if f"{self._fit_key}_tau_" in self._adata.layers.keys():
Tau_ = self._adata.layers[f"{self._fit_key}_tau_"]

conn = get_connectivities(self._adata) if fit_connected_states else None
conn = get_connectivities(self._adata) if self._fit_connected_states else None

res = parallelize(
self._fit,
Expand All @@ -236,7 +255,7 @@ def fit(
unit="gene",
as_array=False,
backend=self._backend,
show_progress_bar=len(self._var_names) > 9,
show_progress_bar=False, # len(self._var_names) > 9,
)(
use_raw=use_raw,
load_pars=load_pars,
Expand All @@ -248,7 +267,7 @@ def fit(
idx, dms = map(_flatten, zip(*res))

for ix, dm in zip(idx, dms):
T[:, ix], Tau[:, ix], Tau_[:, ix] = dm.t, dm.tau, dm.tau_
sd.T[:, ix], sd.Tau[:, ix], sd.Tau_[:, ix] = dm.t, dm.tau, dm.tau_
(
sd.alpha[ix],
sd.beta[ix],
Expand All @@ -265,43 +284,34 @@ def fit(
sd.likelihood[ix], sd.variance[ix] = dm.likelihood, dm.varx
L.append(dm.loss)

adata = self.export_results_adata(copy=copy, add_key=self._fit_key)

adata.uns["recover_dynamics"] = {
"fit_connected_states": fit_connected_states,
"fit_basal_transcription": self._fit_basal_transcription,
"use_raw": use_raw,
}
adata = self._adata

if f"{self._fit_key}_t" in adata.layers.keys():
adata.layers[f"{self._fit_key}_t"][:, idx] = (
T[:, idx] if conn is None else conn.dot(T[:, idx])
)
else:
adata.layers[f"{self._fit_key}_t"] = T if conn is None else conn.dot(T)
adata.layers[f"{self._fit_key}_tau"] = Tau
adata.layers[f"{self._fit_key}_tau_"] = Tau_
if conn is not None:
if f"{self._fit_key}_t" in adata.layers.keys():
sd.T[:, idx] = conn.dot(sd.T[:, idx])
else:
sd.T = conn.dot(sd.T)

# is False if only one invalid / irrecoverable gene was given in var_names
if L:
cur_len = adata.varm["loss"].shape[1] if "loss" in adata.varm.keys() else 2
max_len = max(np.max([len(loss) for loss in L]), cur_len) if L else cur_len
loss = np.ones((adata.n_vars, max_len)) * np.nan
self._loss = np.empty((adata.n_vars, max_len))
self._loss.fill(np.nan)

if "loss" in adata.varm.keys():
loss[:, :cur_len] = adata.varm["loss"]
self._loss[:, :cur_len] = adata.varm["loss"]

loss[idx] = np.vstack(
self._loss[idx] = np.vstack(
[
np.concatenate([loss, np.ones(max_len - len(loss)) * np.nan])
for loss in L
]
)
adata.varm["loss"] = loss

# TODO: Fix s.t. `self._t_max` is only integer
if self._t_max is not False:
dm = align_dynamics(adata, t_max=self._t_max, dm=dm, idx=idx)
dm = self._align_dynamics(t_max=self._t_max, dm=dm, idx=idx)

logg.info(
" finished", time=True, end=" " if settings.verbosity > 2 else "\n"
Expand All @@ -317,6 +327,109 @@ def fit(

return dm if return_model else adata if copy else None

def _align_dynamics(
self,
t_max: Optional[Union[float, bool]] = None,
dm: Optional[DynamicsRecovery] = None,
idx: Optional[List[bool]] = None,
mode: Optional[str] = None,
remove_outliers: bool = False,
):
"""Align dynamics to a common set of parameters.
Arguments
---------
t_max: `float`, `False` or `None` (default: `None`)
Total range for time assignments.
dm: :class:`~DynamicsRecovery`
DynamicsRecovery object to perform alignment on.
idx: list of `bool` or `None` (default: `None`)
Mask for indices used for alignment.
mode: `str` or None (default: `'align_total_time`)
What to align. Takes the following arguments:
common_splicing_rate, common_scaling, align_increments, align_total_time
remove_outliers: `bool` or `None` (default: `None`)
Whether to remove outliers.
copy: `bool` (default: `False`)
Return a copy instead of writing to `adata`.
Returns
-------
`alpha`, `beta`, `gamma`, `t_`, `alignment_scaling`: `.var`
Aligned parameters
`fit_t`, `fit_tau`, `fit_tau_`: `.layer`
Aligned time
"""
sd = self._state_dict
if idx is None:
idx = ~np.isnan(np.sum(sd.T, axis=0))
if np.all(np.isnan(sd.alignment_scaling)):
sd.alignment_scaling.fill(1)
if mode is None:
mode = "align_total_time"

m = np.ones(self._adata.n_vars)
mz = sd.alignment_scaling
mz_prev = np.array(mz)

if dm is not None: # newly fitted
mz[idx] = 1

if mode == "align_total_time" and t_max is not False:
# transient 'on'
T_max = np.max(sd.T[:, idx] * (sd.T[:, idx] < sd.t_[idx]), axis=0)
# 'off'
T_max += np.max(
(sd.T[:, idx] - sd.t_[idx]) * (sd.T[:, idx] > sd.t_[idx]), axis=0
)

denom = 1 - np.sum(
(sd.T[:, idx] == sd.t_[idx]) | (sd.T[:, idx] == 0), axis=0
) / len(sd.T)
denom += denom == 0

T_max = T_max / denom
T_max += T_max == 0

if t_max is None:
t_max = 20
m[idx] = t_max / T_max
mz *= m
else:
m = 1 / mz
mz = np.ones(self.adata.n_vars)

if remove_outliers:
mu, std = np.nanmean(mz), np.nanstd(mz)
mz = np.clip(mz, mu - 3 * std, mu + 3 * std)
m = mz / mz_prev

if idx is None:
sd.alpha, sd.beta, sd.gamma = sd.alpha / m, sd.beta / m, sd.gamma / m
sd.T, sd.t_, sd.Tau, sd.Tau_ = sd.T * m, sd.t_ * m, sd.Tau * m, sd.Tau_ * m
else:
m_ = m[idx]
sd.alpha[idx] = sd.alpha[idx] / m_
sd.beta[idx] = sd.beta[idx] / m_
sd.gamma[idx] = sd.gamma[idx] / m_
sd.T[:, idx], sd.t_[idx] = sd.T[:, idx] * m_, sd.t_[idx] * m_
sd.Tau[:, idx], sd.Tau_[:, idx] = sd.Tau[:, idx] * m_, sd.Tau_[:, idx] * m_

mz[mz == 1] = np.nan

if dm is not None and dm.recoverable:
dm.m = m[idx]
dm.alpha = dm.alpha / dm.m[-1]
dm.beta = dm.beta / dm.m[-1]
dm.gamma = dm.gamma / dm.m[-1]
dm.pars[:3] = dm.pars[:3] / dm.m[-1]

dm.t = dm.t * dm.m[-1]
dm.tau = dm.tau * dm.m[-1]
dm.t_ = dm.t_ * dm.m[-1]
dm.pars[4] = dm.pars[4] * dm.m[-1]
return dm

# TODO: Add docstrings
def _fit(
self,
Expand Down Expand Up @@ -362,3 +475,7 @@ def _fit(
queue.put(None)

return idx, dms


def _flatten(iterable):
return [i for it in iterable for i in it]

0 comments on commit 15808e0

Please sign in to comment.